1#![doc(
8 html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9 html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10 issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
14
15use crate::{
16 metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17 shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21 future::{select, BoxFuture},
22 Future, FutureExt, TryFutureExt,
23};
24use std::{
25 any::Any,
26 fmt::{Display, Formatter},
27 pin::{pin, Pin},
28 sync::{
29 atomic::{AtomicUsize, Ordering},
30 Arc,
31 },
32 task::{ready, Context, Poll},
33};
34use tokio::{
35 runtime::Handle,
36 sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37 task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48#[auto_impl::auto_impl(&, Arc)]
88pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
89 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
92
93 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99 fn spawn_critical_blocking(
101 &self,
102 name: &'static str,
103 fut: BoxFuture<'static, ()>,
104 ) -> JoinHandle<()>;
105}
106
107dyn_clone::clone_trait_object!(TaskSpawner);
108
109#[derive(Debug, Clone, Default)]
111#[non_exhaustive]
112pub struct TokioTaskExecutor;
113
114impl TokioTaskExecutor {
115 pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
117 Box::new(self)
118 }
119}
120
121impl TaskSpawner for TokioTaskExecutor {
122 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
123 tokio::task::spawn(fut)
124 }
125
126 fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
127 tokio::task::spawn(fut)
128 }
129
130 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
131 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
132 }
133
134 fn spawn_critical_blocking(
135 &self,
136 _name: &'static str,
137 fut: BoxFuture<'static, ()>,
138 ) -> JoinHandle<()> {
139 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
140 }
141}
142
143#[derive(Debug)]
156#[must_use = "TaskManager must be polled to monitor critical tasks"]
157pub struct TaskManager {
158 handle: Handle,
162 panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
164 panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
166 signal: Option<Signal>,
170 on_shutdown: Shutdown,
172 graceful_tasks: Arc<AtomicUsize>,
174}
175
176impl TaskManager {
179 pub fn current() -> Self {
185 let handle = Handle::current();
186 Self::new(handle)
187 }
188
189 pub fn new(handle: Handle) -> Self {
191 let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
192 let (signal, on_shutdown) = signal();
193 Self {
194 handle,
195 panicked_tasks_tx,
196 panicked_tasks_rx,
197 signal: Some(signal),
198 on_shutdown,
199 graceful_tasks: Arc::new(AtomicUsize::new(0)),
200 }
201 }
202
203 pub fn executor(&self) -> TaskExecutor {
206 TaskExecutor {
207 handle: self.handle.clone(),
208 on_shutdown: self.on_shutdown.clone(),
209 panicked_tasks_tx: self.panicked_tasks_tx.clone(),
210 metrics: Default::default(),
211 graceful_tasks: Arc::clone(&self.graceful_tasks),
212 }
213 }
214
215 pub fn graceful_shutdown(self) {
217 let _ = self.do_graceful_shutdown(None);
218 }
219
220 pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
224 self.do_graceful_shutdown(Some(timeout))
225 }
226
227 fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
228 drop(self.signal);
229 let when = timeout.map(|t| std::time::Instant::now() + t);
230 while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
231 if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
232 debug!("graceful shutdown timed out");
233 return false
234 }
235 std::hint::spin_loop();
236 }
237
238 debug!("gracefully shut down");
239 true
240 }
241}
242
243impl Future for TaskManager {
247 type Output = PanickedTaskError;
248
249 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250 let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
251 Poll::Ready(err.expect("stream can not end"))
252 }
253}
254
255#[derive(Debug, thiserror::Error)]
257pub struct PanickedTaskError {
258 task_name: &'static str,
259 error: Option<String>,
260}
261
262impl Display for PanickedTaskError {
263 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264 let task_name = self.task_name;
265 if let Some(error) = &self.error {
266 write!(f, "Critical task `{task_name}` panicked: `{error}`")
267 } else {
268 write!(f, "Critical task `{task_name}` panicked")
269 }
270 }
271}
272
273impl PanickedTaskError {
274 fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
275 let error = match error.downcast::<String>() {
276 Ok(value) => Some(*value),
277 Err(error) => match error.downcast::<&str>() {
278 Ok(value) => Some(value.to_string()),
279 Err(_) => None,
280 },
281 };
282
283 Self { task_name, error }
284 }
285}
286
287#[derive(Debug, Clone)]
289pub struct TaskExecutor {
290 handle: Handle,
294 on_shutdown: Shutdown,
296 panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
298 metrics: TaskExecutorMetrics,
300 graceful_tasks: Arc<AtomicUsize>,
302}
303
304impl TaskExecutor {
307 pub const fn handle(&self) -> &Handle {
309 &self.handle
310 }
311
312 pub const fn on_shutdown_signal(&self) -> &Shutdown {
314 &self.on_shutdown
315 }
316
317 fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
319 where
320 F: Future<Output = ()> + Send + 'static,
321 {
322 match task_kind {
323 TaskKind::Default => self.handle.spawn(fut),
324 TaskKind::Blocking => {
325 let handle = self.handle.clone();
326 self.handle.spawn_blocking(move || handle.block_on(fut))
327 }
328 }
329 }
330
331 fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
333 where
334 F: Future<Output = ()> + Send + 'static,
335 {
336 let on_shutdown = self.on_shutdown.clone();
337
338 let finished_regular_tasks_total_metrics =
340 self.metrics.finished_regular_tasks_total.clone();
341 let task = {
343 async move {
344 let _inc_counter_on_drop =
346 IncCounterOnDrop::new(finished_regular_tasks_total_metrics);
347 let fut = pin!(fut);
348 let _ = select(on_shutdown, fut).await;
349 }
350 }
351 .in_current_span();
352
353 self.spawn_on_rt(task, task_kind)
354 }
355
356 pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
361 where
362 F: Future<Output = ()> + Send + 'static,
363 {
364 self.spawn_task_as(fut, TaskKind::Default)
365 }
366
367 pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
372 where
373 F: Future<Output = ()> + Send + 'static,
374 {
375 self.spawn_task_as(fut, TaskKind::Blocking)
376 }
377
378 pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
383 where
384 F: Future<Output = ()> + Send + 'static,
385 {
386 let on_shutdown = self.on_shutdown.clone();
387 let fut = f(on_shutdown);
388
389 let task = fut.in_current_span();
390
391 self.handle.spawn(task)
392 }
393
394 fn spawn_critical_as<F>(
396 &self,
397 name: &'static str,
398 fut: F,
399 task_kind: TaskKind,
400 ) -> JoinHandle<()>
401 where
402 F: Future<Output = ()> + Send + 'static,
403 {
404 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
405 let on_shutdown = self.on_shutdown.clone();
406
407 let task = std::panic::AssertUnwindSafe(fut)
409 .catch_unwind()
410 .map_err(move |error| {
411 let task_error = PanickedTaskError::new(name, error);
412 error!("{task_error}");
413 let _ = panicked_tasks_tx.send(task_error);
414 })
415 .in_current_span();
416
417 let finished_critical_tasks_total_metrics =
419 self.metrics.finished_critical_tasks_total.clone();
420 let task = async move {
421 let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
423 let task = pin!(task);
424 let _ = select(on_shutdown, task).await;
425 };
426
427 self.spawn_on_rt(task, task_kind)
428 }
429
430 pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
435 where
436 F: Future<Output = ()> + Send + 'static,
437 {
438 self.spawn_critical_as(name, fut, TaskKind::Blocking)
439 }
440
441 pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
446 where
447 F: Future<Output = ()> + Send + 'static,
448 {
449 self.spawn_critical_as(name, fut, TaskKind::Default)
450 }
451
452 pub fn spawn_critical_with_shutdown_signal<F>(
456 &self,
457 name: &'static str,
458 f: impl FnOnce(Shutdown) -> F,
459 ) -> JoinHandle<()>
460 where
461 F: Future<Output = ()> + Send + 'static,
462 {
463 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
464 let on_shutdown = self.on_shutdown.clone();
465 let fut = f(on_shutdown);
466
467 let task = std::panic::AssertUnwindSafe(fut)
469 .catch_unwind()
470 .map_err(move |error| {
471 let task_error = PanickedTaskError::new(name, error);
472 error!("{task_error}");
473 let _ = panicked_tasks_tx.send(task_error);
474 })
475 .map(drop)
476 .in_current_span();
477
478 self.handle.spawn(task)
479 }
480
481 pub fn spawn_critical_with_graceful_shutdown_signal<F>(
502 &self,
503 name: &'static str,
504 f: impl FnOnce(GracefulShutdown) -> F,
505 ) -> JoinHandle<()>
506 where
507 F: Future<Output = ()> + Send + 'static,
508 {
509 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
510 let on_shutdown = GracefulShutdown::new(
511 self.on_shutdown.clone(),
512 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
513 );
514 let fut = f(on_shutdown);
515
516 let task = std::panic::AssertUnwindSafe(fut)
518 .catch_unwind()
519 .map_err(move |error| {
520 let task_error = PanickedTaskError::new(name, error);
521 error!("{task_error}");
522 let _ = panicked_tasks_tx.send(task_error);
523 })
524 .map(drop)
525 .in_current_span();
526
527 self.handle.spawn(task)
528 }
529
530 pub fn spawn_with_graceful_shutdown_signal<F>(
550 &self,
551 f: impl FnOnce(GracefulShutdown) -> F,
552 ) -> JoinHandle<()>
553 where
554 F: Future<Output = ()> + Send + 'static,
555 {
556 let on_shutdown = GracefulShutdown::new(
557 self.on_shutdown.clone(),
558 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
559 );
560 let fut = f(on_shutdown);
561
562 self.handle.spawn(fut)
563 }
564}
565
566impl TaskSpawner for TaskExecutor {
567 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
568 self.metrics.inc_regular_tasks();
569 self.spawn(fut)
570 }
571
572 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
573 self.metrics.inc_critical_tasks();
574 Self::spawn_critical(self, name, fut)
575 }
576
577 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
578 self.spawn_blocking(fut)
579 }
580
581 fn spawn_critical_blocking(
582 &self,
583 name: &'static str,
584 fut: BoxFuture<'static, ()>,
585 ) -> JoinHandle<()> {
586 Self::spawn_critical_blocking(self, name, fut)
587 }
588}
589
590#[auto_impl::auto_impl(&, Arc)]
592pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
593 fn spawn_critical_with_graceful_shutdown_signal<F>(
598 &self,
599 name: &'static str,
600 f: impl FnOnce(GracefulShutdown) -> F,
601 ) -> JoinHandle<()>
602 where
603 F: Future<Output = ()> + Send + 'static;
604
605 fn spawn_with_graceful_shutdown_signal<F>(
609 &self,
610 f: impl FnOnce(GracefulShutdown) -> F,
611 ) -> JoinHandle<()>
612 where
613 F: Future<Output = ()> + Send + 'static;
614}
615
616impl TaskSpawnerExt for TaskExecutor {
617 fn spawn_critical_with_graceful_shutdown_signal<F>(
618 &self,
619 name: &'static str,
620 f: impl FnOnce(GracefulShutdown) -> F,
621 ) -> JoinHandle<()>
622 where
623 F: Future<Output = ()> + Send + 'static,
624 {
625 Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
626 }
627
628 fn spawn_with_graceful_shutdown_signal<F>(
629 &self,
630 f: impl FnOnce(GracefulShutdown) -> F,
631 ) -> JoinHandle<()>
632 where
633 F: Future<Output = ()> + Send + 'static,
634 {
635 Self::spawn_with_graceful_shutdown_signal(self, f)
636 }
637}
638
639enum TaskKind {
641 Default,
643 Blocking,
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use std::{sync::atomic::AtomicBool, time::Duration};
651
652 #[test]
653 fn test_cloneable() {
654 #[derive(Clone)]
655 struct ExecutorWrapper {
656 _e: Box<dyn TaskSpawner>,
657 }
658
659 let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
660 let _e = dyn_clone::clone_box(&*executor);
661
662 let e = ExecutorWrapper { _e };
663 let _e2 = e;
664 }
665
666 #[test]
667 fn test_critical() {
668 let runtime = tokio::runtime::Runtime::new().unwrap();
669 let handle = runtime.handle().clone();
670 let manager = TaskManager::new(handle);
671 let executor = manager.executor();
672
673 executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
674
675 runtime.block_on(async move {
676 let err = manager.await;
677 assert_eq!(err.task_name, "this is a critical task");
678 assert_eq!(err.error, Some("intentionally panic".to_string()));
679 })
680 }
681
682 #[test]
684 fn test_manager_shutdown_critical() {
685 let runtime = tokio::runtime::Runtime::new().unwrap();
686 let handle = runtime.handle().clone();
687 let manager = TaskManager::new(handle.clone());
688 let executor = manager.executor();
689
690 let (signal, shutdown) = signal();
691
692 executor.spawn_critical("this is a critical task", async move {
693 tokio::time::sleep(Duration::from_millis(200)).await;
694 drop(signal);
695 });
696
697 drop(manager);
698
699 handle.block_on(shutdown);
700 }
701
702 #[test]
704 fn test_manager_shutdown() {
705 let runtime = tokio::runtime::Runtime::new().unwrap();
706 let handle = runtime.handle().clone();
707 let manager = TaskManager::new(handle.clone());
708 let executor = manager.executor();
709
710 let (signal, shutdown) = signal();
711
712 executor.spawn(Box::pin(async move {
713 tokio::time::sleep(Duration::from_millis(200)).await;
714 drop(signal);
715 }));
716
717 drop(manager);
718
719 handle.block_on(shutdown);
720 }
721
722 #[test]
723 fn test_manager_graceful_shutdown() {
724 let runtime = tokio::runtime::Runtime::new().unwrap();
725 let handle = runtime.handle().clone();
726 let manager = TaskManager::new(handle);
727 let executor = manager.executor();
728
729 let val = Arc::new(AtomicBool::new(false));
730 let c = val.clone();
731 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
732 let _guard = shutdown.await;
733 tokio::time::sleep(Duration::from_millis(200)).await;
734 c.store(true, Ordering::Relaxed);
735 });
736
737 manager.graceful_shutdown();
738 assert!(val.load(Ordering::Relaxed));
739 }
740
741 #[test]
742 fn test_manager_graceful_shutdown_many() {
743 let runtime = tokio::runtime::Runtime::new().unwrap();
744 let handle = runtime.handle().clone();
745 let manager = TaskManager::new(handle);
746 let executor = manager.executor();
747
748 let counter = Arc::new(AtomicUsize::new(0));
749 let num = 10;
750 for _ in 0..num {
751 let c = counter.clone();
752 executor.spawn_critical_with_graceful_shutdown_signal(
753 "grace",
754 move |shutdown| async move {
755 let _guard = shutdown.await;
756 tokio::time::sleep(Duration::from_millis(200)).await;
757 c.fetch_add(1, Ordering::SeqCst);
758 },
759 );
760 }
761
762 manager.graceful_shutdown();
763 assert_eq!(counter.load(Ordering::Relaxed), num);
764 }
765
766 #[test]
767 fn test_manager_graceful_shutdown_timeout() {
768 let runtime = tokio::runtime::Runtime::new().unwrap();
769 let handle = runtime.handle().clone();
770 let manager = TaskManager::new(handle);
771 let executor = manager.executor();
772
773 let timeout = Duration::from_millis(500);
774 let val = Arc::new(AtomicBool::new(false));
775 let val2 = val.clone();
776 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
777 let _guard = shutdown.await;
778 tokio::time::sleep(timeout * 3).await;
779 val2.store(true, Ordering::Relaxed);
780 unreachable!("should not be reached");
781 });
782
783 manager.graceful_shutdown_with_timeout(timeout);
784 assert!(!val.load(Ordering::Relaxed));
785 }
786}