Skip to main content

reth_tasks/
lib.rs

1//! Reth task management.
2//!
3//! # Feature Flags
4//!
5//! - `rayon`: Enable rayon thread pool for blocking tasks.
6
7#![doc(
8    html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9    html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10    issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14
15use crate::shutdown::{signal, Shutdown, Signal};
16use std::{
17    any::Any,
18    fmt::{Display, Formatter},
19    pin::Pin,
20    sync::{
21        atomic::{AtomicUsize, Ordering},
22        Arc,
23    },
24    task::{ready, Context, Poll},
25    thread,
26};
27use tokio::{
28    runtime::Handle,
29    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
30};
31use tracing::debug;
32
33pub mod lazy;
34pub mod metrics;
35pub mod runtime;
36pub mod shutdown;
37pub mod utils;
38pub(crate) mod worker_map;
39
40#[cfg(feature = "rayon")]
41pub mod pool;
42#[cfg(feature = "rayon")]
43pub use pool::{Worker, WorkerPool};
44
45/// Lock-free ordered parallel iterator extension trait.
46#[cfg(feature = "rayon")]
47pub mod for_each_ordered;
48#[cfg(feature = "rayon")]
49pub use for_each_ordered::ForEachOrdered;
50
51pub use lazy::LazyHandle;
52#[cfg(feature = "rayon")]
53pub use runtime::RayonConfig;
54pub use runtime::{Runtime, RuntimeBuildError, RuntimeBuilder, RuntimeConfig, TokioConfig};
55
56/// A [`TaskExecutor`] is now an alias for [`Runtime`].
57pub type TaskExecutor = Runtime;
58
59/// Spawns an OS thread with the current tokio runtime context propagated.
60///
61/// This function captures the current tokio runtime handle (if available) and enters it
62/// in the newly spawned thread. This ensures that code running in the spawned thread can
63/// use [`Handle::current()`], [`Handle::spawn_blocking()`], and other tokio utilities that
64/// require a runtime context.
65#[track_caller]
66pub fn spawn_os_thread<F, T>(name: &str, f: F) -> thread::JoinHandle<T>
67where
68    F: FnOnce() -> T + Send + 'static,
69    T: Send + 'static,
70{
71    let handle = Handle::try_current().ok();
72    thread::Builder::new()
73        .name(name.to_string())
74        .spawn(move || {
75            let _guard = handle.as_ref().map(Handle::enter);
76            f()
77        })
78        .unwrap_or_else(|e| panic!("failed to spawn thread {name:?}: {e}"))
79}
80
81/// Spawns a scoped OS thread with the current tokio runtime context propagated.
82///
83/// This is the scoped thread version of [`spawn_os_thread`], for use with [`std::thread::scope`].
84#[track_caller]
85pub fn spawn_scoped_os_thread<'scope, 'env, F, T>(
86    scope: &'scope thread::Scope<'scope, 'env>,
87    name: &str,
88    f: F,
89) -> thread::ScopedJoinHandle<'scope, T>
90where
91    F: FnOnce() -> T + Send + 'scope,
92    T: Send + 'scope,
93{
94    let handle = Handle::try_current().ok();
95    thread::Builder::new()
96        .name(name.to_string())
97        .spawn_scoped(scope, move || {
98            let _guard = handle.as_ref().map(Handle::enter);
99            f()
100        })
101        .unwrap_or_else(|e| panic!("failed to spawn scoped thread {name:?}: {e}"))
102}
103
104/// Monitors critical tasks for panics and manages graceful shutdown.
105///
106/// The main purpose of this type is to be able to monitor if a critical task panicked, for
107/// diagnostic purposes, since tokio tasks essentially fail silently. Therefore, this type is a
108/// Future that resolves with the name of the panicked task. See [`Runtime::spawn_critical_task`].
109///
110/// Automatically spawned as a background task when building a [`Runtime`]. Use
111/// [`Runtime::take_task_manager_handle`] to extract the join handle if you need to poll for
112/// panic errors directly.
113#[derive(Debug)]
114#[must_use = "TaskManager must be polled to monitor critical tasks"]
115pub struct TaskManager {
116    /// Receiver for task events.
117    task_events_rx: UnboundedReceiver<TaskEvent>,
118    /// The [Signal] to fire when all tasks should be shutdown.
119    ///
120    /// This is fired when dropped.
121    signal: Option<Signal>,
122    /// How many [`GracefulShutdown`](crate::shutdown::GracefulShutdown) tasks are currently
123    /// active.
124    graceful_tasks: Arc<AtomicUsize>,
125}
126
127// === impl TaskManager ===
128
129impl TaskManager {
130    /// Create a new [`TaskManager`] without an associated [`Runtime`], returning
131    /// the shutdown/event primitives for [`RuntimeBuilder`] to wire up.
132    pub(crate) fn new_parts(
133        _handle: Handle,
134    ) -> (Self, Shutdown, UnboundedSender<TaskEvent>, Arc<AtomicUsize>) {
135        let (task_events_tx, task_events_rx) = unbounded_channel();
136        let (signal, on_shutdown) = signal();
137        let graceful_tasks = Arc::new(AtomicUsize::new(0));
138        let manager = Self {
139            task_events_rx,
140            signal: Some(signal),
141            graceful_tasks: Arc::clone(&graceful_tasks),
142        };
143        (manager, on_shutdown, task_events_tx, graceful_tasks)
144    }
145
146    /// Fires the shutdown signal and awaits until all tasks are shutdown.
147    pub fn graceful_shutdown(self) {
148        let _ = self.do_graceful_shutdown(None);
149    }
150
151    /// Fires the shutdown signal and awaits until all tasks are shutdown.
152    ///
153    /// Returns true if all tasks were shutdown before the timeout elapsed.
154    pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
155        self.do_graceful_shutdown(Some(timeout))
156    }
157
158    fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
159        drop(self.signal);
160        let deadline = timeout.map(|t| std::time::Instant::now() + t);
161        while self.graceful_tasks.load(Ordering::SeqCst) > 0 {
162            if deadline.is_some_and(|d| std::time::Instant::now() > d) {
163                debug!("graceful shutdown timed out");
164                return false;
165            }
166            thread::yield_now();
167        }
168        debug!("gracefully shut down");
169        true
170    }
171}
172
173/// An endless future that resolves if a critical task panicked.
174///
175/// See [`Runtime::spawn_critical_task`]
176impl std::future::Future for TaskManager {
177    type Output = Result<(), PanickedTaskError>;
178
179    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180        match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
181            Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
182            Some(TaskEvent::GracefulShutdown) | None => {
183                if let Some(signal) = self.get_mut().signal.take() {
184                    signal.fire();
185                }
186                Poll::Ready(Ok(()))
187            }
188        }
189    }
190}
191
192/// Error with the name of the task that panicked and an error downcasted to string, if possible.
193#[derive(Debug, thiserror::Error, PartialEq, Eq)]
194pub struct PanickedTaskError {
195    task_name: &'static str,
196    error: Option<String>,
197}
198
199impl Display for PanickedTaskError {
200    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
201        let task_name = self.task_name;
202        if let Some(error) = &self.error {
203            write!(f, "Critical task `{task_name}` panicked: `{error}`")
204        } else {
205            write!(f, "Critical task `{task_name}` panicked")
206        }
207    }
208}
209
210impl PanickedTaskError {
211    pub(crate) fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
212        let error = match error.downcast::<String>() {
213            Ok(value) => Some(*value),
214            Err(error) => match error.downcast::<&str>() {
215                Ok(value) => Some(value.to_string()),
216                Err(_) => None,
217            },
218        };
219
220        Self { task_name, error }
221    }
222}
223
224/// Represents the events that the `TaskManager`'s main future can receive.
225#[derive(Debug)]
226pub(crate) enum TaskEvent {
227    /// Indicates that a critical task has panicked.
228    Panic(PanickedTaskError),
229    /// A signal requesting a graceful shutdown of the `TaskManager`.
230    GracefulShutdown,
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use std::{
237        sync::atomic::{AtomicBool, AtomicUsize, Ordering},
238        time::Duration,
239    };
240
241    #[test]
242    fn test_critical() {
243        let rt = Runtime::test();
244        let handle = rt.take_task_manager_handle().unwrap();
245
246        rt.spawn_critical_task("this is a critical task", async { panic!("intentionally panic") });
247
248        rt.handle().block_on(async move {
249            let err_result = handle.await.unwrap();
250            assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
251            let panicked_err = err_result.unwrap_err();
252
253            assert_eq!(panicked_err.task_name, "this is a critical task");
254            assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
255        })
256    }
257
258    #[test]
259    fn test_manager_shutdown_critical() {
260        let rt = Runtime::test();
261
262        let (signal, shutdown) = signal();
263
264        rt.spawn_critical_task("this is a critical task", async move {
265            tokio::time::sleep(Duration::from_millis(200)).await;
266            drop(signal);
267        });
268
269        rt.graceful_shutdown();
270
271        rt.handle().block_on(shutdown);
272    }
273
274    #[test]
275    fn test_manager_shutdown() {
276        let rt = Runtime::test();
277
278        let (signal, shutdown) = signal();
279
280        rt.spawn_task(async move {
281            tokio::time::sleep(Duration::from_millis(200)).await;
282            drop(signal);
283        });
284
285        rt.graceful_shutdown();
286
287        rt.handle().block_on(shutdown);
288    }
289
290    #[test]
291    fn test_manager_graceful_shutdown() {
292        let rt = Runtime::test();
293
294        let val = Arc::new(AtomicBool::new(false));
295        let c = val.clone();
296        rt.spawn_critical_with_graceful_shutdown_signal("grace", async move |shutdown| {
297            let _guard = shutdown.await;
298            tokio::time::sleep(Duration::from_millis(200)).await;
299            c.store(true, Ordering::Relaxed);
300        });
301
302        rt.graceful_shutdown();
303        assert!(val.load(Ordering::Relaxed));
304    }
305
306    #[test]
307    fn test_manager_graceful_shutdown_many() {
308        let rt = Runtime::test();
309
310        let counter = Arc::new(AtomicUsize::new(0));
311        let num = 10;
312        for _ in 0..num {
313            let c = counter.clone();
314            rt.spawn_critical_with_graceful_shutdown_signal("grace", async move |shutdown| {
315                let _guard = shutdown.await;
316                tokio::time::sleep(Duration::from_millis(200)).await;
317                c.fetch_add(1, Ordering::SeqCst);
318            });
319        }
320
321        rt.graceful_shutdown();
322        assert_eq!(counter.load(Ordering::Relaxed), num);
323    }
324
325    #[test]
326    fn test_manager_graceful_shutdown_timeout() {
327        let rt = Runtime::test();
328
329        let timeout = Duration::from_millis(500);
330        let val = Arc::new(AtomicBool::new(false));
331        let val2 = val.clone();
332        rt.spawn_critical_with_graceful_shutdown_signal("grace", async move |shutdown| {
333            let _guard = shutdown.await;
334            tokio::time::sleep(timeout * 3).await;
335            val2.store(true, Ordering::Relaxed);
336            unreachable!("should not be reached");
337        });
338
339        rt.graceful_shutdown_with_timeout(timeout);
340        assert!(!val.load(Ordering::Relaxed));
341    }
342
343    #[test]
344    fn can_build_runtime() {
345        let rt = Runtime::test();
346        let _handle = rt.handle();
347    }
348
349    #[test]
350    fn test_graceful_shutdown_triggered_by_executor() {
351        let rt = Runtime::test();
352        let task_manager_handle = rt.take_task_manager_handle().unwrap();
353
354        let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
355        let flag_clone = task_did_shutdown_flag.clone();
356
357        let spawned_task_handle = rt.spawn_with_signal(async move |shutdown_signal| {
358            shutdown_signal.await;
359            flag_clone.store(true, Ordering::SeqCst);
360        });
361
362        let send_result = rt.initiate_graceful_shutdown();
363        assert!(send_result.is_ok());
364
365        let manager_final_result = rt.handle().block_on(task_manager_handle);
366        assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
367        assert_eq!(manager_final_result.unwrap(), Ok(()));
368
369        let task_join_result = rt.handle().block_on(spawned_task_handle);
370        assert!(task_join_result.is_ok());
371
372        assert!(task_did_shutdown_flag.load(Ordering::Relaxed));
373    }
374}