reth_tasks/
lib.rs

1//! Reth task management.
2//!
3//! # Feature Flags
4//!
5//! - `rayon`: Enable rayon thread pool for blocking tasks.
6
7#![doc(
8    html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9    html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10    issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14
15use crate::{
16    metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17    shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21    future::{select, BoxFuture},
22    Future, FutureExt, TryFutureExt,
23};
24use std::{
25    any::Any,
26    fmt::{Display, Formatter},
27    pin::{pin, Pin},
28    sync::{
29        atomic::{AtomicUsize, Ordering},
30        Arc, OnceLock,
31    },
32    task::{ready, Context, Poll},
33};
34use tokio::{
35    runtime::Handle,
36    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37    task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48/// Global [`TaskExecutor`] instance that can be accessed from anywhere.
49static GLOBAL_EXECUTOR: OnceLock<TaskExecutor> = OnceLock::new();
50
51/// A type that can spawn tasks.
52///
53/// The main purpose of this type is to abstract over [`TaskExecutor`] so it's more convenient to
54/// provide default impls for testing.
55///
56///
57/// # Examples
58///
59/// Use the [`TokioTaskExecutor`] that spawns with [`tokio::task::spawn`]
60///
61/// ```
62/// # async fn t() {
63/// use reth_tasks::{TaskSpawner, TokioTaskExecutor};
64/// let executor = TokioTaskExecutor::default();
65///
66/// let task = executor.spawn(Box::pin(async {
67///     // -- snip --
68/// }));
69/// task.await.unwrap();
70/// # }
71/// ```
72///
73/// Use the [`TaskExecutor`] that spawns task directly onto the tokio runtime via the [Handle].
74///
75/// ```
76/// # use reth_tasks::TaskManager;
77/// fn t() {
78///  use reth_tasks::TaskSpawner;
79/// let rt = tokio::runtime::Runtime::new().unwrap();
80/// let manager = TaskManager::new(rt.handle().clone());
81/// let executor = manager.executor();
82/// let task = TaskSpawner::spawn(&executor, Box::pin(async {
83///     // -- snip --
84/// }));
85/// rt.block_on(task).unwrap();
86/// # }
87/// ```
88///
89/// The [`TaskSpawner`] trait is [`DynClone`] so `Box<dyn TaskSpawner>` are also `Clone`.
90#[auto_impl::auto_impl(&, Arc)]
91pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
92    /// Spawns the task onto the runtime.
93    /// See also [`Handle::spawn`].
94    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96    /// This spawns a critical task onto the runtime.
97    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99    /// Spawns a blocking task onto the runtime.
100    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
101
102    /// This spawns a critical blocking task onto the runtime.
103    fn spawn_critical_blocking(
104        &self,
105        name: &'static str,
106        fut: BoxFuture<'static, ()>,
107    ) -> JoinHandle<()>;
108}
109
110dyn_clone::clone_trait_object!(TaskSpawner);
111
112/// An [`TaskSpawner`] that uses [`tokio::task::spawn`] to execute tasks
113#[derive(Debug, Clone, Default)]
114#[non_exhaustive]
115pub struct TokioTaskExecutor;
116
117impl TokioTaskExecutor {
118    /// Converts the instance to a boxed [`TaskSpawner`].
119    pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
120        Box::new(self)
121    }
122}
123
124impl TaskSpawner for TokioTaskExecutor {
125    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
126        tokio::task::spawn(fut)
127    }
128
129    fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
130        tokio::task::spawn(fut)
131    }
132
133    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
134        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
135    }
136
137    fn spawn_critical_blocking(
138        &self,
139        _name: &'static str,
140        fut: BoxFuture<'static, ()>,
141    ) -> JoinHandle<()> {
142        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
143    }
144}
145
146/// Many reth components require to spawn tasks for long-running jobs. For example `discovery`
147/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks
148/// that handle the traffic to and from a peer.
149///
150/// To unify how tasks are created, the [`TaskManager`] provides access to the configured Tokio
151/// runtime. A [`TaskManager`] stores the [`tokio::runtime::Handle`] it is associated with. In this
152/// way it is possible to configure on which runtime a task is executed.
153///
154/// The main purpose of this type is to be able to monitor if a critical task panicked, for
155/// diagnostic purposes, since tokio task essentially fail silently. Therefore, this type is a
156/// Stream that yields the name of panicked task, See [`TaskExecutor::spawn_critical`]. In order to
157/// execute Tasks use the [`TaskExecutor`] type [`TaskManager::executor`].
158#[derive(Debug)]
159#[must_use = "TaskManager must be polled to monitor critical tasks"]
160pub struct TaskManager {
161    /// Handle to the tokio runtime this task manager is associated with.
162    ///
163    /// See [`Handle`] docs.
164    handle: Handle,
165    /// Sender half for sending task events to this type
166    task_events_tx: UnboundedSender<TaskEvent>,
167    /// Receiver for task events
168    task_events_rx: UnboundedReceiver<TaskEvent>,
169    /// The [Signal] to fire when all tasks should be shutdown.
170    ///
171    /// This is fired when dropped.
172    signal: Option<Signal>,
173    /// Receiver of the shutdown signal.
174    on_shutdown: Shutdown,
175    /// How many [`GracefulShutdown`] tasks are currently active
176    graceful_tasks: Arc<AtomicUsize>,
177}
178
179// === impl TaskManager ===
180
181impl TaskManager {
182    /// Returns a __new__ [`TaskManager`] over the currently running Runtime.
183    ///
184    /// This must be polled for the duration of the program.
185    ///
186    /// To obtain the current [`TaskExecutor`] see [`TaskExecutor::current`].
187    ///
188    /// # Panics
189    ///
190    /// This will panic if called outside the context of a Tokio runtime.
191    pub fn current() -> Self {
192        let handle = Handle::current();
193        Self::new(handle)
194    }
195
196    /// Create a new instance connected to the given handle's tokio runtime.
197    ///
198    /// This also sets the global [`TaskExecutor`].
199    pub fn new(handle: Handle) -> Self {
200        let (task_events_tx, task_events_rx) = unbounded_channel();
201        let (signal, on_shutdown) = signal();
202        let manager = Self {
203            handle,
204            task_events_tx,
205            task_events_rx,
206            signal: Some(signal),
207            on_shutdown,
208            graceful_tasks: Arc::new(AtomicUsize::new(0)),
209        };
210
211        let _ = GLOBAL_EXECUTOR
212            .set(manager.executor())
213            .inspect_err(|_| error!("Global executor already set"));
214
215        manager
216    }
217
218    /// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is
219    /// connected to.
220    pub fn executor(&self) -> TaskExecutor {
221        TaskExecutor {
222            handle: self.handle.clone(),
223            on_shutdown: self.on_shutdown.clone(),
224            task_events_tx: self.task_events_tx.clone(),
225            metrics: Default::default(),
226            graceful_tasks: Arc::clone(&self.graceful_tasks),
227        }
228    }
229
230    /// Fires the shutdown signal and awaits until all tasks are shutdown.
231    pub fn graceful_shutdown(self) {
232        let _ = self.do_graceful_shutdown(None);
233    }
234
235    /// Fires the shutdown signal and awaits until all tasks are shutdown.
236    ///
237    /// Returns true if all tasks were shutdown before the timeout elapsed.
238    pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
239        self.do_graceful_shutdown(Some(timeout))
240    }
241
242    fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
243        drop(self.signal);
244        let when = timeout.map(|t| std::time::Instant::now() + t);
245        while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
246            if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
247                debug!("graceful shutdown timed out");
248                return false
249            }
250            std::hint::spin_loop();
251        }
252
253        debug!("gracefully shut down");
254        true
255    }
256}
257
258/// An endless future that resolves if a critical task panicked.
259///
260/// See [`TaskExecutor::spawn_critical`]
261impl Future for TaskManager {
262    type Output = Result<(), PanickedTaskError>;
263
264    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265        match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
266            Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
267            Some(TaskEvent::GracefulShutdown) | None => {
268                if let Some(signal) = self.get_mut().signal.take() {
269                    signal.fire();
270                }
271                Poll::Ready(Ok(()))
272            }
273        }
274    }
275}
276
277/// Error with the name of the task that panicked and an error downcasted to string, if possible.
278#[derive(Debug, thiserror::Error, PartialEq, Eq)]
279pub struct PanickedTaskError {
280    task_name: &'static str,
281    error: Option<String>,
282}
283
284impl Display for PanickedTaskError {
285    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286        let task_name = self.task_name;
287        if let Some(error) = &self.error {
288            write!(f, "Critical task `{task_name}` panicked: `{error}`")
289        } else {
290            write!(f, "Critical task `{task_name}` panicked")
291        }
292    }
293}
294
295impl PanickedTaskError {
296    fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
297        let error = match error.downcast::<String>() {
298            Ok(value) => Some(*value),
299            Err(error) => match error.downcast::<&str>() {
300                Ok(value) => Some(value.to_string()),
301                Err(_) => None,
302            },
303        };
304
305        Self { task_name, error }
306    }
307}
308
309/// Represents the events that the `TaskManager`'s main future can receive.
310#[derive(Debug)]
311enum TaskEvent {
312    /// Indicates that a critical task has panicked.
313    Panic(PanickedTaskError),
314    /// A signal requesting a graceful shutdown of the `TaskManager`.
315    GracefulShutdown,
316}
317
318/// A type that can spawn new tokio tasks
319#[derive(Debug, Clone)]
320pub struct TaskExecutor {
321    /// Handle to the tokio runtime this task manager is associated with.
322    ///
323    /// See [`Handle`] docs.
324    handle: Handle,
325    /// Receiver of the shutdown signal.
326    on_shutdown: Shutdown,
327    /// Sender half for sending task events to this type
328    task_events_tx: UnboundedSender<TaskEvent>,
329    /// Task Executor Metrics
330    metrics: TaskExecutorMetrics,
331    /// How many [`GracefulShutdown`] tasks are currently active
332    graceful_tasks: Arc<AtomicUsize>,
333}
334
335// === impl TaskExecutor ===
336
337impl TaskExecutor {
338    /// Attempts to get the current `TaskExecutor` if one has been initialized.
339    ///
340    /// Returns an error if no [`TaskExecutor`] has been initialized via [`TaskManager`].
341    pub fn try_current() -> Result<Self, NoCurrentTaskExecutorError> {
342        GLOBAL_EXECUTOR.get().cloned().ok_or_else(NoCurrentTaskExecutorError::default)
343    }
344
345    /// Returns the current `TaskExecutor`.
346    ///
347    /// # Panics
348    ///
349    /// Panics if no global executor has been initialized. Use [`try_current`](Self::try_current)
350    /// for a non-panicking version.
351    pub fn current() -> Self {
352        Self::try_current().unwrap()
353    }
354
355    /// Returns the [Handle] to the tokio runtime.
356    pub const fn handle(&self) -> &Handle {
357        &self.handle
358    }
359
360    /// Returns the receiver of the shutdown signal.
361    pub const fn on_shutdown_signal(&self) -> &Shutdown {
362        &self.on_shutdown
363    }
364
365    /// Spawns a future on the tokio runtime depending on the [`TaskKind`]
366    fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
367    where
368        F: Future<Output = ()> + Send + 'static,
369    {
370        match task_kind {
371            TaskKind::Default => self.handle.spawn(fut),
372            TaskKind::Blocking => {
373                let handle = self.handle.clone();
374                self.handle.spawn_blocking(move || handle.block_on(fut))
375            }
376        }
377    }
378
379    /// Spawns a regular task depending on the given [`TaskKind`]
380    fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
381    where
382        F: Future<Output = ()> + Send + 'static,
383    {
384        let on_shutdown = self.on_shutdown.clone();
385
386        // Choose the appropriate finished counter based on task kind
387        let finished_counter = match task_kind {
388            TaskKind::Default => self.metrics.finished_regular_tasks_total.clone(),
389            TaskKind::Blocking => self.metrics.finished_regular_blocking_tasks_total.clone(),
390        };
391
392        // Wrap the original future to increment the finished tasks counter upon completion
393        let task = {
394            async move {
395                // Create an instance of IncCounterOnDrop with the counter to increment
396                let _inc_counter_on_drop = IncCounterOnDrop::new(finished_counter);
397                let fut = pin!(fut);
398                let _ = select(on_shutdown, fut).await;
399            }
400        }
401        .in_current_span();
402
403        self.spawn_on_rt(task, task_kind)
404    }
405
406    /// Spawns the task onto the runtime.
407    /// The given future resolves as soon as the [Shutdown] signal is received.
408    ///
409    /// See also [`Handle::spawn`].
410    pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
411    where
412        F: Future<Output = ()> + Send + 'static,
413    {
414        self.spawn_task_as(fut, TaskKind::Default)
415    }
416
417    /// Spawns a blocking task onto the runtime.
418    /// The given future resolves as soon as the [Shutdown] signal is received.
419    ///
420    /// See also [`Handle::spawn_blocking`].
421    pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
422    where
423        F: Future<Output = ()> + Send + 'static,
424    {
425        self.spawn_task_as(fut, TaskKind::Blocking)
426    }
427
428    /// Spawns the task onto the runtime.
429    /// The given future resolves as soon as the [Shutdown] signal is received.
430    ///
431    /// See also [`Handle::spawn`].
432    pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
433    where
434        F: Future<Output = ()> + Send + 'static,
435    {
436        let on_shutdown = self.on_shutdown.clone();
437        let fut = f(on_shutdown);
438
439        let task = fut.in_current_span();
440
441        self.handle.spawn(task)
442    }
443
444    /// Spawns a critical task depending on the given [`TaskKind`]
445    fn spawn_critical_as<F>(
446        &self,
447        name: &'static str,
448        fut: F,
449        task_kind: TaskKind,
450    ) -> JoinHandle<()>
451    where
452        F: Future<Output = ()> + Send + 'static,
453    {
454        let panicked_tasks_tx = self.task_events_tx.clone();
455        let on_shutdown = self.on_shutdown.clone();
456
457        // wrap the task in catch unwind
458        let task = std::panic::AssertUnwindSafe(fut)
459            .catch_unwind()
460            .map_err(move |error| {
461                let task_error = PanickedTaskError::new(name, error);
462                error!("{task_error}");
463                let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
464            })
465            .in_current_span();
466
467        // Clone only the specific counter that we need.
468        let finished_critical_tasks_total_metrics =
469            self.metrics.finished_critical_tasks_total.clone();
470        let task = async move {
471            // Create an instance of IncCounterOnDrop with the counter to increment
472            let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
473            let task = pin!(task);
474            let _ = select(on_shutdown, task).await;
475        };
476
477        self.spawn_on_rt(task, task_kind)
478    }
479
480    /// This spawns a critical blocking task onto the runtime.
481    /// The given future resolves as soon as the [Shutdown] signal is received.
482    ///
483    /// If this task panics, the [`TaskManager`] is notified.
484    pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
485    where
486        F: Future<Output = ()> + Send + 'static,
487    {
488        self.spawn_critical_as(name, fut, TaskKind::Blocking)
489    }
490
491    /// This spawns a critical task onto the runtime.
492    /// The given future resolves as soon as the [Shutdown] signal is received.
493    ///
494    /// If this task panics, the [`TaskManager`] is notified.
495    pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
496    where
497        F: Future<Output = ()> + Send + 'static,
498    {
499        self.spawn_critical_as(name, fut, TaskKind::Default)
500    }
501
502    /// This spawns a critical task onto the runtime.
503    ///
504    /// If this task panics, the [`TaskManager`] is notified.
505    pub fn spawn_critical_with_shutdown_signal<F>(
506        &self,
507        name: &'static str,
508        f: impl FnOnce(Shutdown) -> F,
509    ) -> JoinHandle<()>
510    where
511        F: Future<Output = ()> + Send + 'static,
512    {
513        let panicked_tasks_tx = self.task_events_tx.clone();
514        let on_shutdown = self.on_shutdown.clone();
515        let fut = f(on_shutdown);
516
517        // wrap the task in catch unwind
518        let task = std::panic::AssertUnwindSafe(fut)
519            .catch_unwind()
520            .map_err(move |error| {
521                let task_error = PanickedTaskError::new(name, error);
522                error!("{task_error}");
523                let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
524            })
525            .map(drop)
526            .in_current_span();
527
528        self.handle.spawn(task)
529    }
530
531    /// This spawns a critical task onto the runtime.
532    ///
533    /// If this task panics, the [`TaskManager`] is notified.
534    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
535    ///
536    /// # Example
537    ///
538    /// ```no_run
539    /// # async fn t(executor: reth_tasks::TaskExecutor) {
540    ///
541    /// executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
542    ///     // await the shutdown signal
543    ///     let guard = shutdown.await;
544    ///     // do work before exiting the program
545    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
546    ///     // allow graceful shutdown
547    ///     drop(guard);
548    /// });
549    /// # }
550    /// ```
551    pub fn spawn_critical_with_graceful_shutdown_signal<F>(
552        &self,
553        name: &'static str,
554        f: impl FnOnce(GracefulShutdown) -> F,
555    ) -> JoinHandle<()>
556    where
557        F: Future<Output = ()> + Send + 'static,
558    {
559        let panicked_tasks_tx = self.task_events_tx.clone();
560        let on_shutdown = GracefulShutdown::new(
561            self.on_shutdown.clone(),
562            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
563        );
564        let fut = f(on_shutdown);
565
566        // wrap the task in catch unwind
567        let task = std::panic::AssertUnwindSafe(fut)
568            .catch_unwind()
569            .map_err(move |error| {
570                let task_error = PanickedTaskError::new(name, error);
571                error!("{task_error}");
572                let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
573            })
574            .map(drop)
575            .in_current_span();
576
577        self.handle.spawn(task)
578    }
579
580    /// This spawns a regular task onto the runtime.
581    ///
582    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
583    ///
584    /// # Example
585    ///
586    /// ```no_run
587    /// # async fn t(executor: reth_tasks::TaskExecutor) {
588    ///
589    /// executor.spawn_with_graceful_shutdown_signal(|shutdown| async move {
590    ///     // await the shutdown signal
591    ///     let guard = shutdown.await;
592    ///     // do work before exiting the program
593    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
594    ///     // allow graceful shutdown
595    ///     drop(guard);
596    /// });
597    /// # }
598    /// ```
599    pub fn spawn_with_graceful_shutdown_signal<F>(
600        &self,
601        f: impl FnOnce(GracefulShutdown) -> F,
602    ) -> JoinHandle<()>
603    where
604        F: Future<Output = ()> + Send + 'static,
605    {
606        let on_shutdown = GracefulShutdown::new(
607            self.on_shutdown.clone(),
608            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
609        );
610        let fut = f(on_shutdown);
611
612        self.handle.spawn(fut)
613    }
614
615    /// Sends a request to the `TaskManager` to initiate a graceful shutdown.
616    ///
617    /// Caution: This will terminate the entire program.
618    ///
619    /// The [`TaskManager`] upon receiving this event, will terminate and initiate the shutdown that
620    /// can be handled via the returned [`GracefulShutdown`].
621    pub fn initiate_graceful_shutdown(
622        &self,
623    ) -> Result<GracefulShutdown, tokio::sync::mpsc::error::SendError<()>> {
624        self.task_events_tx
625            .send(TaskEvent::GracefulShutdown)
626            .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?;
627
628        Ok(GracefulShutdown::new(
629            self.on_shutdown.clone(),
630            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
631        ))
632    }
633}
634
635impl TaskSpawner for TaskExecutor {
636    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
637        self.metrics.inc_regular_tasks();
638        self.spawn(fut)
639    }
640
641    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
642        self.metrics.inc_critical_tasks();
643        Self::spawn_critical(self, name, fut)
644    }
645
646    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
647        self.metrics.inc_regular_blocking_tasks();
648        self.spawn_blocking(fut)
649    }
650
651    fn spawn_critical_blocking(
652        &self,
653        name: &'static str,
654        fut: BoxFuture<'static, ()>,
655    ) -> JoinHandle<()> {
656        self.metrics.inc_critical_tasks();
657        Self::spawn_critical_blocking(self, name, fut)
658    }
659}
660
661/// `TaskSpawner` with extended behaviour
662#[auto_impl::auto_impl(&, Arc)]
663pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
664    /// This spawns a critical task onto the runtime.
665    ///
666    /// If this task panics, the [`TaskManager`] is notified.
667    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
668    fn spawn_critical_with_graceful_shutdown_signal<F>(
669        &self,
670        name: &'static str,
671        f: impl FnOnce(GracefulShutdown) -> F,
672    ) -> JoinHandle<()>
673    where
674        F: Future<Output = ()> + Send + 'static;
675
676    /// This spawns a regular task onto the runtime.
677    ///
678    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
679    fn spawn_with_graceful_shutdown_signal<F>(
680        &self,
681        f: impl FnOnce(GracefulShutdown) -> F,
682    ) -> JoinHandle<()>
683    where
684        F: Future<Output = ()> + Send + 'static;
685}
686
687impl TaskSpawnerExt for TaskExecutor {
688    fn spawn_critical_with_graceful_shutdown_signal<F>(
689        &self,
690        name: &'static str,
691        f: impl FnOnce(GracefulShutdown) -> F,
692    ) -> JoinHandle<()>
693    where
694        F: Future<Output = ()> + Send + 'static,
695    {
696        Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
697    }
698
699    fn spawn_with_graceful_shutdown_signal<F>(
700        &self,
701        f: impl FnOnce(GracefulShutdown) -> F,
702    ) -> JoinHandle<()>
703    where
704        F: Future<Output = ()> + Send + 'static,
705    {
706        Self::spawn_with_graceful_shutdown_signal(self, f)
707    }
708}
709
710/// Determines how a task is spawned
711enum TaskKind {
712    /// Spawn the task to the default executor [`Handle::spawn`]
713    Default,
714    /// Spawn the task to the blocking executor [`Handle::spawn_blocking`]
715    Blocking,
716}
717
718/// Error returned by `try_current` when no task executor has been configured.
719#[derive(Debug, Default, thiserror::Error)]
720#[error("No current task executor available.")]
721#[non_exhaustive]
722pub struct NoCurrentTaskExecutorError;
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727    use std::{sync::atomic::AtomicBool, time::Duration};
728
729    #[test]
730    fn test_cloneable() {
731        #[derive(Clone)]
732        struct ExecutorWrapper {
733            _e: Box<dyn TaskSpawner>,
734        }
735
736        let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
737        let _e = dyn_clone::clone_box(&*executor);
738
739        let e = ExecutorWrapper { _e };
740        let _e2 = e;
741    }
742
743    #[test]
744    fn test_critical() {
745        let runtime = tokio::runtime::Runtime::new().unwrap();
746        let handle = runtime.handle().clone();
747        let manager = TaskManager::new(handle);
748        let executor = manager.executor();
749
750        executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
751
752        runtime.block_on(async move {
753            let err_result = manager.await;
754            assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
755            let panicked_err = err_result.unwrap_err();
756
757            assert_eq!(panicked_err.task_name, "this is a critical task");
758            assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
759        })
760    }
761
762    // Tests that spawned tasks are terminated if the `TaskManager` drops
763    #[test]
764    fn test_manager_shutdown_critical() {
765        let runtime = tokio::runtime::Runtime::new().unwrap();
766        let handle = runtime.handle().clone();
767        let manager = TaskManager::new(handle.clone());
768        let executor = manager.executor();
769
770        let (signal, shutdown) = signal();
771
772        executor.spawn_critical("this is a critical task", async move {
773            tokio::time::sleep(Duration::from_millis(200)).await;
774            drop(signal);
775        });
776
777        drop(manager);
778
779        handle.block_on(shutdown);
780    }
781
782    // Tests that spawned tasks are terminated if the `TaskManager` drops
783    #[test]
784    fn test_manager_shutdown() {
785        let runtime = tokio::runtime::Runtime::new().unwrap();
786        let handle = runtime.handle().clone();
787        let manager = TaskManager::new(handle.clone());
788        let executor = manager.executor();
789
790        let (signal, shutdown) = signal();
791
792        executor.spawn(Box::pin(async move {
793            tokio::time::sleep(Duration::from_millis(200)).await;
794            drop(signal);
795        }));
796
797        drop(manager);
798
799        handle.block_on(shutdown);
800    }
801
802    #[test]
803    fn test_manager_graceful_shutdown() {
804        let runtime = tokio::runtime::Runtime::new().unwrap();
805        let handle = runtime.handle().clone();
806        let manager = TaskManager::new(handle);
807        let executor = manager.executor();
808
809        let val = Arc::new(AtomicBool::new(false));
810        let c = val.clone();
811        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
812            let _guard = shutdown.await;
813            tokio::time::sleep(Duration::from_millis(200)).await;
814            c.store(true, Ordering::Relaxed);
815        });
816
817        manager.graceful_shutdown();
818        assert!(val.load(Ordering::Relaxed));
819    }
820
821    #[test]
822    fn test_manager_graceful_shutdown_many() {
823        let runtime = tokio::runtime::Runtime::new().unwrap();
824        let handle = runtime.handle().clone();
825        let manager = TaskManager::new(handle);
826        let executor = manager.executor();
827
828        let counter = Arc::new(AtomicUsize::new(0));
829        let num = 10;
830        for _ in 0..num {
831            let c = counter.clone();
832            executor.spawn_critical_with_graceful_shutdown_signal(
833                "grace",
834                move |shutdown| async move {
835                    let _guard = shutdown.await;
836                    tokio::time::sleep(Duration::from_millis(200)).await;
837                    c.fetch_add(1, Ordering::SeqCst);
838                },
839            );
840        }
841
842        manager.graceful_shutdown();
843        assert_eq!(counter.load(Ordering::Relaxed), num);
844    }
845
846    #[test]
847    fn test_manager_graceful_shutdown_timeout() {
848        let runtime = tokio::runtime::Runtime::new().unwrap();
849        let handle = runtime.handle().clone();
850        let manager = TaskManager::new(handle);
851        let executor = manager.executor();
852
853        let timeout = Duration::from_millis(500);
854        let val = Arc::new(AtomicBool::new(false));
855        let val2 = val.clone();
856        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
857            let _guard = shutdown.await;
858            tokio::time::sleep(timeout * 3).await;
859            val2.store(true, Ordering::Relaxed);
860            unreachable!("should not be reached");
861        });
862
863        manager.graceful_shutdown_with_timeout(timeout);
864        assert!(!val.load(Ordering::Relaxed));
865    }
866
867    #[test]
868    fn can_access_global() {
869        let runtime = tokio::runtime::Runtime::new().unwrap();
870        let handle = runtime.handle().clone();
871        let _manager = TaskManager::new(handle);
872        let _executor = TaskExecutor::try_current().unwrap();
873    }
874
875    #[test]
876    fn test_graceful_shutdown_triggered_by_executor() {
877        let runtime = tokio::runtime::Runtime::new().unwrap();
878        let task_manager = TaskManager::new(runtime.handle().clone());
879        let executor = task_manager.executor();
880
881        let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
882        let flag_clone = task_did_shutdown_flag.clone();
883
884        let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move {
885            shutdown_signal.await;
886            flag_clone.store(true, Ordering::SeqCst);
887        });
888
889        let manager_future_handle = runtime.spawn(task_manager);
890
891        let send_result = executor.initiate_graceful_shutdown();
892        assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future");
893
894        let manager_final_result = runtime.block_on(manager_future_handle);
895
896        assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
897        assert_eq!(
898            manager_final_result.unwrap(),
899            Ok(()),
900            "TaskManager should resolve cleanly with Ok(()) after graceful shutdown request"
901        );
902
903        let task_join_result = runtime.block_on(spawned_task_handle);
904        assert!(task_join_result.is_ok(), "Spawned task should complete without panic");
905
906        assert!(
907            task_did_shutdown_flag.load(Ordering::Relaxed),
908            "Task should have received the shutdown signal and set the flag"
909        );
910    }
911}