reth_tasks/
lib.rs

1//! Reth task management.
2//!
3//! # Feature Flags
4//!
5//! - `rayon`: Enable rayon thread pool for blocking tasks.
6
7#![doc(
8    html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9    html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10    issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
14
15use crate::{
16    metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17    shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21    future::{select, BoxFuture},
22    Future, FutureExt, TryFutureExt,
23};
24use std::{
25    any::Any,
26    fmt::{Display, Formatter},
27    pin::{pin, Pin},
28    sync::{
29        atomic::{AtomicUsize, Ordering},
30        Arc,
31    },
32    task::{ready, Context, Poll},
33};
34use tokio::{
35    runtime::Handle,
36    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37    task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48/// A type that can spawn tasks.
49///
50/// The main purpose of this type is to abstract over [`TaskExecutor`] so it's more convenient to
51/// provide default impls for testing.
52///
53///
54/// # Examples
55///
56/// Use the [`TokioTaskExecutor`] that spawns with [`tokio::task::spawn`]
57///
58/// ```
59/// # async fn t() {
60/// use reth_tasks::{TaskSpawner, TokioTaskExecutor};
61/// let executor = TokioTaskExecutor::default();
62///
63/// let task = executor.spawn(Box::pin(async {
64///     // -- snip --
65/// }));
66/// task.await.unwrap();
67/// # }
68/// ```
69///
70/// Use the [`TaskExecutor`] that spawns task directly onto the tokio runtime via the [Handle].
71///
72/// ```
73/// # use reth_tasks::TaskManager;
74/// fn t() {
75///  use reth_tasks::TaskSpawner;
76/// let rt = tokio::runtime::Runtime::new().unwrap();
77/// let manager = TaskManager::new(rt.handle().clone());
78/// let executor = manager.executor();
79/// let task = TaskSpawner::spawn(&executor, Box::pin(async {
80///     // -- snip --
81/// }));
82/// rt.block_on(task).unwrap();
83/// # }
84/// ```
85///
86/// The [`TaskSpawner`] trait is [`DynClone`] so `Box<dyn TaskSpawner>` are also `Clone`.
87#[auto_impl::auto_impl(&, Arc)]
88pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
89    /// Spawns the task onto the runtime.
90    /// See also [`Handle::spawn`].
91    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
92
93    /// This spawns a critical task onto the runtime.
94    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96    /// Spawns a blocking task onto the runtime.
97    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99    /// This spawns a critical blocking task onto the runtime.
100    fn spawn_critical_blocking(
101        &self,
102        name: &'static str,
103        fut: BoxFuture<'static, ()>,
104    ) -> JoinHandle<()>;
105}
106
107dyn_clone::clone_trait_object!(TaskSpawner);
108
109/// An [`TaskSpawner`] that uses [`tokio::task::spawn`] to execute tasks
110#[derive(Debug, Clone, Default)]
111#[non_exhaustive]
112pub struct TokioTaskExecutor;
113
114impl TokioTaskExecutor {
115    /// Converts the instance to a boxed [`TaskSpawner`].
116    pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
117        Box::new(self)
118    }
119}
120
121impl TaskSpawner for TokioTaskExecutor {
122    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
123        tokio::task::spawn(fut)
124    }
125
126    fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
127        tokio::task::spawn(fut)
128    }
129
130    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
131        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
132    }
133
134    fn spawn_critical_blocking(
135        &self,
136        _name: &'static str,
137        fut: BoxFuture<'static, ()>,
138    ) -> JoinHandle<()> {
139        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
140    }
141}
142
143/// Many reth components require to spawn tasks for long-running jobs. For example `discovery`
144/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks
145/// that handle the traffic to and from a peer.
146///
147/// To unify how tasks are created, the [`TaskManager`] provides access to the configured Tokio
148/// runtime. A [`TaskManager`] stores the [`tokio::runtime::Handle`] it is associated with. In this
149/// way it is possible to configure on which runtime a task is executed.
150///
151/// The main purpose of this type is to be able to monitor if a critical task panicked, for
152/// diagnostic purposes, since tokio task essentially fail silently. Therefore, this type is a
153/// Stream that yields the name of panicked task, See [`TaskExecutor::spawn_critical`]. In order to
154/// execute Tasks use the [`TaskExecutor`] type [`TaskManager::executor`].
155#[derive(Debug)]
156#[must_use = "TaskManager must be polled to monitor critical tasks"]
157pub struct TaskManager {
158    /// Handle to the tokio runtime this task manager is associated with.
159    ///
160    /// See [`Handle`] docs.
161    handle: Handle,
162    /// Sender half for sending panic signals to this type
163    panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
164    /// Listens for panicked tasks
165    panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
166    /// The [Signal] to fire when all tasks should be shutdown.
167    ///
168    /// This is fired when dropped.
169    signal: Option<Signal>,
170    /// Receiver of the shutdown signal.
171    on_shutdown: Shutdown,
172    /// How many [`GracefulShutdown`] tasks are currently active
173    graceful_tasks: Arc<AtomicUsize>,
174}
175
176// === impl TaskManager ===
177
178impl TaskManager {
179    /// Returns a [`TaskManager`] over the currently running Runtime.
180    ///
181    /// # Panics
182    ///
183    /// This will panic if called outside the context of a Tokio runtime.
184    pub fn current() -> Self {
185        let handle = Handle::current();
186        Self::new(handle)
187    }
188
189    /// Create a new instance connected to the given handle's tokio runtime.
190    pub fn new(handle: Handle) -> Self {
191        let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
192        let (signal, on_shutdown) = signal();
193        Self {
194            handle,
195            panicked_tasks_tx,
196            panicked_tasks_rx,
197            signal: Some(signal),
198            on_shutdown,
199            graceful_tasks: Arc::new(AtomicUsize::new(0)),
200        }
201    }
202
203    /// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is
204    /// connected to.
205    pub fn executor(&self) -> TaskExecutor {
206        TaskExecutor {
207            handle: self.handle.clone(),
208            on_shutdown: self.on_shutdown.clone(),
209            panicked_tasks_tx: self.panicked_tasks_tx.clone(),
210            metrics: Default::default(),
211            graceful_tasks: Arc::clone(&self.graceful_tasks),
212        }
213    }
214
215    /// Fires the shutdown signal and awaits until all tasks are shutdown.
216    pub fn graceful_shutdown(self) {
217        let _ = self.do_graceful_shutdown(None);
218    }
219
220    /// Fires the shutdown signal and awaits until all tasks are shutdown.
221    ///
222    /// Returns true if all tasks were shutdown before the timeout elapsed.
223    pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
224        self.do_graceful_shutdown(Some(timeout))
225    }
226
227    fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
228        drop(self.signal);
229        let when = timeout.map(|t| std::time::Instant::now() + t);
230        while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
231            if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
232                debug!("graceful shutdown timed out");
233                return false
234            }
235            std::hint::spin_loop();
236        }
237
238        debug!("gracefully shut down");
239        true
240    }
241}
242
243/// An endless future that resolves if a critical task panicked.
244///
245/// See [`TaskExecutor::spawn_critical`]
246impl Future for TaskManager {
247    type Output = PanickedTaskError;
248
249    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250        let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
251        Poll::Ready(err.expect("stream can not end"))
252    }
253}
254
255/// Error with the name of the task that panicked and an error downcasted to string, if possible.
256#[derive(Debug, thiserror::Error)]
257pub struct PanickedTaskError {
258    task_name: &'static str,
259    error: Option<String>,
260}
261
262impl Display for PanickedTaskError {
263    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264        let task_name = self.task_name;
265        if let Some(error) = &self.error {
266            write!(f, "Critical task `{task_name}` panicked: `{error}`")
267        } else {
268            write!(f, "Critical task `{task_name}` panicked")
269        }
270    }
271}
272
273impl PanickedTaskError {
274    fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
275        let error = match error.downcast::<String>() {
276            Ok(value) => Some(*value),
277            Err(error) => match error.downcast::<&str>() {
278                Ok(value) => Some(value.to_string()),
279                Err(_) => None,
280            },
281        };
282
283        Self { task_name, error }
284    }
285}
286
287/// A type that can spawn new tokio tasks
288#[derive(Debug, Clone)]
289pub struct TaskExecutor {
290    /// Handle to the tokio runtime this task manager is associated with.
291    ///
292    /// See [`Handle`] docs.
293    handle: Handle,
294    /// Receiver of the shutdown signal.
295    on_shutdown: Shutdown,
296    /// Sender half for sending panic signals to this type
297    panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
298    /// Task Executor Metrics
299    metrics: TaskExecutorMetrics,
300    /// How many [`GracefulShutdown`] tasks are currently active
301    graceful_tasks: Arc<AtomicUsize>,
302}
303
304// === impl TaskExecutor ===
305
306impl TaskExecutor {
307    /// Returns the [Handle] to the tokio runtime.
308    pub const fn handle(&self) -> &Handle {
309        &self.handle
310    }
311
312    /// Returns the receiver of the shutdown signal.
313    pub const fn on_shutdown_signal(&self) -> &Shutdown {
314        &self.on_shutdown
315    }
316
317    /// Spawns a future on the tokio runtime depending on the [`TaskKind`]
318    fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
319    where
320        F: Future<Output = ()> + Send + 'static,
321    {
322        match task_kind {
323            TaskKind::Default => self.handle.spawn(fut),
324            TaskKind::Blocking => {
325                let handle = self.handle.clone();
326                self.handle.spawn_blocking(move || handle.block_on(fut))
327            }
328        }
329    }
330
331    /// Spawns a regular task depending on the given [`TaskKind`]
332    fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
333    where
334        F: Future<Output = ()> + Send + 'static,
335    {
336        let on_shutdown = self.on_shutdown.clone();
337
338        // Clone only the specific counter that we need.
339        let finished_regular_tasks_total_metrics =
340            self.metrics.finished_regular_tasks_total.clone();
341        // Wrap the original future to increment the finished tasks counter upon completion
342        let task = {
343            async move {
344                // Create an instance of IncCounterOnDrop with the counter to increment
345                let _inc_counter_on_drop =
346                    IncCounterOnDrop::new(finished_regular_tasks_total_metrics);
347                let fut = pin!(fut);
348                let _ = select(on_shutdown, fut).await;
349            }
350        }
351        .in_current_span();
352
353        self.spawn_on_rt(task, task_kind)
354    }
355
356    /// Spawns the task onto the runtime.
357    /// The given future resolves as soon as the [Shutdown] signal is received.
358    ///
359    /// See also [`Handle::spawn`].
360    pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
361    where
362        F: Future<Output = ()> + Send + 'static,
363    {
364        self.spawn_task_as(fut, TaskKind::Default)
365    }
366
367    /// Spawns a blocking task onto the runtime.
368    /// The given future resolves as soon as the [Shutdown] signal is received.
369    ///
370    /// See also [`Handle::spawn_blocking`].
371    pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
372    where
373        F: Future<Output = ()> + Send + 'static,
374    {
375        self.spawn_task_as(fut, TaskKind::Blocking)
376    }
377
378    /// Spawns the task onto the runtime.
379    /// The given future resolves as soon as the [Shutdown] signal is received.
380    ///
381    /// See also [`Handle::spawn`].
382    pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
383    where
384        F: Future<Output = ()> + Send + 'static,
385    {
386        let on_shutdown = self.on_shutdown.clone();
387        let fut = f(on_shutdown);
388
389        let task = fut.in_current_span();
390
391        self.handle.spawn(task)
392    }
393
394    /// Spawns a critical task depending on the given [`TaskKind`]
395    fn spawn_critical_as<F>(
396        &self,
397        name: &'static str,
398        fut: F,
399        task_kind: TaskKind,
400    ) -> JoinHandle<()>
401    where
402        F: Future<Output = ()> + Send + 'static,
403    {
404        let panicked_tasks_tx = self.panicked_tasks_tx.clone();
405        let on_shutdown = self.on_shutdown.clone();
406
407        // wrap the task in catch unwind
408        let task = std::panic::AssertUnwindSafe(fut)
409            .catch_unwind()
410            .map_err(move |error| {
411                let task_error = PanickedTaskError::new(name, error);
412                error!("{task_error}");
413                let _ = panicked_tasks_tx.send(task_error);
414            })
415            .in_current_span();
416
417        // Clone only the specific counter that we need.
418        let finished_critical_tasks_total_metrics =
419            self.metrics.finished_critical_tasks_total.clone();
420        let task = async move {
421            // Create an instance of IncCounterOnDrop with the counter to increment
422            let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
423            let task = pin!(task);
424            let _ = select(on_shutdown, task).await;
425        };
426
427        self.spawn_on_rt(task, task_kind)
428    }
429
430    /// This spawns a critical blocking task onto the runtime.
431    /// The given future resolves as soon as the [Shutdown] signal is received.
432    ///
433    /// If this task panics, the [`TaskManager`] is notified.
434    pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
435    where
436        F: Future<Output = ()> + Send + 'static,
437    {
438        self.spawn_critical_as(name, fut, TaskKind::Blocking)
439    }
440
441    /// This spawns a critical task onto the runtime.
442    /// The given future resolves as soon as the [Shutdown] signal is received.
443    ///
444    /// If this task panics, the [`TaskManager`] is notified.
445    pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
446    where
447        F: Future<Output = ()> + Send + 'static,
448    {
449        self.spawn_critical_as(name, fut, TaskKind::Default)
450    }
451
452    /// This spawns a critical task onto the runtime.
453    ///
454    /// If this task panics, the [`TaskManager`] is notified.
455    pub fn spawn_critical_with_shutdown_signal<F>(
456        &self,
457        name: &'static str,
458        f: impl FnOnce(Shutdown) -> F,
459    ) -> JoinHandle<()>
460    where
461        F: Future<Output = ()> + Send + 'static,
462    {
463        let panicked_tasks_tx = self.panicked_tasks_tx.clone();
464        let on_shutdown = self.on_shutdown.clone();
465        let fut = f(on_shutdown);
466
467        // wrap the task in catch unwind
468        let task = std::panic::AssertUnwindSafe(fut)
469            .catch_unwind()
470            .map_err(move |error| {
471                let task_error = PanickedTaskError::new(name, error);
472                error!("{task_error}");
473                let _ = panicked_tasks_tx.send(task_error);
474            })
475            .map(drop)
476            .in_current_span();
477
478        self.handle.spawn(task)
479    }
480
481    /// This spawns a critical task onto the runtime.
482    ///
483    /// If this task panics, the [`TaskManager`] is notified.
484    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
485    ///
486    /// # Example
487    ///
488    /// ```no_run
489    /// # async fn t(executor: reth_tasks::TaskExecutor) {
490    ///
491    /// executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
492    ///     // await the shutdown signal
493    ///     let guard = shutdown.await;
494    ///     // do work before exiting the program
495    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
496    ///     // allow graceful shutdown
497    ///     drop(guard);
498    /// });
499    /// # }
500    /// ```
501    pub fn spawn_critical_with_graceful_shutdown_signal<F>(
502        &self,
503        name: &'static str,
504        f: impl FnOnce(GracefulShutdown) -> F,
505    ) -> JoinHandle<()>
506    where
507        F: Future<Output = ()> + Send + 'static,
508    {
509        let panicked_tasks_tx = self.panicked_tasks_tx.clone();
510        let on_shutdown = GracefulShutdown::new(
511            self.on_shutdown.clone(),
512            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
513        );
514        let fut = f(on_shutdown);
515
516        // wrap the task in catch unwind
517        let task = std::panic::AssertUnwindSafe(fut)
518            .catch_unwind()
519            .map_err(move |error| {
520                let task_error = PanickedTaskError::new(name, error);
521                error!("{task_error}");
522                let _ = panicked_tasks_tx.send(task_error);
523            })
524            .map(drop)
525            .in_current_span();
526
527        self.handle.spawn(task)
528    }
529
530    /// This spawns a regular task onto the runtime.
531    ///
532    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
533    ///
534    /// # Example
535    ///
536    /// ```no_run
537    /// # async fn t(executor: reth_tasks::TaskExecutor) {
538    ///
539    /// executor.spawn_with_graceful_shutdown_signal(|shutdown| async move {
540    ///     // await the shutdown signal
541    ///     let guard = shutdown.await;
542    ///     // do work before exiting the program
543    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
544    ///     // allow graceful shutdown
545    ///     drop(guard);
546    /// });
547    /// # }
548    /// ```
549    pub fn spawn_with_graceful_shutdown_signal<F>(
550        &self,
551        f: impl FnOnce(GracefulShutdown) -> F,
552    ) -> JoinHandle<()>
553    where
554        F: Future<Output = ()> + Send + 'static,
555    {
556        let on_shutdown = GracefulShutdown::new(
557            self.on_shutdown.clone(),
558            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
559        );
560        let fut = f(on_shutdown);
561
562        self.handle.spawn(fut)
563    }
564}
565
566impl TaskSpawner for TaskExecutor {
567    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
568        self.metrics.inc_regular_tasks();
569        self.spawn(fut)
570    }
571
572    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
573        self.metrics.inc_critical_tasks();
574        Self::spawn_critical(self, name, fut)
575    }
576
577    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
578        self.spawn_blocking(fut)
579    }
580
581    fn spawn_critical_blocking(
582        &self,
583        name: &'static str,
584        fut: BoxFuture<'static, ()>,
585    ) -> JoinHandle<()> {
586        Self::spawn_critical_blocking(self, name, fut)
587    }
588}
589
590/// `TaskSpawner` with extended behaviour
591#[auto_impl::auto_impl(&, Arc)]
592pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
593    /// This spawns a critical task onto the runtime.
594    ///
595    /// If this task panics, the [`TaskManager`] is notified.
596    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
597    fn spawn_critical_with_graceful_shutdown_signal<F>(
598        &self,
599        name: &'static str,
600        f: impl FnOnce(GracefulShutdown) -> F,
601    ) -> JoinHandle<()>
602    where
603        F: Future<Output = ()> + Send + 'static;
604
605    /// This spawns a regular task onto the runtime.
606    ///
607    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
608    fn spawn_with_graceful_shutdown_signal<F>(
609        &self,
610        f: impl FnOnce(GracefulShutdown) -> F,
611    ) -> JoinHandle<()>
612    where
613        F: Future<Output = ()> + Send + 'static;
614}
615
616impl TaskSpawnerExt for TaskExecutor {
617    fn spawn_critical_with_graceful_shutdown_signal<F>(
618        &self,
619        name: &'static str,
620        f: impl FnOnce(GracefulShutdown) -> F,
621    ) -> JoinHandle<()>
622    where
623        F: Future<Output = ()> + Send + 'static,
624    {
625        Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
626    }
627
628    fn spawn_with_graceful_shutdown_signal<F>(
629        &self,
630        f: impl FnOnce(GracefulShutdown) -> F,
631    ) -> JoinHandle<()>
632    where
633        F: Future<Output = ()> + Send + 'static,
634    {
635        Self::spawn_with_graceful_shutdown_signal(self, f)
636    }
637}
638
639/// Determines how a task is spawned
640enum TaskKind {
641    /// Spawn the task to the default executor [`Handle::spawn`]
642    Default,
643    /// Spawn the task to the blocking executor [`Handle::spawn_blocking`]
644    Blocking,
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use std::{sync::atomic::AtomicBool, time::Duration};
651
652    #[test]
653    fn test_cloneable() {
654        #[derive(Clone)]
655        struct ExecutorWrapper {
656            _e: Box<dyn TaskSpawner>,
657        }
658
659        let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
660        let _e = dyn_clone::clone_box(&*executor);
661
662        let e = ExecutorWrapper { _e };
663        let _e2 = e;
664    }
665
666    #[test]
667    fn test_critical() {
668        let runtime = tokio::runtime::Runtime::new().unwrap();
669        let handle = runtime.handle().clone();
670        let manager = TaskManager::new(handle);
671        let executor = manager.executor();
672
673        executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
674
675        runtime.block_on(async move {
676            let err = manager.await;
677            assert_eq!(err.task_name, "this is a critical task");
678            assert_eq!(err.error, Some("intentionally panic".to_string()));
679        })
680    }
681
682    // Tests that spawned tasks are terminated if the `TaskManager` drops
683    #[test]
684    fn test_manager_shutdown_critical() {
685        let runtime = tokio::runtime::Runtime::new().unwrap();
686        let handle = runtime.handle().clone();
687        let manager = TaskManager::new(handle.clone());
688        let executor = manager.executor();
689
690        let (signal, shutdown) = signal();
691
692        executor.spawn_critical("this is a critical task", async move {
693            tokio::time::sleep(Duration::from_millis(200)).await;
694            drop(signal);
695        });
696
697        drop(manager);
698
699        handle.block_on(shutdown);
700    }
701
702    // Tests that spawned tasks are terminated if the `TaskManager` drops
703    #[test]
704    fn test_manager_shutdown() {
705        let runtime = tokio::runtime::Runtime::new().unwrap();
706        let handle = runtime.handle().clone();
707        let manager = TaskManager::new(handle.clone());
708        let executor = manager.executor();
709
710        let (signal, shutdown) = signal();
711
712        executor.spawn(Box::pin(async move {
713            tokio::time::sleep(Duration::from_millis(200)).await;
714            drop(signal);
715        }));
716
717        drop(manager);
718
719        handle.block_on(shutdown);
720    }
721
722    #[test]
723    fn test_manager_graceful_shutdown() {
724        let runtime = tokio::runtime::Runtime::new().unwrap();
725        let handle = runtime.handle().clone();
726        let manager = TaskManager::new(handle);
727        let executor = manager.executor();
728
729        let val = Arc::new(AtomicBool::new(false));
730        let c = val.clone();
731        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
732            let _guard = shutdown.await;
733            tokio::time::sleep(Duration::from_millis(200)).await;
734            c.store(true, Ordering::Relaxed);
735        });
736
737        manager.graceful_shutdown();
738        assert!(val.load(Ordering::Relaxed));
739    }
740
741    #[test]
742    fn test_manager_graceful_shutdown_many() {
743        let runtime = tokio::runtime::Runtime::new().unwrap();
744        let handle = runtime.handle().clone();
745        let manager = TaskManager::new(handle);
746        let executor = manager.executor();
747
748        let counter = Arc::new(AtomicUsize::new(0));
749        let num = 10;
750        for _ in 0..num {
751            let c = counter.clone();
752            executor.spawn_critical_with_graceful_shutdown_signal(
753                "grace",
754                move |shutdown| async move {
755                    let _guard = shutdown.await;
756                    tokio::time::sleep(Duration::from_millis(200)).await;
757                    c.fetch_add(1, Ordering::SeqCst);
758                },
759            );
760        }
761
762        manager.graceful_shutdown();
763        assert_eq!(counter.load(Ordering::Relaxed), num);
764    }
765
766    #[test]
767    fn test_manager_graceful_shutdown_timeout() {
768        let runtime = tokio::runtime::Runtime::new().unwrap();
769        let handle = runtime.handle().clone();
770        let manager = TaskManager::new(handle);
771        let executor = manager.executor();
772
773        let timeout = Duration::from_millis(500);
774        let val = Arc::new(AtomicBool::new(false));
775        let val2 = val.clone();
776        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
777            let _guard = shutdown.await;
778            tokio::time::sleep(timeout * 3).await;
779            val2.store(true, Ordering::Relaxed);
780            unreachable!("should not be reached");
781        });
782
783        manager.graceful_shutdown_with_timeout(timeout);
784        assert!(!val.load(Ordering::Relaxed));
785    }
786}