1#[cfg(feature = "rayon")]
10use crate::pool::{build_pool_with_panic_handler, BlockingTaskGuard, BlockingTaskPool, WorkerPool};
11use crate::{
12 metrics::{IncCounterOnDrop, TaskExecutorMetrics},
13 shutdown::{GracefulShutdown, GracefulShutdownGuard, Shutdown},
14 worker_map::WorkerMap,
15 PanickedTaskError, TaskEvent, TaskManager,
16};
17use futures_util::{future::select, Future, FutureExt, TryFutureExt};
18#[cfg(feature = "rayon")]
19use std::{num::NonZeroUsize, thread::available_parallelism};
20use std::{
21 pin::pin,
22 sync::{
23 atomic::{AtomicUsize, Ordering},
24 Arc, Mutex,
25 },
26 time::{Duration, Instant},
27};
28use tokio::{runtime::Handle, sync::mpsc::UnboundedSender, task::JoinHandle};
29use tracing::{debug, error};
30use tracing_futures::Instrument;
31
32use tokio::runtime::Runtime as TokioRuntime;
33
34pub const DEFAULT_THREAD_KEEP_ALIVE: Duration = Duration::from_secs(15);
36
37pub const DEFAULT_RESERVED_CPU_CORES: usize = 2;
39
40pub const DEFAULT_STORAGE_POOL_THREADS: usize = 16;
42
43pub const DEFAULT_MAX_BLOCKING_TASKS: usize = 512;
45
46#[derive(Debug, Clone)]
48pub enum TokioConfig {
49 Owned {
51 worker_threads: Option<usize>,
53 thread_keep_alive: Duration,
55 thread_name: &'static str,
57 },
58 ExistingHandle(Handle),
60}
61
62impl Default for TokioConfig {
63 fn default() -> Self {
64 Self::Owned {
65 worker_threads: None,
66 thread_keep_alive: DEFAULT_THREAD_KEEP_ALIVE,
67 thread_name: "tokio-rt",
68 }
69 }
70}
71
72impl TokioConfig {
73 pub const fn existing_handle(handle: Handle) -> Self {
75 Self::ExistingHandle(handle)
76 }
77
78 pub const fn with_worker_threads(worker_threads: usize) -> Self {
80 Self::Owned {
81 worker_threads: Some(worker_threads),
82 thread_keep_alive: DEFAULT_THREAD_KEEP_ALIVE,
83 thread_name: "tokio-rt",
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90#[cfg(feature = "rayon")]
91pub struct RayonConfig {
92 pub cpu_threads: Option<usize>,
95 pub reserved_cpu_cores: usize,
97 pub rpc_threads: Option<usize>,
100 pub storage_threads: Option<usize>,
103 pub max_blocking_tasks: usize,
105 pub proof_storage_worker_threads: Option<usize>,
108 pub proof_account_worker_threads: Option<usize>,
111 pub prewarming_threads: Option<usize>,
114}
115
116#[cfg(feature = "rayon")]
117impl Default for RayonConfig {
118 fn default() -> Self {
119 Self {
120 cpu_threads: None,
121 reserved_cpu_cores: DEFAULT_RESERVED_CPU_CORES,
122 rpc_threads: None,
123 storage_threads: None,
124 max_blocking_tasks: DEFAULT_MAX_BLOCKING_TASKS,
125 proof_storage_worker_threads: None,
126 proof_account_worker_threads: None,
127 prewarming_threads: None,
128 }
129 }
130}
131
132#[cfg(feature = "rayon")]
133impl RayonConfig {
134 pub const fn with_reserved_cpu_cores(mut self, reserved_cpu_cores: usize) -> Self {
136 self.reserved_cpu_cores = reserved_cpu_cores;
137 self
138 }
139
140 pub const fn with_max_blocking_tasks(mut self, max_blocking_tasks: usize) -> Self {
142 self.max_blocking_tasks = max_blocking_tasks;
143 self
144 }
145
146 pub const fn with_rpc_threads(mut self, rpc_threads: usize) -> Self {
148 self.rpc_threads = Some(rpc_threads);
149 self
150 }
151
152 pub const fn with_storage_threads(mut self, storage_threads: usize) -> Self {
154 self.storage_threads = Some(storage_threads);
155 self
156 }
157
158 pub const fn with_proof_storage_worker_threads(
160 mut self,
161 proof_storage_worker_threads: usize,
162 ) -> Self {
163 self.proof_storage_worker_threads = Some(proof_storage_worker_threads);
164 self
165 }
166
167 pub const fn with_proof_account_worker_threads(
169 mut self,
170 proof_account_worker_threads: usize,
171 ) -> Self {
172 self.proof_account_worker_threads = Some(proof_account_worker_threads);
173 self
174 }
175
176 pub const fn with_prewarming_threads(mut self, prewarming_threads: usize) -> Self {
178 self.prewarming_threads = Some(prewarming_threads);
179 self
180 }
181
182 fn default_thread_count(&self) -> usize {
184 let _ = self.reserved_cpu_cores;
187 self.cpu_threads.unwrap_or_else(|| available_parallelism().map_or(1, NonZeroUsize::get))
188 }
189}
190
191#[derive(Debug, Clone, Default)]
193pub struct RuntimeConfig {
194 pub tokio: TokioConfig,
196 #[cfg(feature = "rayon")]
198 pub rayon: RayonConfig,
199}
200
201impl RuntimeConfig {
202 pub fn with_tokio(mut self, tokio: TokioConfig) -> Self {
204 self.tokio = tokio;
205 self
206 }
207
208 #[cfg(feature = "rayon")]
210 pub const fn with_rayon(mut self, rayon: RayonConfig) -> Self {
211 self.rayon = rayon;
212 self
213 }
214}
215
216#[derive(Debug, thiserror::Error)]
218pub enum RuntimeBuildError {
219 #[error("Failed to build tokio runtime: {0}")]
221 TokioBuild(#[from] std::io::Error),
222 #[cfg(feature = "rayon")]
224 #[error("Failed to build rayon thread pool: {0}")]
225 RayonBuild(#[from] rayon::ThreadPoolBuildError),
226}
227
228struct RuntimeInner {
231 _tokio_runtime: Option<TokioRuntime>,
233 handle: Handle,
235 on_shutdown: Shutdown,
237 task_events_tx: UnboundedSender<TaskEvent>,
239 metrics: TaskExecutorMetrics,
241 graceful_tasks: Arc<AtomicUsize>,
243 #[cfg(feature = "rayon")]
245 cpu_pool: rayon::ThreadPool,
246 #[cfg(feature = "rayon")]
248 rpc_pool: BlockingTaskPool,
249 #[cfg(feature = "rayon")]
251 storage_pool: rayon::ThreadPool,
252 #[cfg(feature = "rayon")]
254 blocking_guard: BlockingTaskGuard,
255 #[cfg(feature = "rayon")]
257 proof_storage_worker_pool: WorkerPool,
258 #[cfg(feature = "rayon")]
260 proof_account_worker_pool: WorkerPool,
261 #[cfg(feature = "rayon")]
263 prewarming_pool: WorkerPool,
264 worker_map: WorkerMap,
267 task_manager_handle: Mutex<Option<JoinHandle<Result<(), PanickedTaskError>>>>,
271}
272
273#[derive(Clone)]
282pub struct Runtime(Arc<RuntimeInner>);
283
284impl std::fmt::Debug for Runtime {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 f.debug_struct("Runtime").field("handle", &self.0.handle).finish()
287 }
288}
289
290impl Runtime {
293 pub fn take_task_manager_handle(&self) -> Option<JoinHandle<Result<(), PanickedTaskError>>> {
299 self.0.task_manager_handle.lock().unwrap().take()
300 }
301
302 pub fn handle(&self) -> &Handle {
304 &self.0.handle
305 }
306
307 #[cfg(feature = "rayon")]
309 pub fn cpu_pool(&self) -> &rayon::ThreadPool {
310 &self.0.cpu_pool
311 }
312
313 #[cfg(feature = "rayon")]
315 pub fn rpc_pool(&self) -> &BlockingTaskPool {
316 &self.0.rpc_pool
317 }
318
319 #[cfg(feature = "rayon")]
321 pub fn storage_pool(&self) -> &rayon::ThreadPool {
322 &self.0.storage_pool
323 }
324
325 #[cfg(feature = "rayon")]
327 pub fn blocking_guard(&self) -> BlockingTaskGuard {
328 self.0.blocking_guard.clone()
329 }
330
331 #[cfg(feature = "rayon")]
333 pub fn proof_storage_worker_pool(&self) -> &WorkerPool {
334 &self.0.proof_storage_worker_pool
335 }
336
337 #[cfg(feature = "rayon")]
339 pub fn proof_account_worker_pool(&self) -> &WorkerPool {
340 &self.0.proof_account_worker_pool
341 }
342
343 #[cfg(feature = "rayon")]
345 pub fn prewarming_pool(&self) -> &WorkerPool {
346 &self.0.prewarming_pool
347 }
348}
349
350impl Runtime {
353 pub fn test() -> Self {
358 let config = match Handle::try_current() {
359 Ok(handle) => Self::test_config().with_tokio(TokioConfig::existing_handle(handle)),
360 Err(_) => Self::test_config(),
361 };
362 RuntimeBuilder::new(config).build().expect("failed to build test Runtime")
363 }
364
365 const fn test_config() -> RuntimeConfig {
366 RuntimeConfig {
367 tokio: TokioConfig::Owned {
368 worker_threads: Some(2),
369 thread_keep_alive: DEFAULT_THREAD_KEEP_ALIVE,
370 thread_name: "tokio-test",
371 },
372 #[cfg(feature = "rayon")]
373 rayon: RayonConfig {
374 cpu_threads: Some(2),
375 reserved_cpu_cores: 0,
376 rpc_threads: Some(2),
377 storage_threads: Some(2),
378 max_blocking_tasks: 16,
379 proof_storage_worker_threads: Some(2),
380 proof_account_worker_threads: Some(2),
381 prewarming_threads: Some(2),
382 },
383 }
384 }
385}
386
387enum TaskKind {
391 Default,
393 Blocking,
395}
396
397impl Runtime {
398 pub fn on_shutdown_signal(&self) -> &Shutdown {
400 &self.0.on_shutdown
401 }
402
403 fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
405 where
406 F: Future<Output = ()> + Send + 'static,
407 {
408 match task_kind {
409 TaskKind::Default => self.0.handle.spawn(fut),
410 TaskKind::Blocking => {
411 let handle = self.0.handle.clone();
412 self.0.handle.spawn_blocking(move || handle.block_on(fut))
413 }
414 }
415 }
416
417 fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
419 where
420 F: Future<Output = ()> + Send + 'static,
421 {
422 match task_kind {
423 TaskKind::Default => self.0.metrics.inc_regular_tasks(),
424 TaskKind::Blocking => self.0.metrics.inc_regular_blocking_tasks(),
425 }
426 let on_shutdown = self.0.on_shutdown.clone();
427
428 let finished_counter = match task_kind {
429 TaskKind::Default => self.0.metrics.finished_regular_tasks_total.clone(),
430 TaskKind::Blocking => self.0.metrics.finished_regular_blocking_tasks_total.clone(),
431 };
432
433 let task = {
434 async move {
435 let _inc_counter_on_drop = IncCounterOnDrop::new(finished_counter);
436 let fut = pin!(fut);
437 let _ = select(on_shutdown, fut).await;
438 }
439 }
440 .in_current_span();
441
442 self.spawn_on_rt(task, task_kind)
443 }
444
445 pub fn spawn_task<F>(&self, fut: F) -> JoinHandle<()>
450 where
451 F: Future<Output = ()> + Send + 'static,
452 {
453 self.spawn_task_as(fut, TaskKind::Default)
454 }
455
456 pub fn spawn_blocking_task<F>(&self, fut: F) -> JoinHandle<()>
461 where
462 F: Future<Output = ()> + Send + 'static,
463 {
464 self.spawn_task_as(fut, TaskKind::Blocking)
465 }
466
467 pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
470 where
471 F: FnOnce() -> R + Send + 'static,
472 R: Send + 'static,
473 {
474 self.0.handle.spawn_blocking(func)
475 }
476
477 pub fn spawn_drop<T: Send + 'static>(&self, value: T) {
483 self.spawn_blocking_named("drop", move || drop(value));
484 }
485
486 pub fn spawn_blocking_named<F, R>(&self, name: &'static str, func: F) -> crate::LazyHandle<R>
498 where
499 F: FnOnce() -> R + Send + 'static,
500 R: Send + 'static,
501 {
502 crate::LazyHandle::new(self.0.worker_map.spawn_on(name, func))
503 }
504
505 pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
510 where
511 F: Future<Output = ()> + Send + 'static,
512 {
513 let on_shutdown = self.0.on_shutdown.clone();
514 let fut = f(on_shutdown);
515 let task = fut.in_current_span();
516 self.0.handle.spawn(task)
517 }
518
519 fn spawn_critical_as<F>(
521 &self,
522 name: &'static str,
523 fut: F,
524 task_kind: TaskKind,
525 ) -> JoinHandle<()>
526 where
527 F: Future<Output = ()> + Send + 'static,
528 {
529 self.0.metrics.inc_critical_tasks();
530 let panicked_tasks_tx = self.0.task_events_tx.clone();
531 let on_shutdown = self.0.on_shutdown.clone();
532
533 let task = std::panic::AssertUnwindSafe(fut)
535 .catch_unwind()
536 .map_err(move |error| {
537 let task_error = PanickedTaskError::new(name, error);
538 error!("{task_error}");
539 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
540 })
541 .in_current_span();
542
543 let finished_critical_tasks_total_metrics =
544 self.0.metrics.finished_critical_tasks_total.clone();
545 let task = async move {
546 let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
547 let task = pin!(task);
548 let _ = select(on_shutdown, task).await;
549 };
550
551 self.spawn_on_rt(task, task_kind)
552 }
553
554 pub fn spawn_critical_task<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
559 where
560 F: Future<Output = ()> + Send + 'static,
561 {
562 self.spawn_critical_as(name, fut, TaskKind::Default)
563 }
564
565 pub fn spawn_critical_blocking_task<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
570 where
571 F: Future<Output = ()> + Send + 'static,
572 {
573 self.spawn_critical_as(name, fut, TaskKind::Blocking)
574 }
575
576 pub fn spawn_critical_with_graceful_shutdown_signal<F>(
597 &self,
598 name: &'static str,
599 f: impl FnOnce(GracefulShutdown) -> F,
600 ) -> JoinHandle<()>
601 where
602 F: Future<Output = ()> + Send + 'static,
603 {
604 let panicked_tasks_tx = self.0.task_events_tx.clone();
605 let on_shutdown = GracefulShutdown::new(
606 self.0.on_shutdown.clone(),
607 GracefulShutdownGuard::new(Arc::clone(&self.0.graceful_tasks)),
608 );
609 let fut = f(on_shutdown);
610
611 let task = std::panic::AssertUnwindSafe(fut)
613 .catch_unwind()
614 .map_err(move |error| {
615 let task_error = PanickedTaskError::new(name, error);
616 error!("{task_error}");
617 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
618 })
619 .map(drop)
620 .in_current_span();
621
622 self.0.handle.spawn(task)
623 }
624
625 pub fn spawn_with_graceful_shutdown_signal<F>(
645 &self,
646 f: impl FnOnce(GracefulShutdown) -> F,
647 ) -> JoinHandle<()>
648 where
649 F: Future<Output = ()> + Send + 'static,
650 {
651 let on_shutdown = GracefulShutdown::new(
652 self.0.on_shutdown.clone(),
653 GracefulShutdownGuard::new(Arc::clone(&self.0.graceful_tasks)),
654 );
655 let fut = f(on_shutdown);
656
657 self.0.handle.spawn(fut)
658 }
659
660 pub fn initiate_graceful_shutdown(
664 &self,
665 ) -> Result<GracefulShutdown, tokio::sync::mpsc::error::SendError<()>> {
666 self.0
667 .task_events_tx
668 .send(TaskEvent::GracefulShutdown)
669 .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?;
670
671 Ok(GracefulShutdown::new(
672 self.0.on_shutdown.clone(),
673 GracefulShutdownGuard::new(Arc::clone(&self.0.graceful_tasks)),
674 ))
675 }
676
677 pub fn graceful_shutdown(&self) {
679 let _ = self.do_graceful_shutdown(None);
680 }
681
682 pub fn graceful_shutdown_with_timeout(&self, timeout: Duration) -> bool {
687 self.do_graceful_shutdown(Some(timeout))
688 }
689
690 fn do_graceful_shutdown(&self, timeout: Option<Duration>) -> bool {
691 let _ = self.0.task_events_tx.send(TaskEvent::GracefulShutdown);
692 let deadline = timeout.map(|t| Instant::now() + t);
693 while self.0.graceful_tasks.load(Ordering::SeqCst) > 0 {
694 if deadline.is_some_and(|d| Instant::now() > d) {
695 debug!("graceful shutdown timed out");
696 return false;
697 }
698 std::thread::yield_now();
699 }
700 debug!("gracefully shut down");
701 true
702 }
703}
704
705#[derive(Debug, Clone)]
709pub struct RuntimeBuilder {
710 config: RuntimeConfig,
711}
712
713impl RuntimeBuilder {
714 pub const fn new(config: RuntimeConfig) -> Self {
716 Self { config }
717 }
718
719 #[tracing::instrument(name = "RuntimeBuilder::build", level = "debug", skip_all)]
725 pub fn build(self) -> Result<Runtime, RuntimeBuildError> {
726 debug!(?self.config, "Building runtime");
727 let config = self.config;
728
729 let (owned_runtime, handle) = match &config.tokio {
730 TokioConfig::Owned { worker_threads, thread_keep_alive, thread_name } => {
731 let mut builder = tokio::runtime::Builder::new_multi_thread();
732 builder
733 .enable_all()
734 .thread_keep_alive(*thread_keep_alive)
735 .thread_name(*thread_name);
736
737 if let Some(threads) = worker_threads {
738 builder.worker_threads(*threads);
739 }
740
741 let runtime = builder.build()?;
742 let h = runtime.handle().clone();
743 (Some(runtime), h)
744 }
745 TokioConfig::ExistingHandle(h) => (None, h.clone()),
746 };
747
748 let (task_manager, on_shutdown, task_events_tx, graceful_tasks) =
749 TaskManager::new_parts(handle.clone());
750
751 #[cfg(feature = "rayon")]
752 let (
753 cpu_pool,
754 rpc_pool,
755 storage_pool,
756 blocking_guard,
757 proof_storage_worker_pool,
758 proof_account_worker_pool,
759 prewarming_pool,
760 ) = {
761 let default_threads = config.rayon.default_thread_count();
762 let rpc_threads = config.rayon.rpc_threads.unwrap_or(default_threads);
763
764 let cpu_pool = build_pool_with_panic_handler(
765 rayon::ThreadPoolBuilder::new()
766 .num_threads(default_threads)
767 .thread_name(|i| format!("cpu-{i:02}")),
768 )?;
769
770 let rpc_raw = build_pool_with_panic_handler(
771 rayon::ThreadPoolBuilder::new()
772 .num_threads(rpc_threads)
773 .thread_name(|i| format!("rpc-{i:02}")),
774 )?;
775 let rpc_pool = BlockingTaskPool::new(rpc_raw);
776
777 let storage_threads =
778 config.rayon.storage_threads.unwrap_or(DEFAULT_STORAGE_POOL_THREADS);
779 let storage_pool = build_pool_with_panic_handler(
780 rayon::ThreadPoolBuilder::new()
781 .num_threads(storage_threads)
782 .thread_name(|i| format!("storage-{i:02}")),
783 )?;
784
785 let blocking_guard = BlockingTaskGuard::new(config.rayon.max_blocking_tasks);
786
787 let proof_storage_worker_threads =
788 config.rayon.proof_storage_worker_threads.unwrap_or(default_threads * 2);
789 let proof_storage_worker_pool = WorkerPool::from_builder(
790 rayon::ThreadPoolBuilder::new()
791 .num_threads(proof_storage_worker_threads)
792 .thread_name(|i| format!("proof-strg-{i:02}")),
793 )?;
794
795 let proof_account_worker_threads =
796 config.rayon.proof_account_worker_threads.unwrap_or(default_threads * 2);
797 let proof_account_worker_pool = WorkerPool::from_builder(
798 rayon::ThreadPoolBuilder::new()
799 .num_threads(proof_account_worker_threads)
800 .thread_name(|i| format!("proof-acct-{i:02}")),
801 )?;
802
803 let prewarming_threads = config.rayon.prewarming_threads.unwrap_or(default_threads);
804 let prewarming_pool = WorkerPool::from_builder(
805 rayon::ThreadPoolBuilder::new()
806 .num_threads(prewarming_threads)
807 .thread_name(|i| format!("prewarm-{i:02}")),
808 )?;
809
810 debug!(
811 default_threads,
812 rpc_threads,
813 storage_threads,
814 proof_storage_worker_threads,
815 proof_account_worker_threads,
816 prewarming_threads,
817 max_blocking_tasks = config.rayon.max_blocking_tasks,
818 "Initialized rayon thread pools"
819 );
820
821 (
822 cpu_pool,
823 rpc_pool,
824 storage_pool,
825 blocking_guard,
826 proof_storage_worker_pool,
827 proof_account_worker_pool,
828 prewarming_pool,
829 )
830 };
831
832 let task_manager_handle = handle.spawn(async move {
833 let result = task_manager.await;
834 if let Err(ref err) = result {
835 debug!("{err}");
836 }
837 result
838 });
839
840 let inner = RuntimeInner {
841 _tokio_runtime: owned_runtime,
842 handle,
843 on_shutdown,
844 task_events_tx,
845 metrics: Default::default(),
846 graceful_tasks,
847 #[cfg(feature = "rayon")]
848 cpu_pool,
849 #[cfg(feature = "rayon")]
850 rpc_pool,
851 #[cfg(feature = "rayon")]
852 storage_pool,
853 #[cfg(feature = "rayon")]
854 blocking_guard,
855 #[cfg(feature = "rayon")]
856 proof_storage_worker_pool,
857 #[cfg(feature = "rayon")]
858 proof_account_worker_pool,
859 #[cfg(feature = "rayon")]
860 prewarming_pool,
861 worker_map: WorkerMap::new(),
862 task_manager_handle: Mutex::new(Some(task_manager_handle)),
863 };
864
865 Ok(Runtime(Arc::new(inner)))
866 }
867}
868
869#[cfg(test)]
870mod tests {
871 use super::*;
872
873 #[test]
874 fn test_runtime_config_default() {
875 let config = RuntimeConfig::default();
876 assert!(matches!(config.tokio, TokioConfig::Owned { .. }));
877 }
878
879 #[test]
880 fn test_runtime_config_existing_handle() {
881 let rt = TokioRuntime::new().unwrap();
882 let config =
883 Runtime::test_config().with_tokio(TokioConfig::existing_handle(rt.handle().clone()));
884 assert!(matches!(config.tokio, TokioConfig::ExistingHandle(_)));
885 }
886
887 #[cfg(feature = "rayon")]
888 #[test]
889 fn test_rayon_config_thread_count() {
890 let config = RayonConfig::default();
891 let count = config.default_thread_count();
892 assert!(count >= 1);
893 }
894
895 #[test]
896 fn test_runtime_builder() {
897 let rt = TokioRuntime::new().unwrap();
898 let config =
899 Runtime::test_config().with_tokio(TokioConfig::existing_handle(rt.handle().clone()));
900 let runtime = RuntimeBuilder::new(config).build().unwrap();
901 let _ = runtime.handle();
902 }
903}