1#![doc(
8 html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9 html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10 issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14
15use crate::{
16 metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17 shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21 future::{select, BoxFuture},
22 Future, FutureExt, TryFutureExt,
23};
24use std::{
25 any::Any,
26 fmt::{Display, Formatter},
27 pin::{pin, Pin},
28 sync::{
29 atomic::{AtomicUsize, Ordering},
30 Arc, OnceLock,
31 },
32 task::{ready, Context, Poll},
33};
34use tokio::{
35 runtime::Handle,
36 sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37 task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48static GLOBAL_EXECUTOR: OnceLock<TaskExecutor> = OnceLock::new();
50
51#[auto_impl::auto_impl(&, Arc)]
91pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
92 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
101
102 fn spawn_critical_blocking(
104 &self,
105 name: &'static str,
106 fut: BoxFuture<'static, ()>,
107 ) -> JoinHandle<()>;
108}
109
110dyn_clone::clone_trait_object!(TaskSpawner);
111
112#[derive(Debug, Clone, Default)]
114#[non_exhaustive]
115pub struct TokioTaskExecutor;
116
117impl TokioTaskExecutor {
118 pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
120 Box::new(self)
121 }
122}
123
124impl TaskSpawner for TokioTaskExecutor {
125 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
126 tokio::task::spawn(fut)
127 }
128
129 fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
130 tokio::task::spawn(fut)
131 }
132
133 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
134 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
135 }
136
137 fn spawn_critical_blocking(
138 &self,
139 _name: &'static str,
140 fut: BoxFuture<'static, ()>,
141 ) -> JoinHandle<()> {
142 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
143 }
144}
145
146#[derive(Debug)]
159#[must_use = "TaskManager must be polled to monitor critical tasks"]
160pub struct TaskManager {
161 handle: Handle,
165 task_events_tx: UnboundedSender<TaskEvent>,
167 task_events_rx: UnboundedReceiver<TaskEvent>,
169 signal: Option<Signal>,
173 on_shutdown: Shutdown,
175 graceful_tasks: Arc<AtomicUsize>,
177}
178
179impl TaskManager {
182 pub fn current() -> Self {
192 let handle = Handle::current();
193 Self::new(handle)
194 }
195
196 pub fn new(handle: Handle) -> Self {
200 let (task_events_tx, task_events_rx) = unbounded_channel();
201 let (signal, on_shutdown) = signal();
202 let manager = Self {
203 handle,
204 task_events_tx,
205 task_events_rx,
206 signal: Some(signal),
207 on_shutdown,
208 graceful_tasks: Arc::new(AtomicUsize::new(0)),
209 };
210
211 let _ = GLOBAL_EXECUTOR
212 .set(manager.executor())
213 .inspect_err(|_| error!("Global executor already set"));
214
215 manager
216 }
217
218 pub fn executor(&self) -> TaskExecutor {
221 TaskExecutor {
222 handle: self.handle.clone(),
223 on_shutdown: self.on_shutdown.clone(),
224 task_events_tx: self.task_events_tx.clone(),
225 metrics: Default::default(),
226 graceful_tasks: Arc::clone(&self.graceful_tasks),
227 }
228 }
229
230 pub fn graceful_shutdown(self) {
232 let _ = self.do_graceful_shutdown(None);
233 }
234
235 pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
239 self.do_graceful_shutdown(Some(timeout))
240 }
241
242 fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
243 drop(self.signal);
244 let when = timeout.map(|t| std::time::Instant::now() + t);
245 while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
246 if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
247 debug!("graceful shutdown timed out");
248 return false
249 }
250 std::hint::spin_loop();
251 }
252
253 debug!("gracefully shut down");
254 true
255 }
256}
257
258impl Future for TaskManager {
262 type Output = Result<(), PanickedTaskError>;
263
264 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265 match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
266 Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
267 Some(TaskEvent::GracefulShutdown) | None => {
268 if let Some(signal) = self.get_mut().signal.take() {
269 signal.fire();
270 }
271 Poll::Ready(Ok(()))
272 }
273 }
274 }
275}
276
277#[derive(Debug, thiserror::Error, PartialEq, Eq)]
279pub struct PanickedTaskError {
280 task_name: &'static str,
281 error: Option<String>,
282}
283
284impl Display for PanickedTaskError {
285 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286 let task_name = self.task_name;
287 if let Some(error) = &self.error {
288 write!(f, "Critical task `{task_name}` panicked: `{error}`")
289 } else {
290 write!(f, "Critical task `{task_name}` panicked")
291 }
292 }
293}
294
295impl PanickedTaskError {
296 fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
297 let error = match error.downcast::<String>() {
298 Ok(value) => Some(*value),
299 Err(error) => match error.downcast::<&str>() {
300 Ok(value) => Some(value.to_string()),
301 Err(_) => None,
302 },
303 };
304
305 Self { task_name, error }
306 }
307}
308
309#[derive(Debug)]
311enum TaskEvent {
312 Panic(PanickedTaskError),
314 GracefulShutdown,
316}
317
318#[derive(Debug, Clone)]
320pub struct TaskExecutor {
321 handle: Handle,
325 on_shutdown: Shutdown,
327 task_events_tx: UnboundedSender<TaskEvent>,
329 metrics: TaskExecutorMetrics,
331 graceful_tasks: Arc<AtomicUsize>,
333}
334
335impl TaskExecutor {
338 pub fn try_current() -> Result<Self, NoCurrentTaskExecutorError> {
342 GLOBAL_EXECUTOR.get().cloned().ok_or_else(NoCurrentTaskExecutorError::default)
343 }
344
345 pub fn current() -> Self {
352 Self::try_current().unwrap()
353 }
354
355 pub const fn handle(&self) -> &Handle {
357 &self.handle
358 }
359
360 pub const fn on_shutdown_signal(&self) -> &Shutdown {
362 &self.on_shutdown
363 }
364
365 fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
367 where
368 F: Future<Output = ()> + Send + 'static,
369 {
370 match task_kind {
371 TaskKind::Default => self.handle.spawn(fut),
372 TaskKind::Blocking => {
373 let handle = self.handle.clone();
374 self.handle.spawn_blocking(move || handle.block_on(fut))
375 }
376 }
377 }
378
379 fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
381 where
382 F: Future<Output = ()> + Send + 'static,
383 {
384 let on_shutdown = self.on_shutdown.clone();
385
386 let finished_counter = match task_kind {
388 TaskKind::Default => self.metrics.finished_regular_tasks_total.clone(),
389 TaskKind::Blocking => self.metrics.finished_regular_blocking_tasks_total.clone(),
390 };
391
392 let task = {
394 async move {
395 let _inc_counter_on_drop = IncCounterOnDrop::new(finished_counter);
397 let fut = pin!(fut);
398 let _ = select(on_shutdown, fut).await;
399 }
400 }
401 .in_current_span();
402
403 self.spawn_on_rt(task, task_kind)
404 }
405
406 pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
411 where
412 F: Future<Output = ()> + Send + 'static,
413 {
414 self.spawn_task_as(fut, TaskKind::Default)
415 }
416
417 pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
422 where
423 F: Future<Output = ()> + Send + 'static,
424 {
425 self.spawn_task_as(fut, TaskKind::Blocking)
426 }
427
428 pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
433 where
434 F: Future<Output = ()> + Send + 'static,
435 {
436 let on_shutdown = self.on_shutdown.clone();
437 let fut = f(on_shutdown);
438
439 let task = fut.in_current_span();
440
441 self.handle.spawn(task)
442 }
443
444 fn spawn_critical_as<F>(
446 &self,
447 name: &'static str,
448 fut: F,
449 task_kind: TaskKind,
450 ) -> JoinHandle<()>
451 where
452 F: Future<Output = ()> + Send + 'static,
453 {
454 let panicked_tasks_tx = self.task_events_tx.clone();
455 let on_shutdown = self.on_shutdown.clone();
456
457 let task = std::panic::AssertUnwindSafe(fut)
459 .catch_unwind()
460 .map_err(move |error| {
461 let task_error = PanickedTaskError::new(name, error);
462 error!("{task_error}");
463 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
464 })
465 .in_current_span();
466
467 let finished_critical_tasks_total_metrics =
469 self.metrics.finished_critical_tasks_total.clone();
470 let task = async move {
471 let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
473 let task = pin!(task);
474 let _ = select(on_shutdown, task).await;
475 };
476
477 self.spawn_on_rt(task, task_kind)
478 }
479
480 pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
485 where
486 F: Future<Output = ()> + Send + 'static,
487 {
488 self.spawn_critical_as(name, fut, TaskKind::Blocking)
489 }
490
491 pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
496 where
497 F: Future<Output = ()> + Send + 'static,
498 {
499 self.spawn_critical_as(name, fut, TaskKind::Default)
500 }
501
502 pub fn spawn_critical_with_shutdown_signal<F>(
506 &self,
507 name: &'static str,
508 f: impl FnOnce(Shutdown) -> F,
509 ) -> JoinHandle<()>
510 where
511 F: Future<Output = ()> + Send + 'static,
512 {
513 let panicked_tasks_tx = self.task_events_tx.clone();
514 let on_shutdown = self.on_shutdown.clone();
515 let fut = f(on_shutdown);
516
517 let task = std::panic::AssertUnwindSafe(fut)
519 .catch_unwind()
520 .map_err(move |error| {
521 let task_error = PanickedTaskError::new(name, error);
522 error!("{task_error}");
523 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
524 })
525 .map(drop)
526 .in_current_span();
527
528 self.handle.spawn(task)
529 }
530
531 pub fn spawn_critical_with_graceful_shutdown_signal<F>(
552 &self,
553 name: &'static str,
554 f: impl FnOnce(GracefulShutdown) -> F,
555 ) -> JoinHandle<()>
556 where
557 F: Future<Output = ()> + Send + 'static,
558 {
559 let panicked_tasks_tx = self.task_events_tx.clone();
560 let on_shutdown = GracefulShutdown::new(
561 self.on_shutdown.clone(),
562 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
563 );
564 let fut = f(on_shutdown);
565
566 let task = std::panic::AssertUnwindSafe(fut)
568 .catch_unwind()
569 .map_err(move |error| {
570 let task_error = PanickedTaskError::new(name, error);
571 error!("{task_error}");
572 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
573 })
574 .map(drop)
575 .in_current_span();
576
577 self.handle.spawn(task)
578 }
579
580 pub fn spawn_with_graceful_shutdown_signal<F>(
600 &self,
601 f: impl FnOnce(GracefulShutdown) -> F,
602 ) -> JoinHandle<()>
603 where
604 F: Future<Output = ()> + Send + 'static,
605 {
606 let on_shutdown = GracefulShutdown::new(
607 self.on_shutdown.clone(),
608 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
609 );
610 let fut = f(on_shutdown);
611
612 self.handle.spawn(fut)
613 }
614
615 pub fn initiate_graceful_shutdown(
622 &self,
623 ) -> Result<GracefulShutdown, tokio::sync::mpsc::error::SendError<()>> {
624 self.task_events_tx
625 .send(TaskEvent::GracefulShutdown)
626 .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?;
627
628 Ok(GracefulShutdown::new(
629 self.on_shutdown.clone(),
630 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
631 ))
632 }
633}
634
635impl TaskSpawner for TaskExecutor {
636 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
637 self.metrics.inc_regular_tasks();
638 self.spawn(fut)
639 }
640
641 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
642 self.metrics.inc_critical_tasks();
643 Self::spawn_critical(self, name, fut)
644 }
645
646 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
647 self.metrics.inc_regular_blocking_tasks();
648 self.spawn_blocking(fut)
649 }
650
651 fn spawn_critical_blocking(
652 &self,
653 name: &'static str,
654 fut: BoxFuture<'static, ()>,
655 ) -> JoinHandle<()> {
656 self.metrics.inc_critical_tasks();
657 Self::spawn_critical_blocking(self, name, fut)
658 }
659}
660
661#[auto_impl::auto_impl(&, Arc)]
663pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
664 fn spawn_critical_with_graceful_shutdown_signal<F>(
669 &self,
670 name: &'static str,
671 f: impl FnOnce(GracefulShutdown) -> F,
672 ) -> JoinHandle<()>
673 where
674 F: Future<Output = ()> + Send + 'static;
675
676 fn spawn_with_graceful_shutdown_signal<F>(
680 &self,
681 f: impl FnOnce(GracefulShutdown) -> F,
682 ) -> JoinHandle<()>
683 where
684 F: Future<Output = ()> + Send + 'static;
685}
686
687impl TaskSpawnerExt for TaskExecutor {
688 fn spawn_critical_with_graceful_shutdown_signal<F>(
689 &self,
690 name: &'static str,
691 f: impl FnOnce(GracefulShutdown) -> F,
692 ) -> JoinHandle<()>
693 where
694 F: Future<Output = ()> + Send + 'static,
695 {
696 Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
697 }
698
699 fn spawn_with_graceful_shutdown_signal<F>(
700 &self,
701 f: impl FnOnce(GracefulShutdown) -> F,
702 ) -> JoinHandle<()>
703 where
704 F: Future<Output = ()> + Send + 'static,
705 {
706 Self::spawn_with_graceful_shutdown_signal(self, f)
707 }
708}
709
710enum TaskKind {
712 Default,
714 Blocking,
716}
717
718#[derive(Debug, Default, thiserror::Error)]
720#[error("No current task executor available.")]
721#[non_exhaustive]
722pub struct NoCurrentTaskExecutorError;
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727 use std::{sync::atomic::AtomicBool, time::Duration};
728
729 #[test]
730 fn test_cloneable() {
731 #[derive(Clone)]
732 struct ExecutorWrapper {
733 _e: Box<dyn TaskSpawner>,
734 }
735
736 let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
737 let _e = dyn_clone::clone_box(&*executor);
738
739 let e = ExecutorWrapper { _e };
740 let _e2 = e;
741 }
742
743 #[test]
744 fn test_critical() {
745 let runtime = tokio::runtime::Runtime::new().unwrap();
746 let handle = runtime.handle().clone();
747 let manager = TaskManager::new(handle);
748 let executor = manager.executor();
749
750 executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
751
752 runtime.block_on(async move {
753 let err_result = manager.await;
754 assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
755 let panicked_err = err_result.unwrap_err();
756
757 assert_eq!(panicked_err.task_name, "this is a critical task");
758 assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
759 })
760 }
761
762 #[test]
764 fn test_manager_shutdown_critical() {
765 let runtime = tokio::runtime::Runtime::new().unwrap();
766 let handle = runtime.handle().clone();
767 let manager = TaskManager::new(handle.clone());
768 let executor = manager.executor();
769
770 let (signal, shutdown) = signal();
771
772 executor.spawn_critical("this is a critical task", async move {
773 tokio::time::sleep(Duration::from_millis(200)).await;
774 drop(signal);
775 });
776
777 drop(manager);
778
779 handle.block_on(shutdown);
780 }
781
782 #[test]
784 fn test_manager_shutdown() {
785 let runtime = tokio::runtime::Runtime::new().unwrap();
786 let handle = runtime.handle().clone();
787 let manager = TaskManager::new(handle.clone());
788 let executor = manager.executor();
789
790 let (signal, shutdown) = signal();
791
792 executor.spawn(Box::pin(async move {
793 tokio::time::sleep(Duration::from_millis(200)).await;
794 drop(signal);
795 }));
796
797 drop(manager);
798
799 handle.block_on(shutdown);
800 }
801
802 #[test]
803 fn test_manager_graceful_shutdown() {
804 let runtime = tokio::runtime::Runtime::new().unwrap();
805 let handle = runtime.handle().clone();
806 let manager = TaskManager::new(handle);
807 let executor = manager.executor();
808
809 let val = Arc::new(AtomicBool::new(false));
810 let c = val.clone();
811 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
812 let _guard = shutdown.await;
813 tokio::time::sleep(Duration::from_millis(200)).await;
814 c.store(true, Ordering::Relaxed);
815 });
816
817 manager.graceful_shutdown();
818 assert!(val.load(Ordering::Relaxed));
819 }
820
821 #[test]
822 fn test_manager_graceful_shutdown_many() {
823 let runtime = tokio::runtime::Runtime::new().unwrap();
824 let handle = runtime.handle().clone();
825 let manager = TaskManager::new(handle);
826 let executor = manager.executor();
827
828 let counter = Arc::new(AtomicUsize::new(0));
829 let num = 10;
830 for _ in 0..num {
831 let c = counter.clone();
832 executor.spawn_critical_with_graceful_shutdown_signal(
833 "grace",
834 move |shutdown| async move {
835 let _guard = shutdown.await;
836 tokio::time::sleep(Duration::from_millis(200)).await;
837 c.fetch_add(1, Ordering::SeqCst);
838 },
839 );
840 }
841
842 manager.graceful_shutdown();
843 assert_eq!(counter.load(Ordering::Relaxed), num);
844 }
845
846 #[test]
847 fn test_manager_graceful_shutdown_timeout() {
848 let runtime = tokio::runtime::Runtime::new().unwrap();
849 let handle = runtime.handle().clone();
850 let manager = TaskManager::new(handle);
851 let executor = manager.executor();
852
853 let timeout = Duration::from_millis(500);
854 let val = Arc::new(AtomicBool::new(false));
855 let val2 = val.clone();
856 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
857 let _guard = shutdown.await;
858 tokio::time::sleep(timeout * 3).await;
859 val2.store(true, Ordering::Relaxed);
860 unreachable!("should not be reached");
861 });
862
863 manager.graceful_shutdown_with_timeout(timeout);
864 assert!(!val.load(Ordering::Relaxed));
865 }
866
867 #[test]
868 fn can_access_global() {
869 let runtime = tokio::runtime::Runtime::new().unwrap();
870 let handle = runtime.handle().clone();
871 let _manager = TaskManager::new(handle);
872 let _executor = TaskExecutor::try_current().unwrap();
873 }
874
875 #[test]
876 fn test_graceful_shutdown_triggered_by_executor() {
877 let runtime = tokio::runtime::Runtime::new().unwrap();
878 let task_manager = TaskManager::new(runtime.handle().clone());
879 let executor = task_manager.executor();
880
881 let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
882 let flag_clone = task_did_shutdown_flag.clone();
883
884 let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move {
885 shutdown_signal.await;
886 flag_clone.store(true, Ordering::SeqCst);
887 });
888
889 let manager_future_handle = runtime.spawn(task_manager);
890
891 let send_result = executor.initiate_graceful_shutdown();
892 assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future");
893
894 let manager_final_result = runtime.block_on(manager_future_handle);
895
896 assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
897 assert_eq!(
898 manager_final_result.unwrap(),
899 Ok(()),
900 "TaskManager should resolve cleanly with Ok(()) after graceful shutdown request"
901 );
902
903 let task_join_result = runtime.block_on(spawned_task_handle);
904 assert!(task_join_result.is_ok(), "Spawned task should complete without panic");
905
906 assert!(
907 task_did_shutdown_flag.load(Ordering::Relaxed),
908 "Task should have received the shutdown signal and set the flag"
909 );
910 }
911}