1#![doc(
8 html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9 html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10 issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14
15use crate::shutdown::{signal, GracefulShutdown, Shutdown, Signal};
16use dyn_clone::DynClone;
17use futures_util::future::BoxFuture;
18use std::{
19 any::Any,
20 fmt::{Display, Formatter},
21 pin::Pin,
22 sync::{
23 atomic::{AtomicUsize, Ordering},
24 Arc,
25 },
26 task::{ready, Context, Poll},
27 thread,
28};
29use tokio::{
30 runtime::Handle,
31 sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
32 task::JoinHandle,
33};
34use tracing::debug;
35
36pub mod metrics;
37pub mod runtime;
38pub mod shutdown;
39
40#[cfg(feature = "rayon")]
41pub mod pool;
42
43#[cfg(feature = "rayon")]
44pub use runtime::RayonConfig;
45pub use runtime::{Runtime, RuntimeBuildError, RuntimeBuilder, RuntimeConfig, TokioConfig};
46
47pub type TaskExecutor = Runtime;
49
50#[track_caller]
57pub fn spawn_os_thread<F, T>(name: &str, f: F) -> thread::JoinHandle<T>
58where
59 F: FnOnce() -> T + Send + 'static,
60 T: Send + 'static,
61{
62 let handle = Handle::try_current().ok();
63 thread::Builder::new()
64 .name(name.to_string())
65 .spawn(move || {
66 let _guard = handle.as_ref().map(Handle::enter);
67 f()
68 })
69 .unwrap_or_else(|e| panic!("failed to spawn thread {name:?}: {e}"))
70}
71
72#[track_caller]
76pub fn spawn_scoped_os_thread<'scope, 'env, F, T>(
77 scope: &'scope thread::Scope<'scope, 'env>,
78 name: &str,
79 f: F,
80) -> thread::ScopedJoinHandle<'scope, T>
81where
82 F: FnOnce() -> T + Send + 'scope,
83 T: Send + 'scope,
84{
85 let handle = Handle::try_current().ok();
86 thread::Builder::new()
87 .name(name.to_string())
88 .spawn_scoped(scope, move || {
89 let _guard = handle.as_ref().map(Handle::enter);
90 f()
91 })
92 .unwrap_or_else(|e| panic!("failed to spawn scoped thread {name:?}: {e}"))
93}
94
95#[auto_impl::auto_impl(&, Arc)]
134pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
135 fn spawn_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
138
139 fn spawn_critical_task(
141 &self,
142 name: &'static str,
143 fut: BoxFuture<'static, ()>,
144 ) -> JoinHandle<()>;
145
146 fn spawn_blocking_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
148
149 fn spawn_critical_blocking_task(
151 &self,
152 name: &'static str,
153 fut: BoxFuture<'static, ()>,
154 ) -> JoinHandle<()>;
155}
156
157dyn_clone::clone_trait_object!(TaskSpawner);
158
159#[derive(Debug, Clone, Default)]
161#[non_exhaustive]
162pub struct TokioTaskExecutor;
163
164impl TokioTaskExecutor {
165 pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
167 Box::new(self)
168 }
169}
170
171impl TaskSpawner for TokioTaskExecutor {
172 fn spawn_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
173 tokio::task::spawn(fut)
174 }
175
176 fn spawn_critical_task(
177 &self,
178 _name: &'static str,
179 fut: BoxFuture<'static, ()>,
180 ) -> JoinHandle<()> {
181 tokio::task::spawn(fut)
182 }
183
184 fn spawn_blocking_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
185 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
186 }
187
188 fn spawn_critical_blocking_task(
189 &self,
190 _name: &'static str,
191 fut: BoxFuture<'static, ()>,
192 ) -> JoinHandle<()> {
193 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
194 }
195}
196
197#[derive(Debug)]
207#[must_use = "TaskManager must be polled to monitor critical tasks"]
208pub struct TaskManager {
209 task_events_rx: UnboundedReceiver<TaskEvent>,
211 signal: Option<Signal>,
215 graceful_tasks: Arc<AtomicUsize>,
217}
218
219impl TaskManager {
222 pub(crate) fn new_parts(
225 _handle: Handle,
226 ) -> (Self, Shutdown, UnboundedSender<TaskEvent>, Arc<AtomicUsize>) {
227 let (task_events_tx, task_events_rx) = unbounded_channel();
228 let (signal, on_shutdown) = signal();
229 let graceful_tasks = Arc::new(AtomicUsize::new(0));
230 let manager = Self {
231 task_events_rx,
232 signal: Some(signal),
233 graceful_tasks: Arc::clone(&graceful_tasks),
234 };
235 (manager, on_shutdown, task_events_tx, graceful_tasks)
236 }
237
238 pub fn graceful_shutdown(self) {
240 let _ = self.do_graceful_shutdown(None);
241 }
242
243 pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
247 self.do_graceful_shutdown(Some(timeout))
248 }
249
250 fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
251 drop(self.signal);
252 let deadline = timeout.map(|t| std::time::Instant::now() + t);
253 while self.graceful_tasks.load(Ordering::SeqCst) > 0 {
254 if deadline.is_some_and(|d| std::time::Instant::now() > d) {
255 debug!("graceful shutdown timed out");
256 return false;
257 }
258 thread::yield_now();
259 }
260 debug!("gracefully shut down");
261 true
262 }
263}
264
265impl std::future::Future for TaskManager {
269 type Output = Result<(), PanickedTaskError>;
270
271 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
272 match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
273 Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
274 Some(TaskEvent::GracefulShutdown) | None => {
275 if let Some(signal) = self.get_mut().signal.take() {
276 signal.fire();
277 }
278 Poll::Ready(Ok(()))
279 }
280 }
281 }
282}
283
284#[derive(Debug, thiserror::Error, PartialEq, Eq)]
286pub struct PanickedTaskError {
287 task_name: &'static str,
288 error: Option<String>,
289}
290
291impl Display for PanickedTaskError {
292 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293 let task_name = self.task_name;
294 if let Some(error) = &self.error {
295 write!(f, "Critical task `{task_name}` panicked: `{error}`")
296 } else {
297 write!(f, "Critical task `{task_name}` panicked")
298 }
299 }
300}
301
302impl PanickedTaskError {
303 pub(crate) fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
304 let error = match error.downcast::<String>() {
305 Ok(value) => Some(*value),
306 Err(error) => match error.downcast::<&str>() {
307 Ok(value) => Some(value.to_string()),
308 Err(_) => None,
309 },
310 };
311
312 Self { task_name, error }
313 }
314}
315
316#[derive(Debug)]
318pub(crate) enum TaskEvent {
319 Panic(PanickedTaskError),
321 GracefulShutdown,
323}
324
325#[auto_impl::auto_impl(&, Arc)]
327pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
328 fn spawn_critical_with_graceful_shutdown_signal<F>(
333 &self,
334 name: &'static str,
335 f: impl FnOnce(GracefulShutdown) -> F,
336 ) -> JoinHandle<()>
337 where
338 F: std::future::Future<Output = ()> + Send + 'static;
339
340 fn spawn_with_graceful_shutdown_signal<F>(
344 &self,
345 f: impl FnOnce(GracefulShutdown) -> F,
346 ) -> JoinHandle<()>
347 where
348 F: std::future::Future<Output = ()> + Send + 'static;
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use std::{
355 sync::atomic::{AtomicBool, AtomicUsize, Ordering},
356 time::Duration,
357 };
358
359 #[test]
360 fn test_cloneable() {
361 #[derive(Clone)]
362 struct ExecutorWrapper {
363 _e: Box<dyn TaskSpawner>,
364 }
365
366 let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
367 let _e = dyn_clone::clone_box(&*executor);
368
369 let e = ExecutorWrapper { _e };
370 let _e2 = e;
371 }
372
373 #[test]
374 fn test_critical() {
375 let runtime = tokio::runtime::Runtime::new().unwrap();
376 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
377 let handle = rt.take_task_manager_handle().unwrap();
378
379 rt.spawn_critical_task("this is a critical task", async { panic!("intentionally panic") });
380
381 runtime.block_on(async move {
382 let err_result = handle.await.unwrap();
383 assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
384 let panicked_err = err_result.unwrap_err();
385
386 assert_eq!(panicked_err.task_name, "this is a critical task");
387 assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
388 })
389 }
390
391 #[test]
392 fn test_manager_shutdown_critical() {
393 let runtime = tokio::runtime::Runtime::new().unwrap();
394 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
395
396 let (signal, shutdown) = signal();
397
398 rt.spawn_critical_task("this is a critical task", async move {
399 tokio::time::sleep(Duration::from_millis(200)).await;
400 drop(signal);
401 });
402
403 rt.graceful_shutdown();
404
405 runtime.block_on(shutdown);
406 }
407
408 #[test]
409 fn test_manager_shutdown() {
410 let runtime = tokio::runtime::Runtime::new().unwrap();
411 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
412
413 let (signal, shutdown) = signal();
414
415 rt.spawn_task(Box::pin(async move {
416 tokio::time::sleep(Duration::from_millis(200)).await;
417 drop(signal);
418 }));
419
420 rt.graceful_shutdown();
421
422 runtime.block_on(shutdown);
423 }
424
425 #[test]
426 fn test_manager_graceful_shutdown() {
427 let runtime = tokio::runtime::Runtime::new().unwrap();
428 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
429
430 let val = Arc::new(AtomicBool::new(false));
431 let c = val.clone();
432 rt.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
433 let _guard = shutdown.await;
434 tokio::time::sleep(Duration::from_millis(200)).await;
435 c.store(true, Ordering::Relaxed);
436 });
437
438 rt.graceful_shutdown();
439 assert!(val.load(Ordering::Relaxed));
440 }
441
442 #[test]
443 fn test_manager_graceful_shutdown_many() {
444 let runtime = tokio::runtime::Runtime::new().unwrap();
445 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
446
447 let counter = Arc::new(AtomicUsize::new(0));
448 let num = 10;
449 for _ in 0..num {
450 let c = counter.clone();
451 rt.spawn_critical_with_graceful_shutdown_signal("grace", move |shutdown| async move {
452 let _guard = shutdown.await;
453 tokio::time::sleep(Duration::from_millis(200)).await;
454 c.fetch_add(1, Ordering::SeqCst);
455 });
456 }
457
458 rt.graceful_shutdown();
459 assert_eq!(counter.load(Ordering::Relaxed), num);
460 }
461
462 #[test]
463 fn test_manager_graceful_shutdown_timeout() {
464 let runtime = tokio::runtime::Runtime::new().unwrap();
465 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
466
467 let timeout = Duration::from_millis(500);
468 let val = Arc::new(AtomicBool::new(false));
469 let val2 = val.clone();
470 rt.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
471 let _guard = shutdown.await;
472 tokio::time::sleep(timeout * 3).await;
473 val2.store(true, Ordering::Relaxed);
474 unreachable!("should not be reached");
475 });
476
477 rt.graceful_shutdown_with_timeout(timeout);
478 assert!(!val.load(Ordering::Relaxed));
479 }
480
481 #[test]
482 fn can_build_runtime() {
483 let runtime = tokio::runtime::Runtime::new().unwrap();
484 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
485 let _handle = rt.handle();
486 }
487
488 #[test]
489 fn test_graceful_shutdown_triggered_by_executor() {
490 let runtime = tokio::runtime::Runtime::new().unwrap();
491 let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap();
492 let task_manager_handle = rt.take_task_manager_handle().unwrap();
493
494 let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
495 let flag_clone = task_did_shutdown_flag.clone();
496
497 let spawned_task_handle = rt.spawn_with_signal(|shutdown_signal| async move {
498 shutdown_signal.await;
499 flag_clone.store(true, Ordering::SeqCst);
500 });
501
502 let send_result = rt.initiate_graceful_shutdown();
503 assert!(send_result.is_ok());
504
505 let manager_final_result = runtime.block_on(task_manager_handle);
506 assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
507 assert_eq!(manager_final_result.unwrap(), Ok(()));
508
509 let task_join_result = runtime.block_on(spawned_task_handle);
510 assert!(task_join_result.is_ok());
511
512 assert!(task_did_shutdown_flag.load(Ordering::Relaxed));
513 }
514}