1#![doc(
8 html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9 html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10 issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
14
15use crate::{
16 metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17 shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21 future::{select, BoxFuture},
22 Future, FutureExt, TryFutureExt,
23};
24use std::{
25 any::Any,
26 fmt::{Display, Formatter},
27 pin::{pin, Pin},
28 sync::{
29 atomic::{AtomicUsize, Ordering},
30 Arc, OnceLock,
31 },
32 task::{ready, Context, Poll},
33};
34use tokio::{
35 runtime::Handle,
36 sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37 task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48static GLOBAL_EXECUTOR: OnceLock<TaskExecutor> = OnceLock::new();
50
51#[auto_impl::auto_impl(&, Arc)]
91pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
92 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
101
102 fn spawn_critical_blocking(
104 &self,
105 name: &'static str,
106 fut: BoxFuture<'static, ()>,
107 ) -> JoinHandle<()>;
108}
109
110dyn_clone::clone_trait_object!(TaskSpawner);
111
112#[derive(Debug, Clone, Default)]
114#[non_exhaustive]
115pub struct TokioTaskExecutor;
116
117impl TokioTaskExecutor {
118 pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
120 Box::new(self)
121 }
122}
123
124impl TaskSpawner for TokioTaskExecutor {
125 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
126 tokio::task::spawn(fut)
127 }
128
129 fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
130 tokio::task::spawn(fut)
131 }
132
133 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
134 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
135 }
136
137 fn spawn_critical_blocking(
138 &self,
139 _name: &'static str,
140 fut: BoxFuture<'static, ()>,
141 ) -> JoinHandle<()> {
142 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
143 }
144}
145
146#[derive(Debug)]
159#[must_use = "TaskManager must be polled to monitor critical tasks"]
160pub struct TaskManager {
161 handle: Handle,
165 panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
167 panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
169 signal: Option<Signal>,
173 on_shutdown: Shutdown,
175 graceful_tasks: Arc<AtomicUsize>,
177}
178
179impl TaskManager {
182 pub fn current() -> Self {
188 let handle = Handle::current();
189 Self::new(handle)
190 }
191
192 pub fn new(handle: Handle) -> Self {
196 let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
197 let (signal, on_shutdown) = signal();
198 let manager = Self {
199 handle,
200 panicked_tasks_tx,
201 panicked_tasks_rx,
202 signal: Some(signal),
203 on_shutdown,
204 graceful_tasks: Arc::new(AtomicUsize::new(0)),
205 };
206
207 let _ = GLOBAL_EXECUTOR
208 .set(manager.executor())
209 .inspect_err(|_| error!("Global executor already set"));
210
211 manager
212 }
213
214 pub fn executor(&self) -> TaskExecutor {
217 TaskExecutor {
218 handle: self.handle.clone(),
219 on_shutdown: self.on_shutdown.clone(),
220 panicked_tasks_tx: self.panicked_tasks_tx.clone(),
221 metrics: Default::default(),
222 graceful_tasks: Arc::clone(&self.graceful_tasks),
223 }
224 }
225
226 pub fn graceful_shutdown(self) {
228 let _ = self.do_graceful_shutdown(None);
229 }
230
231 pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
235 self.do_graceful_shutdown(Some(timeout))
236 }
237
238 fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
239 drop(self.signal);
240 let when = timeout.map(|t| std::time::Instant::now() + t);
241 while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
242 if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
243 debug!("graceful shutdown timed out");
244 return false
245 }
246 std::hint::spin_loop();
247 }
248
249 debug!("gracefully shut down");
250 true
251 }
252}
253
254impl Future for TaskManager {
258 type Output = PanickedTaskError;
259
260 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
261 let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
262 Poll::Ready(err.expect("stream can not end"))
263 }
264}
265
266#[derive(Debug, thiserror::Error)]
268pub struct PanickedTaskError {
269 task_name: &'static str,
270 error: Option<String>,
271}
272
273impl Display for PanickedTaskError {
274 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
275 let task_name = self.task_name;
276 if let Some(error) = &self.error {
277 write!(f, "Critical task `{task_name}` panicked: `{error}`")
278 } else {
279 write!(f, "Critical task `{task_name}` panicked")
280 }
281 }
282}
283
284impl PanickedTaskError {
285 fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
286 let error = match error.downcast::<String>() {
287 Ok(value) => Some(*value),
288 Err(error) => match error.downcast::<&str>() {
289 Ok(value) => Some(value.to_string()),
290 Err(_) => None,
291 },
292 };
293
294 Self { task_name, error }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct TaskExecutor {
301 handle: Handle,
305 on_shutdown: Shutdown,
307 panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
309 metrics: TaskExecutorMetrics,
311 graceful_tasks: Arc<AtomicUsize>,
313}
314
315impl TaskExecutor {
318 pub fn try_current() -> Result<Self, NoCurrentTaskExecutorError> {
322 GLOBAL_EXECUTOR.get().cloned().ok_or_else(NoCurrentTaskExecutorError::default)
323 }
324
325 pub fn current() -> Self {
332 Self::try_current().unwrap()
333 }
334
335 pub const fn handle(&self) -> &Handle {
337 &self.handle
338 }
339
340 pub const fn on_shutdown_signal(&self) -> &Shutdown {
342 &self.on_shutdown
343 }
344
345 fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
347 where
348 F: Future<Output = ()> + Send + 'static,
349 {
350 match task_kind {
351 TaskKind::Default => self.handle.spawn(fut),
352 TaskKind::Blocking => {
353 let handle = self.handle.clone();
354 self.handle.spawn_blocking(move || handle.block_on(fut))
355 }
356 }
357 }
358
359 fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
361 where
362 F: Future<Output = ()> + Send + 'static,
363 {
364 let on_shutdown = self.on_shutdown.clone();
365
366 let finished_regular_tasks_total_metrics =
368 self.metrics.finished_regular_tasks_total.clone();
369 let task = {
371 async move {
372 let _inc_counter_on_drop =
374 IncCounterOnDrop::new(finished_regular_tasks_total_metrics);
375 let fut = pin!(fut);
376 let _ = select(on_shutdown, fut).await;
377 }
378 }
379 .in_current_span();
380
381 self.spawn_on_rt(task, task_kind)
382 }
383
384 pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
389 where
390 F: Future<Output = ()> + Send + 'static,
391 {
392 self.spawn_task_as(fut, TaskKind::Default)
393 }
394
395 pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
400 where
401 F: Future<Output = ()> + Send + 'static,
402 {
403 self.spawn_task_as(fut, TaskKind::Blocking)
404 }
405
406 pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
411 where
412 F: Future<Output = ()> + Send + 'static,
413 {
414 let on_shutdown = self.on_shutdown.clone();
415 let fut = f(on_shutdown);
416
417 let task = fut.in_current_span();
418
419 self.handle.spawn(task)
420 }
421
422 fn spawn_critical_as<F>(
424 &self,
425 name: &'static str,
426 fut: F,
427 task_kind: TaskKind,
428 ) -> JoinHandle<()>
429 where
430 F: Future<Output = ()> + Send + 'static,
431 {
432 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
433 let on_shutdown = self.on_shutdown.clone();
434
435 let task = std::panic::AssertUnwindSafe(fut)
437 .catch_unwind()
438 .map_err(move |error| {
439 let task_error = PanickedTaskError::new(name, error);
440 error!("{task_error}");
441 let _ = panicked_tasks_tx.send(task_error);
442 })
443 .in_current_span();
444
445 let finished_critical_tasks_total_metrics =
447 self.metrics.finished_critical_tasks_total.clone();
448 let task = async move {
449 let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
451 let task = pin!(task);
452 let _ = select(on_shutdown, task).await;
453 };
454
455 self.spawn_on_rt(task, task_kind)
456 }
457
458 pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
463 where
464 F: Future<Output = ()> + Send + 'static,
465 {
466 self.spawn_critical_as(name, fut, TaskKind::Blocking)
467 }
468
469 pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
474 where
475 F: Future<Output = ()> + Send + 'static,
476 {
477 self.spawn_critical_as(name, fut, TaskKind::Default)
478 }
479
480 pub fn spawn_critical_with_shutdown_signal<F>(
484 &self,
485 name: &'static str,
486 f: impl FnOnce(Shutdown) -> F,
487 ) -> JoinHandle<()>
488 where
489 F: Future<Output = ()> + Send + 'static,
490 {
491 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
492 let on_shutdown = self.on_shutdown.clone();
493 let fut = f(on_shutdown);
494
495 let task = std::panic::AssertUnwindSafe(fut)
497 .catch_unwind()
498 .map_err(move |error| {
499 let task_error = PanickedTaskError::new(name, error);
500 error!("{task_error}");
501 let _ = panicked_tasks_tx.send(task_error);
502 })
503 .map(drop)
504 .in_current_span();
505
506 self.handle.spawn(task)
507 }
508
509 pub fn spawn_critical_with_graceful_shutdown_signal<F>(
530 &self,
531 name: &'static str,
532 f: impl FnOnce(GracefulShutdown) -> F,
533 ) -> JoinHandle<()>
534 where
535 F: Future<Output = ()> + Send + 'static,
536 {
537 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
538 let on_shutdown = GracefulShutdown::new(
539 self.on_shutdown.clone(),
540 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
541 );
542 let fut = f(on_shutdown);
543
544 let task = std::panic::AssertUnwindSafe(fut)
546 .catch_unwind()
547 .map_err(move |error| {
548 let task_error = PanickedTaskError::new(name, error);
549 error!("{task_error}");
550 let _ = panicked_tasks_tx.send(task_error);
551 })
552 .map(drop)
553 .in_current_span();
554
555 self.handle.spawn(task)
556 }
557
558 pub fn spawn_with_graceful_shutdown_signal<F>(
578 &self,
579 f: impl FnOnce(GracefulShutdown) -> F,
580 ) -> JoinHandle<()>
581 where
582 F: Future<Output = ()> + Send + 'static,
583 {
584 let on_shutdown = GracefulShutdown::new(
585 self.on_shutdown.clone(),
586 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
587 );
588 let fut = f(on_shutdown);
589
590 self.handle.spawn(fut)
591 }
592}
593
594impl TaskSpawner for TaskExecutor {
595 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
596 self.metrics.inc_regular_tasks();
597 self.spawn(fut)
598 }
599
600 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
601 self.metrics.inc_critical_tasks();
602 Self::spawn_critical(self, name, fut)
603 }
604
605 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
606 self.spawn_blocking(fut)
607 }
608
609 fn spawn_critical_blocking(
610 &self,
611 name: &'static str,
612 fut: BoxFuture<'static, ()>,
613 ) -> JoinHandle<()> {
614 Self::spawn_critical_blocking(self, name, fut)
615 }
616}
617
618#[auto_impl::auto_impl(&, Arc)]
620pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
621 fn spawn_critical_with_graceful_shutdown_signal<F>(
626 &self,
627 name: &'static str,
628 f: impl FnOnce(GracefulShutdown) -> F,
629 ) -> JoinHandle<()>
630 where
631 F: Future<Output = ()> + Send + 'static;
632
633 fn spawn_with_graceful_shutdown_signal<F>(
637 &self,
638 f: impl FnOnce(GracefulShutdown) -> F,
639 ) -> JoinHandle<()>
640 where
641 F: Future<Output = ()> + Send + 'static;
642}
643
644impl TaskSpawnerExt for TaskExecutor {
645 fn spawn_critical_with_graceful_shutdown_signal<F>(
646 &self,
647 name: &'static str,
648 f: impl FnOnce(GracefulShutdown) -> F,
649 ) -> JoinHandle<()>
650 where
651 F: Future<Output = ()> + Send + 'static,
652 {
653 Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
654 }
655
656 fn spawn_with_graceful_shutdown_signal<F>(
657 &self,
658 f: impl FnOnce(GracefulShutdown) -> F,
659 ) -> JoinHandle<()>
660 where
661 F: Future<Output = ()> + Send + 'static,
662 {
663 Self::spawn_with_graceful_shutdown_signal(self, f)
664 }
665}
666
667enum TaskKind {
669 Default,
671 Blocking,
673}
674
675#[derive(Debug, Default, thiserror::Error)]
677#[error("No current task executor available.")]
678#[non_exhaustive]
679pub struct NoCurrentTaskExecutorError;
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684 use std::{sync::atomic::AtomicBool, time::Duration};
685
686 #[test]
687 fn test_cloneable() {
688 #[derive(Clone)]
689 struct ExecutorWrapper {
690 _e: Box<dyn TaskSpawner>,
691 }
692
693 let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
694 let _e = dyn_clone::clone_box(&*executor);
695
696 let e = ExecutorWrapper { _e };
697 let _e2 = e;
698 }
699
700 #[test]
701 fn test_critical() {
702 let runtime = tokio::runtime::Runtime::new().unwrap();
703 let handle = runtime.handle().clone();
704 let manager = TaskManager::new(handle);
705 let executor = manager.executor();
706
707 executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
708
709 runtime.block_on(async move {
710 let err = manager.await;
711 assert_eq!(err.task_name, "this is a critical task");
712 assert_eq!(err.error, Some("intentionally panic".to_string()));
713 })
714 }
715
716 #[test]
718 fn test_manager_shutdown_critical() {
719 let runtime = tokio::runtime::Runtime::new().unwrap();
720 let handle = runtime.handle().clone();
721 let manager = TaskManager::new(handle.clone());
722 let executor = manager.executor();
723
724 let (signal, shutdown) = signal();
725
726 executor.spawn_critical("this is a critical task", async move {
727 tokio::time::sleep(Duration::from_millis(200)).await;
728 drop(signal);
729 });
730
731 drop(manager);
732
733 handle.block_on(shutdown);
734 }
735
736 #[test]
738 fn test_manager_shutdown() {
739 let runtime = tokio::runtime::Runtime::new().unwrap();
740 let handle = runtime.handle().clone();
741 let manager = TaskManager::new(handle.clone());
742 let executor = manager.executor();
743
744 let (signal, shutdown) = signal();
745
746 executor.spawn(Box::pin(async move {
747 tokio::time::sleep(Duration::from_millis(200)).await;
748 drop(signal);
749 }));
750
751 drop(manager);
752
753 handle.block_on(shutdown);
754 }
755
756 #[test]
757 fn test_manager_graceful_shutdown() {
758 let runtime = tokio::runtime::Runtime::new().unwrap();
759 let handle = runtime.handle().clone();
760 let manager = TaskManager::new(handle);
761 let executor = manager.executor();
762
763 let val = Arc::new(AtomicBool::new(false));
764 let c = val.clone();
765 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
766 let _guard = shutdown.await;
767 tokio::time::sleep(Duration::from_millis(200)).await;
768 c.store(true, Ordering::Relaxed);
769 });
770
771 manager.graceful_shutdown();
772 assert!(val.load(Ordering::Relaxed));
773 }
774
775 #[test]
776 fn test_manager_graceful_shutdown_many() {
777 let runtime = tokio::runtime::Runtime::new().unwrap();
778 let handle = runtime.handle().clone();
779 let manager = TaskManager::new(handle);
780 let executor = manager.executor();
781
782 let counter = Arc::new(AtomicUsize::new(0));
783 let num = 10;
784 for _ in 0..num {
785 let c = counter.clone();
786 executor.spawn_critical_with_graceful_shutdown_signal(
787 "grace",
788 move |shutdown| async move {
789 let _guard = shutdown.await;
790 tokio::time::sleep(Duration::from_millis(200)).await;
791 c.fetch_add(1, Ordering::SeqCst);
792 },
793 );
794 }
795
796 manager.graceful_shutdown();
797 assert_eq!(counter.load(Ordering::Relaxed), num);
798 }
799
800 #[test]
801 fn test_manager_graceful_shutdown_timeout() {
802 let runtime = tokio::runtime::Runtime::new().unwrap();
803 let handle = runtime.handle().clone();
804 let manager = TaskManager::new(handle);
805 let executor = manager.executor();
806
807 let timeout = Duration::from_millis(500);
808 let val = Arc::new(AtomicBool::new(false));
809 let val2 = val.clone();
810 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
811 let _guard = shutdown.await;
812 tokio::time::sleep(timeout * 3).await;
813 val2.store(true, Ordering::Relaxed);
814 unreachable!("should not be reached");
815 });
816
817 manager.graceful_shutdown_with_timeout(timeout);
818 assert!(!val.load(Ordering::Relaxed));
819 }
820
821 #[test]
822 fn can_access_global() {
823 let runtime = tokio::runtime::Runtime::new().unwrap();
824 let handle = runtime.handle().clone();
825 let _manager = TaskManager::new(handle);
826 let _executor = TaskExecutor::try_current().unwrap();
827 }
828}