Skip to main content

reth_tasks/
pool.rs

1//! Additional helpers for executing tracing calls
2
3use std::{
4    any::Any,
5    cell::RefCell,
6    future::Future,
7    panic::{catch_unwind, AssertUnwindSafe},
8    pin::Pin,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc, OnceLock,
12    },
13    task::{ready, Context, Poll},
14    thread,
15};
16use tokio::sync::{oneshot, AcquireError, OwnedSemaphorePermit, Semaphore};
17
18/// RPC Tracing call guard semaphore.
19///
20/// This is used to restrict the number of concurrent RPC requests to tracing methods like
21/// `debug_traceTransaction` as well as `eth_getProof` because they can consume a lot of
22/// memory and CPU.
23///
24/// This types serves as an entry guard for the [`BlockingTaskPool`] and is used to rate limit
25/// parallel blocking tasks in the pool.
26#[derive(Clone, Debug)]
27pub struct BlockingTaskGuard(Arc<Semaphore>);
28
29impl BlockingTaskGuard {
30    /// Create a new `BlockingTaskGuard` with the given maximum number of blocking tasks in
31    /// parallel.
32    pub fn new(max_blocking_tasks: usize) -> Self {
33        Self(Arc::new(Semaphore::new(max_blocking_tasks)))
34    }
35
36    /// See also [`Semaphore::acquire_owned`]
37    pub async fn acquire_owned(self) -> Result<OwnedSemaphorePermit, AcquireError> {
38        self.0.acquire_owned().await
39    }
40
41    /// See also [`Semaphore::acquire_many_owned`]
42    pub async fn acquire_many_owned(self, n: u32) -> Result<OwnedSemaphorePermit, AcquireError> {
43        self.0.acquire_many_owned(n).await
44    }
45}
46
47/// Used to execute blocking tasks on a rayon threadpool from within a tokio runtime.
48///
49/// This is a dedicated threadpool for blocking tasks which are CPU bound.
50/// RPC calls that perform blocking IO (disk lookups) are not executed on this pool but on the tokio
51/// runtime's blocking pool, which performs poorly with CPU bound tasks (see
52/// <https://ryhl.io/blog/async-what-is-blocking/>). Once the tokio blocking
53/// pool is saturated it is converted into a queue, blocking tasks could then interfere with the
54/// queue and block other RPC calls.
55///
56/// See also [tokio-docs] for more information.
57///
58/// [tokio-docs]: https://docs.rs/tokio/latest/tokio/index.html#cpu-bound-tasks-and-blocking-code
59#[derive(Clone, Debug)]
60pub struct BlockingTaskPool {
61    pool: Arc<rayon::ThreadPool>,
62}
63
64impl BlockingTaskPool {
65    /// Create a new `BlockingTaskPool` with the given threadpool.
66    pub fn new(pool: rayon::ThreadPool) -> Self {
67        Self { pool: Arc::new(pool) }
68    }
69
70    /// Convenience function to start building a new threadpool.
71    pub fn builder() -> rayon::ThreadPoolBuilder {
72        rayon::ThreadPoolBuilder::new()
73    }
74
75    /// Convenience function to build a new threadpool with the default configuration.
76    ///
77    /// Uses [`rayon::ThreadPoolBuilder::build`](rayon::ThreadPoolBuilder::build) defaults.
78    /// If a different stack size or other parameters are needed, they can be configured via
79    /// [`rayon::ThreadPoolBuilder`] returned by [`Self::builder`].
80    pub fn build() -> Result<Self, rayon::ThreadPoolBuildError> {
81        Self::builder().build().map(Self::new)
82    }
83
84    /// Asynchronous wrapper around Rayon's
85    /// [`ThreadPool::spawn`](rayon::ThreadPool::spawn).
86    ///
87    /// Runs a function on the configured threadpool, returning a future that resolves with the
88    /// function's return value.
89    ///
90    /// If the function panics, the future will resolve to an error.
91    pub fn spawn<F, R>(&self, func: F) -> BlockingTaskHandle<R>
92    where
93        F: FnOnce() -> R + Send + 'static,
94        R: Send + 'static,
95    {
96        let (tx, rx) = oneshot::channel();
97
98        self.pool.spawn(move || {
99            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
100        });
101
102        BlockingTaskHandle { rx }
103    }
104
105    /// Asynchronous wrapper around Rayon's
106    /// [`ThreadPool::spawn_fifo`](rayon::ThreadPool::spawn_fifo).
107    ///
108    /// Runs a function on the configured threadpool, returning a future that resolves with the
109    /// function's return value.
110    ///
111    /// If the function panics, the future will resolve to an error.
112    pub fn spawn_fifo<F, R>(&self, func: F) -> BlockingTaskHandle<R>
113    where
114        F: FnOnce() -> R + Send + 'static,
115        R: Send + 'static,
116    {
117        let (tx, rx) = oneshot::channel();
118
119        self.pool.spawn_fifo(move || {
120            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
121        });
122
123        BlockingTaskHandle { rx }
124    }
125}
126
127/// Async handle for a blocking task running in a Rayon thread pool.
128///
129/// ## Panics
130///
131/// If polled from outside a tokio runtime.
132#[derive(Debug)]
133#[must_use = "futures do nothing unless you `.await` or poll them"]
134#[pin_project::pin_project]
135pub struct BlockingTaskHandle<T> {
136    #[pin]
137    pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
138}
139
140impl<T> Future for BlockingTaskHandle<T> {
141    type Output = thread::Result<T>;
142
143    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144        match ready!(self.project().rx.poll(cx)) {
145            Ok(res) => Poll::Ready(res),
146            Err(_) => Poll::Ready(Err(Box::<TokioBlockingTaskError>::default())),
147        }
148    }
149}
150
151/// An error returned when the Tokio channel is dropped while awaiting a result.
152///
153/// This should only happen
154#[derive(Debug, Default, thiserror::Error)]
155#[error("tokio channel dropped while awaiting result")]
156#[non_exhaustive]
157pub struct TokioBlockingTaskError;
158
159thread_local! {
160    static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
161}
162
163/// A rayon thread pool with per-thread [`Worker`] state.
164///
165/// Each thread in the pool has its own [`Worker`] that can hold arbitrary state via
166/// [`Worker::init`]. The state is thread-local and accessible during [`install`](Self::install)
167/// calls.
168///
169/// The pool supports multiple init/clear cycles, allowing reuse of the same threads with
170/// different state configurations.
171///
172/// The underlying rayon pool is created lazily on first access.
173#[derive(Debug)]
174pub struct WorkerPool {
175    pool: OnceLock<rayon::ThreadPool>,
176    num_threads: usize,
177    thread_name_prefix: &'static str,
178}
179
180impl WorkerPool {
181    /// Creates a new lazy `WorkerPool` with the given number of threads and a thread name prefix.
182    ///
183    /// The underlying rayon pool is not created until the first method that requires it is called.
184    /// Thread names follow the pattern `"{prefix}-{index:02}"`.
185    pub const fn new(num_threads: usize, thread_name_prefix: &'static str) -> Self {
186        Self { pool: OnceLock::new(), num_threads, thread_name_prefix }
187    }
188
189    /// Returns a reference to the underlying rayon pool, creating it on first access.
190    fn pool(&self) -> &rayon::ThreadPool {
191        self.pool.get_or_init(|| {
192            let prefix = self.thread_name_prefix;
193            build_pool_with_panic_handler(
194                rayon::ThreadPoolBuilder::new()
195                    .num_threads(self.num_threads)
196                    .thread_name(move |i| format!("{prefix}-{i:02}")),
197            )
198            .unwrap_or_else(|err| panic!("failed to build {prefix} worker pool: {err}"))
199        })
200    }
201
202    /// Returns `true` if the underlying rayon pool has been initialized.
203    pub fn is_initialized(&self) -> bool {
204        self.pool.get().is_some()
205    }
206
207    /// Returns the total number of threads in the underlying rayon pool.
208    pub fn current_num_threads(&self) -> usize {
209        self.pool().current_num_threads()
210    }
211
212    /// Initializes per-thread [`Worker`] state on every thread in the pool.
213    pub fn init<T: 'static>(&self, f: impl Fn(Option<&mut T>) -> T + Sync) {
214        self.broadcast(self.pool().current_num_threads(), |worker| {
215            worker.init::<T>(&f);
216        });
217    }
218
219    /// Runs a closure on `num_threads` threads in the pool, giving mutable access to each
220    /// thread's [`Worker`].
221    ///
222    /// Use this to initialize or re-initialize per-thread state via [`Worker::init`].
223    /// Only `num_threads` threads execute the closure; the rest skip it.
224    pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
225        if num_threads >= self.pool().current_num_threads() {
226            // Fast path: run on every thread, no atomic coordination needed.
227            self.pool().broadcast(|_| {
228                WORKER.with_borrow_mut(|worker| f(worker));
229            });
230        } else {
231            let remaining = AtomicUsize::new(num_threads);
232            self.pool().broadcast(|_| {
233                // Atomically claim a slot; threads that can't decrement skip the closure.
234                let mut current = remaining.load(Ordering::Relaxed);
235                loop {
236                    if current == 0 {
237                        return;
238                    }
239                    match remaining.compare_exchange_weak(
240                        current,
241                        current - 1,
242                        Ordering::Relaxed,
243                        Ordering::Relaxed,
244                    ) {
245                        Ok(_) => break,
246                        Err(actual) => current = actual,
247                    }
248                }
249                WORKER.with_borrow_mut(|worker| f(worker));
250            });
251        }
252    }
253
254    /// Clears the state on every thread in the pool.
255    pub fn clear(&self) {
256        self.pool().broadcast(|_| {
257            WORKER.with_borrow_mut(Worker::clear);
258        });
259    }
260
261    /// Runs a closure on the pool with access to the calling thread's [`Worker`].
262    ///
263    /// All rayon parallelism (e.g. `par_iter`) spawned inside the closure executes on this pool.
264    /// Each thread can access its own [`Worker`] via the provided reference or through additional
265    /// [`WorkerPool::with_worker`] calls.
266    pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
267        self.pool().install(|| WORKER.with_borrow(|worker| f(worker)))
268    }
269
270    /// Runs a closure on the pool without worker state access.
271    ///
272    /// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`]
273    /// state.
274    pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
275        self.pool().install(f)
276    }
277
278    /// Spawns a closure on the pool.
279    pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
280        self.pool().spawn(f);
281    }
282
283    /// Executes `f` on this pool using [`rayon::in_place_scope`], which converts the calling
284    /// thread into a worker for the duration — tasks spawned inside the scope run on the pool
285    /// and the call blocks until all of them complete.
286    pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R {
287        self.pool().in_place_scope(f)
288    }
289
290    /// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure.
291    ///
292    /// This is useful for accessing the worker from inside `par_iter` where the initial `&Worker`
293    /// reference from `install` belongs to a different thread.
294    pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
295        WORKER.with_borrow(|worker| f(worker))
296    }
297
298    /// Mutably access the current thread's [`Worker`] from within a pool closure.
299    pub fn with_worker_mut<R>(f: impl FnOnce(&mut Worker) -> R) -> R {
300        WORKER.with_borrow_mut(|worker| f(worker))
301    }
302}
303
304/// Builds a rayon thread pool with a panic handler that prevents aborting the process.
305///
306/// Rust's default panic hook already logs the panic message and backtrace to stderr, so the handler
307/// itself is intentionally a no-op.
308pub fn build_pool_with_panic_handler(
309    builder: rayon::ThreadPoolBuilder,
310) -> Result<rayon::ThreadPool, rayon::ThreadPoolBuildError> {
311    builder.panic_handler(|_| {}).build()
312}
313
314/// Per-thread state container for a [`WorkerPool`].
315///
316/// Holds a type-erased `Box<dyn Any>` that can be initialized and accessed with concrete types
317/// via [`init`](Self::init) and [`get`](Self::get).
318#[derive(Debug, Default)]
319pub struct Worker {
320    state: Option<Box<dyn Any>>,
321}
322
323impl Worker {
324    /// Creates a new empty `Worker`.
325    const fn new() -> Self {
326        Self { state: None }
327    }
328
329    /// Initializes the worker state.
330    ///
331    /// If state of type `T` already exists, passes `Some(&mut T)` to the closure so resources
332    /// can be reused. On first init, passes `None`.
333    pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
334        let existing =
335            self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
336
337        let new_state = match existing {
338            Some(mut boxed) => {
339                let r = boxed.downcast_mut::<T>().expect("type checked above");
340                *r = f(Some(r));
341                boxed
342            }
343            None => Box::new(f(None)),
344        };
345
346        self.state = Some(new_state);
347    }
348
349    /// Returns a reference to the state, downcasted to `T`.
350    ///
351    /// # Panics
352    ///
353    /// Panics if the worker has not been initialized or if the type does not match.
354    pub fn get<T: 'static>(&self) -> &T {
355        self.state
356            .as_ref()
357            .expect("worker not initialized")
358            .downcast_ref::<T>()
359            .expect("worker state type mismatch")
360    }
361
362    /// Returns a mutable reference to the state, downcasted to `T`.
363    ///
364    /// # Panics
365    ///
366    /// Panics if the worker has not been initialized or if the type does not match.
367    pub fn get_mut<T: 'static>(&mut self) -> &mut T {
368        self.state
369            .as_mut()
370            .expect("worker not initialized")
371            .downcast_mut::<T>()
372            .expect("worker state type mismatch")
373    }
374
375    /// Returns a mutable reference to the state, initializing it with `f` on first access.
376    ///
377    /// # Panics
378    ///
379    /// Panics if the state was previously initialized with a different type.
380    pub fn get_or_init<T: 'static>(&mut self, f: impl FnOnce() -> T) -> &mut T {
381        self.state
382            .get_or_insert_with(|| Box::new(f()))
383            .downcast_mut::<T>()
384            .expect("worker state type mismatch")
385    }
386
387    /// Clears the worker state, dropping the contained value.
388    pub fn clear(&mut self) {
389        self.state = None;
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[tokio::test]
398    async fn blocking_pool() {
399        let pool = BlockingTaskPool::build().unwrap();
400        let res = pool.spawn(move || 5);
401        let res = res.await.unwrap();
402        assert_eq!(res, 5);
403    }
404
405    #[tokio::test]
406    async fn blocking_pool_panic() {
407        let pool = BlockingTaskPool::build().unwrap();
408        let res = pool.spawn(move || -> i32 {
409            panic!();
410        });
411        let res = res.await;
412        assert!(res.is_err());
413    }
414
415    #[test]
416    fn worker_pool_init_and_access() {
417        let pool = WorkerPool::new(2, "test");
418
419        pool.broadcast(2, |worker| {
420            worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
421        });
422
423        let sum: u8 = pool.install(|worker| {
424            let v = worker.get::<Vec<u8>>();
425            v.iter().sum()
426        });
427        assert_eq!(sum, 6);
428
429        pool.clear();
430    }
431
432    #[test]
433    fn worker_pool_reinit_reuses_resources() {
434        let pool = WorkerPool::new(1, "test");
435
436        pool.broadcast(1, |worker| {
437            worker.init::<Vec<u8>>(|existing| {
438                assert!(existing.is_none());
439                vec![1, 2, 3]
440            });
441        });
442
443        pool.broadcast(1, |worker| {
444            worker.init::<Vec<u8>>(|existing| {
445                let v = existing.expect("should have existing state");
446                assert_eq!(v, &mut vec![1, 2, 3]);
447                v.push(4);
448                std::mem::take(v)
449            });
450        });
451
452        let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
453        assert_eq!(len, 4);
454
455        pool.clear();
456    }
457
458    #[test]
459    fn worker_pool_clear_and_reinit() {
460        let pool = WorkerPool::new(1, "test");
461
462        pool.broadcast(1, |worker| {
463            worker.init::<u64>(|_| 42);
464        });
465        let val = pool.install(|worker| *worker.get::<u64>());
466        assert_eq!(val, 42);
467
468        pool.clear();
469
470        pool.broadcast(1, |worker| {
471            worker.init::<String>(|_| "hello".to_string());
472        });
473        let val = pool.install(|worker| worker.get::<String>().clone());
474        assert_eq!(val, "hello");
475
476        pool.clear();
477    }
478
479    #[test]
480    fn worker_pool_par_iter_with_worker() {
481        use rayon::prelude::*;
482
483        let pool = WorkerPool::new(2, "test");
484
485        pool.broadcast(2, |worker| {
486            worker.init::<u64>(|_| 10);
487        });
488
489        let results: Vec<u64> = pool.install(|_| {
490            (0u64..4)
491                .into_par_iter()
492                .map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
493                .collect()
494        });
495        assert_eq!(results, vec![10, 11, 12, 13]);
496
497        pool.clear();
498    }
499}