Skip to main content

reth_tasks/
lib.rs

1//! Reth task management.
2//!
3//! # Feature Flags
4//!
5//! - `rayon`: Enable rayon thread pool for blocking tasks.
6
7#![doc(
8    html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9    html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10    issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14
15use crate::shutdown::{signal, GracefulShutdown, Shutdown, Signal};
16use dyn_clone::DynClone;
17use futures_util::future::BoxFuture;
18use std::{
19    any::Any,
20    fmt::{Display, Formatter},
21    pin::Pin,
22    sync::{
23        atomic::{AtomicUsize, Ordering},
24        Arc,
25    },
26    task::{ready, Context, Poll},
27    thread,
28};
29use tokio::{
30    runtime::Handle,
31    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
32    task::JoinHandle,
33};
34use tracing::debug;
35
36pub mod metrics;
37pub mod runtime;
38pub mod shutdown;
39
40#[cfg(feature = "rayon")]
41pub mod pool;
42
43#[cfg(feature = "rayon")]
44pub use runtime::RayonConfig;
45pub use runtime::{Runtime, RuntimeBuildError, RuntimeBuilder, RuntimeConfig, TokioConfig};
46
47/// A [`TaskExecutor`] is now an alias for [`Runtime`].
48pub type TaskExecutor = Runtime;
49
50/// Spawns an OS thread with the current tokio runtime context propagated.
51///
52/// This function captures the current tokio runtime handle (if available) and enters it
53/// in the newly spawned thread. This ensures that code running in the spawned thread can
54/// use [`Handle::current()`], [`Handle::spawn_blocking()`], and other tokio utilities that
55/// require a runtime context.
56#[track_caller]
57pub fn spawn_os_thread<F, T>(name: &str, f: F) -> thread::JoinHandle<T>
58where
59    F: FnOnce() -> T + Send + 'static,
60    T: Send + 'static,
61{
62    let handle = Handle::try_current().ok();
63    thread::Builder::new()
64        .name(name.to_string())
65        .spawn(move || {
66            let _guard = handle.as_ref().map(Handle::enter);
67            f()
68        })
69        .unwrap_or_else(|e| panic!("failed to spawn thread {name:?}: {e}"))
70}
71
72/// Spawns a scoped OS thread with the current tokio runtime context propagated.
73///
74/// This is the scoped thread version of [`spawn_os_thread`], for use with [`std::thread::scope`].
75#[track_caller]
76pub fn spawn_scoped_os_thread<'scope, 'env, F, T>(
77    scope: &'scope thread::Scope<'scope, 'env>,
78    name: &str,
79    f: F,
80) -> thread::ScopedJoinHandle<'scope, T>
81where
82    F: FnOnce() -> T + Send + 'scope,
83    T: Send + 'scope,
84{
85    let handle = Handle::try_current().ok();
86    thread::Builder::new()
87        .name(name.to_string())
88        .spawn_scoped(scope, move || {
89            let _guard = handle.as_ref().map(Handle::enter);
90            f()
91        })
92        .unwrap_or_else(|e| panic!("failed to spawn scoped thread {name:?}: {e}"))
93}
94
95/// A type that can spawn tasks.
96///
97/// The main purpose of this type is to abstract over [`Runtime`] so it's more convenient to
98/// provide default impls for testing.
99///
100///
101/// # Examples
102///
103/// Use the [`TokioTaskExecutor`] that spawns with [`tokio::task::spawn`]
104///
105/// ```
106/// # async fn t() {
107/// use reth_tasks::{TaskSpawner, TokioTaskExecutor};
108/// let executor = TokioTaskExecutor::default();
109///
110/// let task = executor.spawn_task(Box::pin(async {
111///     // -- snip --
112/// }));
113/// task.await.unwrap();
114/// # }
115/// ```
116///
117/// Use the [`Runtime`] that spawns task directly onto the tokio runtime via the [Handle].
118///
119/// ```
120/// # use reth_tasks::Runtime;
121/// fn t() {
122///  use reth_tasks::TaskSpawner;
123/// let rt = tokio::runtime::Runtime::new().unwrap();
124/// let runtime = Runtime::with_existing_handle(rt.handle().clone()).unwrap();
125/// let task = TaskSpawner::spawn_task(&runtime, Box::pin(async {
126///     // -- snip --
127/// }));
128/// rt.block_on(task).unwrap();
129/// # }
130/// ```
131///
132/// The [`TaskSpawner`] trait is [`DynClone`] so `Box<dyn TaskSpawner>` are also `Clone`.
133#[auto_impl::auto_impl(&, Arc)]
134pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
135    /// Spawns the task onto the runtime.
136    /// See also [`Handle::spawn`].
137    fn spawn_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
138
139    /// This spawns a critical task onto the runtime.
140    fn spawn_critical_task(
141        &self,
142        name: &'static str,
143        fut: BoxFuture<'static, ()>,
144    ) -> JoinHandle<()>;
145
146    /// Spawns a blocking task onto the runtime.
147    fn spawn_blocking_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
148
149    /// This spawns a critical blocking task onto the runtime.
150    fn spawn_critical_blocking_task(
151        &self,
152        name: &'static str,
153        fut: BoxFuture<'static, ()>,
154    ) -> JoinHandle<()>;
155}
156
157dyn_clone::clone_trait_object!(TaskSpawner);
158
159/// An [`TaskSpawner`] that uses [`tokio::task::spawn`] to execute tasks
160#[derive(Debug, Clone, Default)]
161#[non_exhaustive]
162pub struct TokioTaskExecutor;
163
164impl TokioTaskExecutor {
165    /// Converts the instance to a boxed [`TaskSpawner`].
166    pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
167        Box::new(self)
168    }
169}
170
171impl TaskSpawner for TokioTaskExecutor {
172    fn spawn_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
173        tokio::task::spawn(fut)
174    }
175
176    fn spawn_critical_task(
177        &self,
178        _name: &'static str,
179        fut: BoxFuture<'static, ()>,
180    ) -> JoinHandle<()> {
181        tokio::task::spawn(fut)
182    }
183
184    fn spawn_blocking_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
185        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
186    }
187
188    fn spawn_critical_blocking_task(
189        &self,
190        _name: &'static str,
191        fut: BoxFuture<'static, ()>,
192    ) -> JoinHandle<()> {
193        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
194    }
195}
196
197/// Monitors critical tasks for panics and manages graceful shutdown.
198///
199/// The main purpose of this type is to be able to monitor if a critical task panicked, for
200/// diagnostic purposes, since tokio tasks essentially fail silently. Therefore, this type is a
201/// Future that resolves with the name of the panicked task. See [`Runtime::spawn_critical_task`].
202///
203/// Automatically spawned as a background task when building a [`Runtime`]. Use
204/// [`Runtime::take_task_manager_handle`] to extract the join handle if you need to poll for
205/// panic errors directly.
206#[derive(Debug)]
207#[must_use = "TaskManager must be polled to monitor critical tasks"]
208pub struct TaskManager {
209    /// Receiver for task events.
210    task_events_rx: UnboundedReceiver<TaskEvent>,
211    /// The [Signal] to fire when all tasks should be shutdown.
212    ///
213    /// This is fired when dropped.
214    signal: Option<Signal>,
215    /// How many [`GracefulShutdown`] tasks are currently active.
216    graceful_tasks: Arc<AtomicUsize>,
217}
218
219// === impl TaskManager ===
220
221impl TaskManager {
222    /// Create a new [`TaskManager`] without an associated [`Runtime`], returning
223    /// the shutdown/event primitives for [`RuntimeBuilder`] to wire up.
224    pub(crate) fn new_parts(
225        _handle: Handle,
226    ) -> (Self, Shutdown, UnboundedSender<TaskEvent>, Arc<AtomicUsize>) {
227        let (task_events_tx, task_events_rx) = unbounded_channel();
228        let (signal, on_shutdown) = signal();
229        let graceful_tasks = Arc::new(AtomicUsize::new(0));
230        let manager = Self {
231            task_events_rx,
232            signal: Some(signal),
233            graceful_tasks: Arc::clone(&graceful_tasks),
234        };
235        (manager, on_shutdown, task_events_tx, graceful_tasks)
236    }
237
238    /// Fires the shutdown signal and awaits until all tasks are shutdown.
239    pub fn graceful_shutdown(self) {
240        let _ = self.do_graceful_shutdown(None);
241    }
242
243    /// Fires the shutdown signal and awaits until all tasks are shutdown.
244    ///
245    /// Returns true if all tasks were shutdown before the timeout elapsed.
246    pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
247        self.do_graceful_shutdown(Some(timeout))
248    }
249
250    fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
251        drop(self.signal);
252        let deadline = timeout.map(|t| std::time::Instant::now() + t);
253        while self.graceful_tasks.load(Ordering::SeqCst) > 0 {
254            if deadline.is_some_and(|d| std::time::Instant::now() > d) {
255                debug!("graceful shutdown timed out");
256                return false;
257            }
258            thread::yield_now();
259        }
260        debug!("gracefully shut down");
261        true
262    }
263}
264
265/// An endless future that resolves if a critical task panicked.
266///
267/// See [`Runtime::spawn_critical_task`]
268impl std::future::Future for TaskManager {
269    type Output = Result<(), PanickedTaskError>;
270
271    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
272        match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
273            Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
274            Some(TaskEvent::GracefulShutdown) | None => {
275                if let Some(signal) = self.get_mut().signal.take() {
276                    signal.fire();
277                }
278                Poll::Ready(Ok(()))
279            }
280        }
281    }
282}
283
284/// Error with the name of the task that panicked and an error downcasted to string, if possible.
285#[derive(Debug, thiserror::Error, PartialEq, Eq)]
286pub struct PanickedTaskError {
287    task_name: &'static str,
288    error: Option<String>,
289}
290
291impl Display for PanickedTaskError {
292    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293        let task_name = self.task_name;
294        if let Some(error) = &self.error {
295            write!(f, "Critical task `{task_name}` panicked: `{error}`")
296        } else {
297            write!(f, "Critical task `{task_name}` panicked")
298        }
299    }
300}
301
302impl PanickedTaskError {
303    pub(crate) fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
304        let error = match error.downcast::<String>() {
305            Ok(value) => Some(*value),
306            Err(error) => match error.downcast::<&str>() {
307                Ok(value) => Some(value.to_string()),
308                Err(_) => None,
309            },
310        };
311
312        Self { task_name, error }
313    }
314}
315
316/// Represents the events that the `TaskManager`'s main future can receive.
317#[derive(Debug)]
318pub(crate) enum TaskEvent {
319    /// Indicates that a critical task has panicked.
320    Panic(PanickedTaskError),
321    /// A signal requesting a graceful shutdown of the `TaskManager`.
322    GracefulShutdown,
323}
324
325/// `TaskSpawner` with extended behaviour
326#[auto_impl::auto_impl(&, Arc)]
327pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
328    /// This spawns a critical task onto the runtime.
329    ///
330    /// If this task panics, the [`TaskManager`] is notified.
331    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
332    fn spawn_critical_with_graceful_shutdown_signal<F>(
333        &self,
334        name: &'static str,
335        f: impl FnOnce(GracefulShutdown) -> F,
336    ) -> JoinHandle<()>
337    where
338        F: std::future::Future<Output = ()> + Send + 'static;
339
340    /// This spawns a regular task onto the runtime.
341    ///
342    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
343    fn spawn_with_graceful_shutdown_signal<F>(
344        &self,
345        f: impl FnOnce(GracefulShutdown) -> F,
346    ) -> JoinHandle<()>
347    where
348        F: std::future::Future<Output = ()> + Send + 'static;
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use std::{
355        sync::atomic::{AtomicBool, AtomicUsize, Ordering},
356        time::Duration,
357    };
358
359    #[test]
360    fn test_cloneable() {
361        #[derive(Clone)]
362        struct ExecutorWrapper {
363            _e: Box<dyn TaskSpawner>,
364        }
365
366        let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
367        let _e = dyn_clone::clone_box(&*executor);
368
369        let e = ExecutorWrapper { _e };
370        let _e2 = e;
371    }
372
373    #[test]
374    fn test_critical() {
375        let runtime = tokio::runtime::Runtime::new().unwrap();
376        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
377        let handle = rt.take_task_manager_handle().unwrap();
378
379        rt.spawn_critical_task("this is a critical task", async { panic!("intentionally panic") });
380
381        runtime.block_on(async move {
382            let err_result = handle.await.unwrap();
383            assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
384            let panicked_err = err_result.unwrap_err();
385
386            assert_eq!(panicked_err.task_name, "this is a critical task");
387            assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
388        })
389    }
390
391    #[test]
392    fn test_manager_shutdown_critical() {
393        let runtime = tokio::runtime::Runtime::new().unwrap();
394        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
395
396        let (signal, shutdown) = signal();
397
398        rt.spawn_critical_task("this is a critical task", async move {
399            tokio::time::sleep(Duration::from_millis(200)).await;
400            drop(signal);
401        });
402
403        rt.graceful_shutdown();
404
405        runtime.block_on(shutdown);
406    }
407
408    #[test]
409    fn test_manager_shutdown() {
410        let runtime = tokio::runtime::Runtime::new().unwrap();
411        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
412
413        let (signal, shutdown) = signal();
414
415        rt.spawn_task(Box::pin(async move {
416            tokio::time::sleep(Duration::from_millis(200)).await;
417            drop(signal);
418        }));
419
420        rt.graceful_shutdown();
421
422        runtime.block_on(shutdown);
423    }
424
425    #[test]
426    fn test_manager_graceful_shutdown() {
427        let runtime = tokio::runtime::Runtime::new().unwrap();
428        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
429
430        let val = Arc::new(AtomicBool::new(false));
431        let c = val.clone();
432        rt.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
433            let _guard = shutdown.await;
434            tokio::time::sleep(Duration::from_millis(200)).await;
435            c.store(true, Ordering::Relaxed);
436        });
437
438        rt.graceful_shutdown();
439        assert!(val.load(Ordering::Relaxed));
440    }
441
442    #[test]
443    fn test_manager_graceful_shutdown_many() {
444        let runtime = tokio::runtime::Runtime::new().unwrap();
445        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
446
447        let counter = Arc::new(AtomicUsize::new(0));
448        let num = 10;
449        for _ in 0..num {
450            let c = counter.clone();
451            rt.spawn_critical_with_graceful_shutdown_signal("grace", move |shutdown| async move {
452                let _guard = shutdown.await;
453                tokio::time::sleep(Duration::from_millis(200)).await;
454                c.fetch_add(1, Ordering::SeqCst);
455            });
456        }
457
458        rt.graceful_shutdown();
459        assert_eq!(counter.load(Ordering::Relaxed), num);
460    }
461
462    #[test]
463    fn test_manager_graceful_shutdown_timeout() {
464        let runtime = tokio::runtime::Runtime::new().unwrap();
465        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
466
467        let timeout = Duration::from_millis(500);
468        let val = Arc::new(AtomicBool::new(false));
469        let val2 = val.clone();
470        rt.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
471            let _guard = shutdown.await;
472            tokio::time::sleep(timeout * 3).await;
473            val2.store(true, Ordering::Relaxed);
474            unreachable!("should not be reached");
475        });
476
477        rt.graceful_shutdown_with_timeout(timeout);
478        assert!(!val.load(Ordering::Relaxed));
479    }
480
481    #[test]
482    fn can_build_runtime() {
483        let runtime = tokio::runtime::Runtime::new().unwrap();
484        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
485        let _handle = rt.handle();
486    }
487
488    #[test]
489    fn test_graceful_shutdown_triggered_by_executor() {
490        let runtime = tokio::runtime::Runtime::new().unwrap();
491        let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
492        let task_manager_handle = rt.take_task_manager_handle().unwrap();
493
494        let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
495        let flag_clone = task_did_shutdown_flag.clone();
496
497        let spawned_task_handle = rt.spawn_with_signal(|shutdown_signal| async move {
498            shutdown_signal.await;
499            flag_clone.store(true, Ordering::SeqCst);
500        });
501
502        let send_result = rt.initiate_graceful_shutdown();
503        assert!(send_result.is_ok());
504
505        let manager_final_result = runtime.block_on(task_manager_handle);
506        assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
507        assert_eq!(manager_final_result.unwrap(), Ok(()));
508
509        let task_join_result = runtime.block_on(spawned_task_handle);
510        assert!(task_join_result.is_ok());
511
512        assert!(task_did_shutdown_flag.load(Ordering::Relaxed));
513    }
514}