reth_tasks/
lib.rs

1//! Reth task management.
2//!
3//! # Feature Flags
4//!
5//! - `rayon`: Enable rayon thread pool for blocking tasks.
6
7#![doc(
8    html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9    html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10    issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
14
15use crate::{
16    metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17    shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21    future::{select, BoxFuture},
22    Future, FutureExt, TryFutureExt,
23};
24use std::{
25    any::Any,
26    fmt::{Display, Formatter},
27    pin::{pin, Pin},
28    sync::{
29        atomic::{AtomicUsize, Ordering},
30        Arc, OnceLock,
31    },
32    task::{ready, Context, Poll},
33};
34use tokio::{
35    runtime::Handle,
36    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37    task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48/// Global [`TaskExecutor`] instance that can be accessed from anywhere.
49static GLOBAL_EXECUTOR: OnceLock<TaskExecutor> = OnceLock::new();
50
51/// A type that can spawn tasks.
52///
53/// The main purpose of this type is to abstract over [`TaskExecutor`] so it's more convenient to
54/// provide default impls for testing.
55///
56///
57/// # Examples
58///
59/// Use the [`TokioTaskExecutor`] that spawns with [`tokio::task::spawn`]
60///
61/// ```
62/// # async fn t() {
63/// use reth_tasks::{TaskSpawner, TokioTaskExecutor};
64/// let executor = TokioTaskExecutor::default();
65///
66/// let task = executor.spawn(Box::pin(async {
67///     // -- snip --
68/// }));
69/// task.await.unwrap();
70/// # }
71/// ```
72///
73/// Use the [`TaskExecutor`] that spawns task directly onto the tokio runtime via the [Handle].
74///
75/// ```
76/// # use reth_tasks::TaskManager;
77/// fn t() {
78///  use reth_tasks::TaskSpawner;
79/// let rt = tokio::runtime::Runtime::new().unwrap();
80/// let manager = TaskManager::new(rt.handle().clone());
81/// let executor = manager.executor();
82/// let task = TaskSpawner::spawn(&executor, Box::pin(async {
83///     // -- snip --
84/// }));
85/// rt.block_on(task).unwrap();
86/// # }
87/// ```
88///
89/// The [`TaskSpawner`] trait is [`DynClone`] so `Box<dyn TaskSpawner>` are also `Clone`.
90#[auto_impl::auto_impl(&, Arc)]
91pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
92    /// Spawns the task onto the runtime.
93    /// See also [`Handle::spawn`].
94    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96    /// This spawns a critical task onto the runtime.
97    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99    /// Spawns a blocking task onto the runtime.
100    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
101
102    /// This spawns a critical blocking task onto the runtime.
103    fn spawn_critical_blocking(
104        &self,
105        name: &'static str,
106        fut: BoxFuture<'static, ()>,
107    ) -> JoinHandle<()>;
108}
109
110dyn_clone::clone_trait_object!(TaskSpawner);
111
112/// An [`TaskSpawner`] that uses [`tokio::task::spawn`] to execute tasks
113#[derive(Debug, Clone, Default)]
114#[non_exhaustive]
115pub struct TokioTaskExecutor;
116
117impl TokioTaskExecutor {
118    /// Converts the instance to a boxed [`TaskSpawner`].
119    pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
120        Box::new(self)
121    }
122}
123
124impl TaskSpawner for TokioTaskExecutor {
125    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
126        tokio::task::spawn(fut)
127    }
128
129    fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
130        tokio::task::spawn(fut)
131    }
132
133    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
134        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
135    }
136
137    fn spawn_critical_blocking(
138        &self,
139        _name: &'static str,
140        fut: BoxFuture<'static, ()>,
141    ) -> JoinHandle<()> {
142        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
143    }
144}
145
146/// Many reth components require to spawn tasks for long-running jobs. For example `discovery`
147/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks
148/// that handle the traffic to and from a peer.
149///
150/// To unify how tasks are created, the [`TaskManager`] provides access to the configured Tokio
151/// runtime. A [`TaskManager`] stores the [`tokio::runtime::Handle`] it is associated with. In this
152/// way it is possible to configure on which runtime a task is executed.
153///
154/// The main purpose of this type is to be able to monitor if a critical task panicked, for
155/// diagnostic purposes, since tokio task essentially fail silently. Therefore, this type is a
156/// Stream that yields the name of panicked task, See [`TaskExecutor::spawn_critical`]. In order to
157/// execute Tasks use the [`TaskExecutor`] type [`TaskManager::executor`].
158#[derive(Debug)]
159#[must_use = "TaskManager must be polled to monitor critical tasks"]
160pub struct TaskManager {
161    /// Handle to the tokio runtime this task manager is associated with.
162    ///
163    /// See [`Handle`] docs.
164    handle: Handle,
165    /// Sender half for sending panic signals to this type
166    panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
167    /// Listens for panicked tasks
168    panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
169    /// The [Signal] to fire when all tasks should be shutdown.
170    ///
171    /// This is fired when dropped.
172    signal: Option<Signal>,
173    /// Receiver of the shutdown signal.
174    on_shutdown: Shutdown,
175    /// How many [`GracefulShutdown`] tasks are currently active
176    graceful_tasks: Arc<AtomicUsize>,
177}
178
179// === impl TaskManager ===
180
181impl TaskManager {
182    /// Returns a new [`TaskManager`] over the currently running Runtime.
183    ///
184    /// # Panics
185    ///
186    /// This will panic if called outside the context of a Tokio runtime.
187    pub fn current() -> Self {
188        let handle = Handle::current();
189        Self::new(handle)
190    }
191
192    /// Create a new instance connected to the given handle's tokio runtime.
193    ///
194    /// This also sets the global [`TaskExecutor`].
195    pub fn new(handle: Handle) -> Self {
196        let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
197        let (signal, on_shutdown) = signal();
198        let manager = Self {
199            handle,
200            panicked_tasks_tx,
201            panicked_tasks_rx,
202            signal: Some(signal),
203            on_shutdown,
204            graceful_tasks: Arc::new(AtomicUsize::new(0)),
205        };
206
207        let _ = GLOBAL_EXECUTOR
208            .set(manager.executor())
209            .inspect_err(|_| error!("Global executor already set"));
210
211        manager
212    }
213
214    /// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is
215    /// connected to.
216    pub fn executor(&self) -> TaskExecutor {
217        TaskExecutor {
218            handle: self.handle.clone(),
219            on_shutdown: self.on_shutdown.clone(),
220            panicked_tasks_tx: self.panicked_tasks_tx.clone(),
221            metrics: Default::default(),
222            graceful_tasks: Arc::clone(&self.graceful_tasks),
223        }
224    }
225
226    /// Fires the shutdown signal and awaits until all tasks are shutdown.
227    pub fn graceful_shutdown(self) {
228        let _ = self.do_graceful_shutdown(None);
229    }
230
231    /// Fires the shutdown signal and awaits until all tasks are shutdown.
232    ///
233    /// Returns true if all tasks were shutdown before the timeout elapsed.
234    pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
235        self.do_graceful_shutdown(Some(timeout))
236    }
237
238    fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
239        drop(self.signal);
240        let when = timeout.map(|t| std::time::Instant::now() + t);
241        while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
242            if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
243                debug!("graceful shutdown timed out");
244                return false
245            }
246            std::hint::spin_loop();
247        }
248
249        debug!("gracefully shut down");
250        true
251    }
252}
253
254/// An endless future that resolves if a critical task panicked.
255///
256/// See [`TaskExecutor::spawn_critical`]
257impl Future for TaskManager {
258    type Output = PanickedTaskError;
259
260    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
261        let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
262        Poll::Ready(err.expect("stream can not end"))
263    }
264}
265
266/// Error with the name of the task that panicked and an error downcasted to string, if possible.
267#[derive(Debug, thiserror::Error)]
268pub struct PanickedTaskError {
269    task_name: &'static str,
270    error: Option<String>,
271}
272
273impl Display for PanickedTaskError {
274    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
275        let task_name = self.task_name;
276        if let Some(error) = &self.error {
277            write!(f, "Critical task `{task_name}` panicked: `{error}`")
278        } else {
279            write!(f, "Critical task `{task_name}` panicked")
280        }
281    }
282}
283
284impl PanickedTaskError {
285    fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
286        let error = match error.downcast::<String>() {
287            Ok(value) => Some(*value),
288            Err(error) => match error.downcast::<&str>() {
289                Ok(value) => Some(value.to_string()),
290                Err(_) => None,
291            },
292        };
293
294        Self { task_name, error }
295    }
296}
297
298/// A type that can spawn new tokio tasks
299#[derive(Debug, Clone)]
300pub struct TaskExecutor {
301    /// Handle to the tokio runtime this task manager is associated with.
302    ///
303    /// See [`Handle`] docs.
304    handle: Handle,
305    /// Receiver of the shutdown signal.
306    on_shutdown: Shutdown,
307    /// Sender half for sending panic signals to this type
308    panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
309    /// Task Executor Metrics
310    metrics: TaskExecutorMetrics,
311    /// How many [`GracefulShutdown`] tasks are currently active
312    graceful_tasks: Arc<AtomicUsize>,
313}
314
315// === impl TaskExecutor ===
316
317impl TaskExecutor {
318    /// Attempts to get the current `TaskExecutor` if one has been initialized.
319    ///
320    /// Returns an error if no [`TaskExecutor`] has been initialized via [`TaskManager`].
321    pub fn try_current() -> Result<Self, NoCurrentTaskExecutorError> {
322        GLOBAL_EXECUTOR.get().cloned().ok_or_else(NoCurrentTaskExecutorError::default)
323    }
324
325    /// Returns the current `TaskExecutor`.
326    ///
327    /// # Panics
328    ///
329    /// Panics if no global executor has been initialized. Use [`try_current`](Self::try_current)
330    /// for a non-panicking version.
331    pub fn current() -> Self {
332        Self::try_current().unwrap()
333    }
334
335    /// Returns the [Handle] to the tokio runtime.
336    pub const fn handle(&self) -> &Handle {
337        &self.handle
338    }
339
340    /// Returns the receiver of the shutdown signal.
341    pub const fn on_shutdown_signal(&self) -> &Shutdown {
342        &self.on_shutdown
343    }
344
345    /// Spawns a future on the tokio runtime depending on the [`TaskKind`]
346    fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
347    where
348        F: Future<Output = ()> + Send + 'static,
349    {
350        match task_kind {
351            TaskKind::Default => self.handle.spawn(fut),
352            TaskKind::Blocking => {
353                let handle = self.handle.clone();
354                self.handle.spawn_blocking(move || handle.block_on(fut))
355            }
356        }
357    }
358
359    /// Spawns a regular task depending on the given [`TaskKind`]
360    fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
361    where
362        F: Future<Output = ()> + Send + 'static,
363    {
364        let on_shutdown = self.on_shutdown.clone();
365
366        // Clone only the specific counter that we need.
367        let finished_regular_tasks_total_metrics =
368            self.metrics.finished_regular_tasks_total.clone();
369        // Wrap the original future to increment the finished tasks counter upon completion
370        let task = {
371            async move {
372                // Create an instance of IncCounterOnDrop with the counter to increment
373                let _inc_counter_on_drop =
374                    IncCounterOnDrop::new(finished_regular_tasks_total_metrics);
375                let fut = pin!(fut);
376                let _ = select(on_shutdown, fut).await;
377            }
378        }
379        .in_current_span();
380
381        self.spawn_on_rt(task, task_kind)
382    }
383
384    /// Spawns the task onto the runtime.
385    /// The given future resolves as soon as the [Shutdown] signal is received.
386    ///
387    /// See also [`Handle::spawn`].
388    pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
389    where
390        F: Future<Output = ()> + Send + 'static,
391    {
392        self.spawn_task_as(fut, TaskKind::Default)
393    }
394
395    /// Spawns a blocking task onto the runtime.
396    /// The given future resolves as soon as the [Shutdown] signal is received.
397    ///
398    /// See also [`Handle::spawn_blocking`].
399    pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
400    where
401        F: Future<Output = ()> + Send + 'static,
402    {
403        self.spawn_task_as(fut, TaskKind::Blocking)
404    }
405
406    /// Spawns the task onto the runtime.
407    /// The given future resolves as soon as the [Shutdown] signal is received.
408    ///
409    /// See also [`Handle::spawn`].
410    pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
411    where
412        F: Future<Output = ()> + Send + 'static,
413    {
414        let on_shutdown = self.on_shutdown.clone();
415        let fut = f(on_shutdown);
416
417        let task = fut.in_current_span();
418
419        self.handle.spawn(task)
420    }
421
422    /// Spawns a critical task depending on the given [`TaskKind`]
423    fn spawn_critical_as<F>(
424        &self,
425        name: &'static str,
426        fut: F,
427        task_kind: TaskKind,
428    ) -> JoinHandle<()>
429    where
430        F: Future<Output = ()> + Send + 'static,
431    {
432        let panicked_tasks_tx = self.panicked_tasks_tx.clone();
433        let on_shutdown = self.on_shutdown.clone();
434
435        // wrap the task in catch unwind
436        let task = std::panic::AssertUnwindSafe(fut)
437            .catch_unwind()
438            .map_err(move |error| {
439                let task_error = PanickedTaskError::new(name, error);
440                error!("{task_error}");
441                let _ = panicked_tasks_tx.send(task_error);
442            })
443            .in_current_span();
444
445        // Clone only the specific counter that we need.
446        let finished_critical_tasks_total_metrics =
447            self.metrics.finished_critical_tasks_total.clone();
448        let task = async move {
449            // Create an instance of IncCounterOnDrop with the counter to increment
450            let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
451            let task = pin!(task);
452            let _ = select(on_shutdown, task).await;
453        };
454
455        self.spawn_on_rt(task, task_kind)
456    }
457
458    /// This spawns a critical blocking task onto the runtime.
459    /// The given future resolves as soon as the [Shutdown] signal is received.
460    ///
461    /// If this task panics, the [`TaskManager`] is notified.
462    pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
463    where
464        F: Future<Output = ()> + Send + 'static,
465    {
466        self.spawn_critical_as(name, fut, TaskKind::Blocking)
467    }
468
469    /// This spawns a critical task onto the runtime.
470    /// The given future resolves as soon as the [Shutdown] signal is received.
471    ///
472    /// If this task panics, the [`TaskManager`] is notified.
473    pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
474    where
475        F: Future<Output = ()> + Send + 'static,
476    {
477        self.spawn_critical_as(name, fut, TaskKind::Default)
478    }
479
480    /// This spawns a critical task onto the runtime.
481    ///
482    /// If this task panics, the [`TaskManager`] is notified.
483    pub fn spawn_critical_with_shutdown_signal<F>(
484        &self,
485        name: &'static str,
486        f: impl FnOnce(Shutdown) -> F,
487    ) -> JoinHandle<()>
488    where
489        F: Future<Output = ()> + Send + 'static,
490    {
491        let panicked_tasks_tx = self.panicked_tasks_tx.clone();
492        let on_shutdown = self.on_shutdown.clone();
493        let fut = f(on_shutdown);
494
495        // wrap the task in catch unwind
496        let task = std::panic::AssertUnwindSafe(fut)
497            .catch_unwind()
498            .map_err(move |error| {
499                let task_error = PanickedTaskError::new(name, error);
500                error!("{task_error}");
501                let _ = panicked_tasks_tx.send(task_error);
502            })
503            .map(drop)
504            .in_current_span();
505
506        self.handle.spawn(task)
507    }
508
509    /// This spawns a critical task onto the runtime.
510    ///
511    /// If this task panics, the [`TaskManager`] is notified.
512    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
513    ///
514    /// # Example
515    ///
516    /// ```no_run
517    /// # async fn t(executor: reth_tasks::TaskExecutor) {
518    ///
519    /// executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
520    ///     // await the shutdown signal
521    ///     let guard = shutdown.await;
522    ///     // do work before exiting the program
523    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
524    ///     // allow graceful shutdown
525    ///     drop(guard);
526    /// });
527    /// # }
528    /// ```
529    pub fn spawn_critical_with_graceful_shutdown_signal<F>(
530        &self,
531        name: &'static str,
532        f: impl FnOnce(GracefulShutdown) -> F,
533    ) -> JoinHandle<()>
534    where
535        F: Future<Output = ()> + Send + 'static,
536    {
537        let panicked_tasks_tx = self.panicked_tasks_tx.clone();
538        let on_shutdown = GracefulShutdown::new(
539            self.on_shutdown.clone(),
540            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
541        );
542        let fut = f(on_shutdown);
543
544        // wrap the task in catch unwind
545        let task = std::panic::AssertUnwindSafe(fut)
546            .catch_unwind()
547            .map_err(move |error| {
548                let task_error = PanickedTaskError::new(name, error);
549                error!("{task_error}");
550                let _ = panicked_tasks_tx.send(task_error);
551            })
552            .map(drop)
553            .in_current_span();
554
555        self.handle.spawn(task)
556    }
557
558    /// This spawns a regular task onto the runtime.
559    ///
560    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
561    ///
562    /// # Example
563    ///
564    /// ```no_run
565    /// # async fn t(executor: reth_tasks::TaskExecutor) {
566    ///
567    /// executor.spawn_with_graceful_shutdown_signal(|shutdown| async move {
568    ///     // await the shutdown signal
569    ///     let guard = shutdown.await;
570    ///     // do work before exiting the program
571    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
572    ///     // allow graceful shutdown
573    ///     drop(guard);
574    /// });
575    /// # }
576    /// ```
577    pub fn spawn_with_graceful_shutdown_signal<F>(
578        &self,
579        f: impl FnOnce(GracefulShutdown) -> F,
580    ) -> JoinHandle<()>
581    where
582        F: Future<Output = ()> + Send + 'static,
583    {
584        let on_shutdown = GracefulShutdown::new(
585            self.on_shutdown.clone(),
586            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
587        );
588        let fut = f(on_shutdown);
589
590        self.handle.spawn(fut)
591    }
592}
593
594impl TaskSpawner for TaskExecutor {
595    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
596        self.metrics.inc_regular_tasks();
597        self.spawn(fut)
598    }
599
600    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
601        self.metrics.inc_critical_tasks();
602        Self::spawn_critical(self, name, fut)
603    }
604
605    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
606        self.spawn_blocking(fut)
607    }
608
609    fn spawn_critical_blocking(
610        &self,
611        name: &'static str,
612        fut: BoxFuture<'static, ()>,
613    ) -> JoinHandle<()> {
614        Self::spawn_critical_blocking(self, name, fut)
615    }
616}
617
618/// `TaskSpawner` with extended behaviour
619#[auto_impl::auto_impl(&, Arc)]
620pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
621    /// This spawns a critical task onto the runtime.
622    ///
623    /// If this task panics, the [`TaskManager`] is notified.
624    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
625    fn spawn_critical_with_graceful_shutdown_signal<F>(
626        &self,
627        name: &'static str,
628        f: impl FnOnce(GracefulShutdown) -> F,
629    ) -> JoinHandle<()>
630    where
631        F: Future<Output = ()> + Send + 'static;
632
633    /// This spawns a regular task onto the runtime.
634    ///
635    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
636    fn spawn_with_graceful_shutdown_signal<F>(
637        &self,
638        f: impl FnOnce(GracefulShutdown) -> F,
639    ) -> JoinHandle<()>
640    where
641        F: Future<Output = ()> + Send + 'static;
642}
643
644impl TaskSpawnerExt for TaskExecutor {
645    fn spawn_critical_with_graceful_shutdown_signal<F>(
646        &self,
647        name: &'static str,
648        f: impl FnOnce(GracefulShutdown) -> F,
649    ) -> JoinHandle<()>
650    where
651        F: Future<Output = ()> + Send + 'static,
652    {
653        Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
654    }
655
656    fn spawn_with_graceful_shutdown_signal<F>(
657        &self,
658        f: impl FnOnce(GracefulShutdown) -> F,
659    ) -> JoinHandle<()>
660    where
661        F: Future<Output = ()> + Send + 'static,
662    {
663        Self::spawn_with_graceful_shutdown_signal(self, f)
664    }
665}
666
667/// Determines how a task is spawned
668enum TaskKind {
669    /// Spawn the task to the default executor [`Handle::spawn`]
670    Default,
671    /// Spawn the task to the blocking executor [`Handle::spawn_blocking`]
672    Blocking,
673}
674
675/// Error returned by `try_current` when no task executor has been configured.
676#[derive(Debug, Default, thiserror::Error)]
677#[error("No current task executor available.")]
678#[non_exhaustive]
679pub struct NoCurrentTaskExecutorError;
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use std::{sync::atomic::AtomicBool, time::Duration};
685
686    #[test]
687    fn test_cloneable() {
688        #[derive(Clone)]
689        struct ExecutorWrapper {
690            _e: Box<dyn TaskSpawner>,
691        }
692
693        let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
694        let _e = dyn_clone::clone_box(&*executor);
695
696        let e = ExecutorWrapper { _e };
697        let _e2 = e;
698    }
699
700    #[test]
701    fn test_critical() {
702        let runtime = tokio::runtime::Runtime::new().unwrap();
703        let handle = runtime.handle().clone();
704        let manager = TaskManager::new(handle);
705        let executor = manager.executor();
706
707        executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
708
709        runtime.block_on(async move {
710            let err = manager.await;
711            assert_eq!(err.task_name, "this is a critical task");
712            assert_eq!(err.error, Some("intentionally panic".to_string()));
713        })
714    }
715
716    // Tests that spawned tasks are terminated if the `TaskManager` drops
717    #[test]
718    fn test_manager_shutdown_critical() {
719        let runtime = tokio::runtime::Runtime::new().unwrap();
720        let handle = runtime.handle().clone();
721        let manager = TaskManager::new(handle.clone());
722        let executor = manager.executor();
723
724        let (signal, shutdown) = signal();
725
726        executor.spawn_critical("this is a critical task", async move {
727            tokio::time::sleep(Duration::from_millis(200)).await;
728            drop(signal);
729        });
730
731        drop(manager);
732
733        handle.block_on(shutdown);
734    }
735
736    // Tests that spawned tasks are terminated if the `TaskManager` drops
737    #[test]
738    fn test_manager_shutdown() {
739        let runtime = tokio::runtime::Runtime::new().unwrap();
740        let handle = runtime.handle().clone();
741        let manager = TaskManager::new(handle.clone());
742        let executor = manager.executor();
743
744        let (signal, shutdown) = signal();
745
746        executor.spawn(Box::pin(async move {
747            tokio::time::sleep(Duration::from_millis(200)).await;
748            drop(signal);
749        }));
750
751        drop(manager);
752
753        handle.block_on(shutdown);
754    }
755
756    #[test]
757    fn test_manager_graceful_shutdown() {
758        let runtime = tokio::runtime::Runtime::new().unwrap();
759        let handle = runtime.handle().clone();
760        let manager = TaskManager::new(handle);
761        let executor = manager.executor();
762
763        let val = Arc::new(AtomicBool::new(false));
764        let c = val.clone();
765        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
766            let _guard = shutdown.await;
767            tokio::time::sleep(Duration::from_millis(200)).await;
768            c.store(true, Ordering::Relaxed);
769        });
770
771        manager.graceful_shutdown();
772        assert!(val.load(Ordering::Relaxed));
773    }
774
775    #[test]
776    fn test_manager_graceful_shutdown_many() {
777        let runtime = tokio::runtime::Runtime::new().unwrap();
778        let handle = runtime.handle().clone();
779        let manager = TaskManager::new(handle);
780        let executor = manager.executor();
781
782        let counter = Arc::new(AtomicUsize::new(0));
783        let num = 10;
784        for _ in 0..num {
785            let c = counter.clone();
786            executor.spawn_critical_with_graceful_shutdown_signal(
787                "grace",
788                move |shutdown| async move {
789                    let _guard = shutdown.await;
790                    tokio::time::sleep(Duration::from_millis(200)).await;
791                    c.fetch_add(1, Ordering::SeqCst);
792                },
793            );
794        }
795
796        manager.graceful_shutdown();
797        assert_eq!(counter.load(Ordering::Relaxed), num);
798    }
799
800    #[test]
801    fn test_manager_graceful_shutdown_timeout() {
802        let runtime = tokio::runtime::Runtime::new().unwrap();
803        let handle = runtime.handle().clone();
804        let manager = TaskManager::new(handle);
805        let executor = manager.executor();
806
807        let timeout = Duration::from_millis(500);
808        let val = Arc::new(AtomicBool::new(false));
809        let val2 = val.clone();
810        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
811            let _guard = shutdown.await;
812            tokio::time::sleep(timeout * 3).await;
813            val2.store(true, Ordering::Relaxed);
814            unreachable!("should not be reached");
815        });
816
817        manager.graceful_shutdown_with_timeout(timeout);
818        assert!(!val.load(Ordering::Relaxed));
819    }
820
821    #[test]
822    fn can_access_global() {
823        let runtime = tokio::runtime::Runtime::new().unwrap();
824        let handle = runtime.handle().clone();
825        let _manager = TaskManager::new(handle);
826        let _executor = TaskExecutor::try_current().unwrap();
827    }
828}