Skip to main content

reth_tasks/
pool.rs

1//! Additional helpers for executing tracing calls
2
3use std::{
4    any::Any,
5    cell::RefCell,
6    future::Future,
7    panic::{catch_unwind, AssertUnwindSafe},
8    pin::Pin,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc,
12    },
13    task::{ready, Context, Poll},
14    thread,
15};
16use tokio::sync::{oneshot, AcquireError, OwnedSemaphorePermit, Semaphore};
17
18/// RPC Tracing call guard semaphore.
19///
20/// This is used to restrict the number of concurrent RPC requests to tracing methods like
21/// `debug_traceTransaction` as well as `eth_getProof` because they can consume a lot of
22/// memory and CPU.
23///
24/// This types serves as an entry guard for the [`BlockingTaskPool`] and is used to rate limit
25/// parallel blocking tasks in the pool.
26#[derive(Clone, Debug)]
27pub struct BlockingTaskGuard(Arc<Semaphore>);
28
29impl BlockingTaskGuard {
30    /// Create a new `BlockingTaskGuard` with the given maximum number of blocking tasks in
31    /// parallel.
32    pub fn new(max_blocking_tasks: usize) -> Self {
33        Self(Arc::new(Semaphore::new(max_blocking_tasks)))
34    }
35
36    /// See also [`Semaphore::acquire_owned`]
37    pub async fn acquire_owned(self) -> Result<OwnedSemaphorePermit, AcquireError> {
38        self.0.acquire_owned().await
39    }
40
41    /// See also [`Semaphore::acquire_many_owned`]
42    pub async fn acquire_many_owned(self, n: u32) -> Result<OwnedSemaphorePermit, AcquireError> {
43        self.0.acquire_many_owned(n).await
44    }
45}
46
47/// Used to execute blocking tasks on a rayon threadpool from within a tokio runtime.
48///
49/// This is a dedicated threadpool for blocking tasks which are CPU bound.
50/// RPC calls that perform blocking IO (disk lookups) are not executed on this pool but on the tokio
51/// runtime's blocking pool, which performs poorly with CPU bound tasks (see
52/// <https://ryhl.io/blog/async-what-is-blocking/>). Once the tokio blocking
53/// pool is saturated it is converted into a queue, blocking tasks could then interfere with the
54/// queue and block other RPC calls.
55///
56/// See also [tokio-docs] for more information.
57///
58/// [tokio-docs]: https://docs.rs/tokio/latest/tokio/index.html#cpu-bound-tasks-and-blocking-code
59#[derive(Clone, Debug)]
60pub struct BlockingTaskPool {
61    pool: Arc<rayon::ThreadPool>,
62}
63
64impl BlockingTaskPool {
65    /// Create a new `BlockingTaskPool` with the given threadpool.
66    pub fn new(pool: rayon::ThreadPool) -> Self {
67        Self { pool: Arc::new(pool) }
68    }
69
70    /// Convenience function to start building a new threadpool.
71    pub fn builder() -> rayon::ThreadPoolBuilder {
72        rayon::ThreadPoolBuilder::new()
73    }
74
75    /// Convenience function to build a new threadpool with the default configuration.
76    ///
77    /// Uses [`rayon::ThreadPoolBuilder::build`](rayon::ThreadPoolBuilder::build) defaults.
78    /// If a different stack size or other parameters are needed, they can be configured via
79    /// [`rayon::ThreadPoolBuilder`] returned by [`Self::builder`].
80    pub fn build() -> Result<Self, rayon::ThreadPoolBuildError> {
81        Self::builder().build().map(Self::new)
82    }
83
84    /// Asynchronous wrapper around Rayon's
85    /// [`ThreadPool::spawn`](rayon::ThreadPool::spawn).
86    ///
87    /// Runs a function on the configured threadpool, returning a future that resolves with the
88    /// function's return value.
89    ///
90    /// If the function panics, the future will resolve to an error.
91    pub fn spawn<F, R>(&self, func: F) -> BlockingTaskHandle<R>
92    where
93        F: FnOnce() -> R + Send + 'static,
94        R: Send + 'static,
95    {
96        let (tx, rx) = oneshot::channel();
97
98        self.pool.spawn(move || {
99            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
100        });
101
102        BlockingTaskHandle { rx }
103    }
104
105    /// Asynchronous wrapper around Rayon's
106    /// [`ThreadPool::spawn_fifo`](rayon::ThreadPool::spawn_fifo).
107    ///
108    /// Runs a function on the configured threadpool, returning a future that resolves with the
109    /// function's return value.
110    ///
111    /// If the function panics, the future will resolve to an error.
112    pub fn spawn_fifo<F, R>(&self, func: F) -> BlockingTaskHandle<R>
113    where
114        F: FnOnce() -> R + Send + 'static,
115        R: Send + 'static,
116    {
117        let (tx, rx) = oneshot::channel();
118
119        self.pool.spawn_fifo(move || {
120            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
121        });
122
123        BlockingTaskHandle { rx }
124    }
125}
126
127/// Async handle for a blocking task running in a Rayon thread pool.
128///
129/// ## Panics
130///
131/// If polled from outside a tokio runtime.
132#[derive(Debug)]
133#[must_use = "futures do nothing unless you `.await` or poll them"]
134#[pin_project::pin_project]
135pub struct BlockingTaskHandle<T> {
136    #[pin]
137    pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
138}
139
140impl<T> Future for BlockingTaskHandle<T> {
141    type Output = thread::Result<T>;
142
143    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144        match ready!(self.project().rx.poll(cx)) {
145            Ok(res) => Poll::Ready(res),
146            Err(_) => Poll::Ready(Err(Box::<TokioBlockingTaskError>::default())),
147        }
148    }
149}
150
151/// An error returned when the Tokio channel is dropped while awaiting a result.
152///
153/// This should only happen
154#[derive(Debug, Default, thiserror::Error)]
155#[error("tokio channel dropped while awaiting result")]
156#[non_exhaustive]
157pub struct TokioBlockingTaskError;
158
159thread_local! {
160    static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
161}
162
163/// A rayon thread pool with per-thread [`Worker`] state.
164///
165/// Each thread in the pool has its own [`Worker`] that can hold arbitrary state via
166/// [`Worker::init`]. The state is thread-local and accessible during [`install`](Self::install)
167/// calls.
168///
169/// The pool supports multiple init/clear cycles, allowing reuse of the same threads with
170/// different state configurations.
171#[derive(Debug)]
172pub struct WorkerPool {
173    pool: rayon::ThreadPool,
174}
175
176impl WorkerPool {
177    /// Creates a new `WorkerPool` with the given number of threads.
178    pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
179        Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads))
180    }
181
182    /// Creates a new `WorkerPool` from a [`rayon::ThreadPoolBuilder`].
183    ///
184    /// Installs a panic handler that logs panics instead of aborting the process.
185    pub fn from_builder(
186        builder: rayon::ThreadPoolBuilder,
187    ) -> Result<Self, rayon::ThreadPoolBuildError> {
188        Ok(Self { pool: build_pool_with_panic_handler(builder)? })
189    }
190
191    /// Returns the total number of threads in the underlying rayon pool.
192    pub fn current_num_threads(&self) -> usize {
193        self.pool.current_num_threads()
194    }
195
196    /// Initializes per-thread [`Worker`] state on every thread in the pool.
197    pub fn init<T: 'static>(&self, f: impl Fn(Option<&mut T>) -> T + Sync) {
198        self.broadcast(self.pool.current_num_threads(), |worker| {
199            worker.init::<T>(&f);
200        });
201    }
202
203    /// Runs a closure on `num_threads` threads in the pool, giving mutable access to each
204    /// thread's [`Worker`].
205    ///
206    /// Use this to initialize or re-initialize per-thread state via [`Worker::init`].
207    /// Only `num_threads` threads execute the closure; the rest skip it.
208    pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
209        if num_threads >= self.pool.current_num_threads() {
210            // Fast path: run on every thread, no atomic coordination needed.
211            self.pool.broadcast(|_| {
212                WORKER.with_borrow_mut(|worker| f(worker));
213            });
214        } else {
215            let remaining = AtomicUsize::new(num_threads);
216            self.pool.broadcast(|_| {
217                // Atomically claim a slot; threads that can't decrement skip the closure.
218                let mut current = remaining.load(Ordering::Relaxed);
219                loop {
220                    if current == 0 {
221                        return;
222                    }
223                    match remaining.compare_exchange_weak(
224                        current,
225                        current - 1,
226                        Ordering::Relaxed,
227                        Ordering::Relaxed,
228                    ) {
229                        Ok(_) => break,
230                        Err(actual) => current = actual,
231                    }
232                }
233                WORKER.with_borrow_mut(|worker| f(worker));
234            });
235        }
236    }
237
238    /// Clears the state on every thread in the pool.
239    pub fn clear(&self) {
240        self.pool.broadcast(|_| {
241            WORKER.with_borrow_mut(Worker::clear);
242        });
243    }
244
245    /// Runs a closure on the pool with access to the calling thread's [`Worker`].
246    ///
247    /// All rayon parallelism (e.g. `par_iter`) spawned inside the closure executes on this pool.
248    /// Each thread can access its own [`Worker`] via the provided reference or through additional
249    /// [`WorkerPool::with_worker`] calls.
250    pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
251        self.pool.install(|| WORKER.with_borrow(|worker| f(worker)))
252    }
253
254    /// Runs a closure on the pool without worker state access.
255    ///
256    /// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`]
257    /// state.
258    pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
259        self.pool.install(f)
260    }
261
262    /// Spawns a closure on the pool.
263    pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
264        self.pool.spawn(f);
265    }
266
267    /// Executes `f` on this pool using [`rayon::in_place_scope`], which converts the calling
268    /// thread into a worker for the duration — tasks spawned inside the scope run on the pool
269    /// and the call blocks until all of them complete.
270    pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R {
271        self.pool.in_place_scope(f)
272    }
273
274    /// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure.
275    ///
276    /// This is useful for accessing the worker from inside `par_iter` where the initial `&Worker`
277    /// reference from `install` belongs to a different thread.
278    pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
279        WORKER.with_borrow(|worker| f(worker))
280    }
281
282    /// Mutably access the current thread's [`Worker`] from within a pool closure.
283    pub fn with_worker_mut<R>(f: impl FnOnce(&mut Worker) -> R) -> R {
284        WORKER.with_borrow_mut(|worker| f(worker))
285    }
286}
287
288/// Builds a rayon thread pool with a panic handler that prevents aborting the process.
289///
290/// Rust's default panic hook already logs the panic message and backtrace to stderr, so the handler
291/// itself is intentionally a no-op.
292pub fn build_pool_with_panic_handler(
293    builder: rayon::ThreadPoolBuilder,
294) -> Result<rayon::ThreadPool, rayon::ThreadPoolBuildError> {
295    builder.panic_handler(|_| {}).build()
296}
297
298/// Per-thread state container for a [`WorkerPool`].
299///
300/// Holds a type-erased `Box<dyn Any>` that can be initialized and accessed with concrete types
301/// via [`init`](Self::init) and [`get`](Self::get).
302#[derive(Debug, Default)]
303pub struct Worker {
304    state: Option<Box<dyn Any>>,
305}
306
307impl Worker {
308    /// Creates a new empty `Worker`.
309    const fn new() -> Self {
310        Self { state: None }
311    }
312
313    /// Initializes the worker state.
314    ///
315    /// If state of type `T` already exists, passes `Some(&mut T)` to the closure so resources
316    /// can be reused. On first init, passes `None`.
317    pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
318        let existing =
319            self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
320
321        let new_state = match existing {
322            Some(mut boxed) => {
323                let r = boxed.downcast_mut::<T>().expect("type checked above");
324                *r = f(Some(r));
325                boxed
326            }
327            None => Box::new(f(None)),
328        };
329
330        self.state = Some(new_state);
331    }
332
333    /// Returns a reference to the state, downcasted to `T`.
334    ///
335    /// # Panics
336    ///
337    /// Panics if the worker has not been initialized or if the type does not match.
338    pub fn get<T: 'static>(&self) -> &T {
339        self.state
340            .as_ref()
341            .expect("worker not initialized")
342            .downcast_ref::<T>()
343            .expect("worker state type mismatch")
344    }
345
346    /// Returns a mutable reference to the state, downcasted to `T`.
347    ///
348    /// # Panics
349    ///
350    /// Panics if the worker has not been initialized or if the type does not match.
351    pub fn get_mut<T: 'static>(&mut self) -> &mut T {
352        self.state
353            .as_mut()
354            .expect("worker not initialized")
355            .downcast_mut::<T>()
356            .expect("worker state type mismatch")
357    }
358
359    /// Returns a mutable reference to the state, initializing it with `f` on first access.
360    ///
361    /// # Panics
362    ///
363    /// Panics if the state was previously initialized with a different type.
364    pub fn get_or_init<T: 'static>(&mut self, f: impl FnOnce() -> T) -> &mut T {
365        self.state
366            .get_or_insert_with(|| Box::new(f()))
367            .downcast_mut::<T>()
368            .expect("worker state type mismatch")
369    }
370
371    /// Clears the worker state, dropping the contained value.
372    pub fn clear(&mut self) {
373        self.state = None;
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[tokio::test]
382    async fn blocking_pool() {
383        let pool = BlockingTaskPool::build().unwrap();
384        let res = pool.spawn(move || 5);
385        let res = res.await.unwrap();
386        assert_eq!(res, 5);
387    }
388
389    #[tokio::test]
390    async fn blocking_pool_panic() {
391        let pool = BlockingTaskPool::build().unwrap();
392        let res = pool.spawn(move || -> i32 {
393            panic!();
394        });
395        let res = res.await;
396        assert!(res.is_err());
397    }
398
399    #[test]
400    fn worker_pool_init_and_access() {
401        let pool = WorkerPool::new(2).unwrap();
402
403        pool.broadcast(2, |worker| {
404            worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
405        });
406
407        let sum: u8 = pool.install(|worker| {
408            let v = worker.get::<Vec<u8>>();
409            v.iter().sum()
410        });
411        assert_eq!(sum, 6);
412
413        pool.clear();
414    }
415
416    #[test]
417    fn worker_pool_reinit_reuses_resources() {
418        let pool = WorkerPool::new(1).unwrap();
419
420        pool.broadcast(1, |worker| {
421            worker.init::<Vec<u8>>(|existing| {
422                assert!(existing.is_none());
423                vec![1, 2, 3]
424            });
425        });
426
427        pool.broadcast(1, |worker| {
428            worker.init::<Vec<u8>>(|existing| {
429                let v = existing.expect("should have existing state");
430                assert_eq!(v, &mut vec![1, 2, 3]);
431                v.push(4);
432                std::mem::take(v)
433            });
434        });
435
436        let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
437        assert_eq!(len, 4);
438
439        pool.clear();
440    }
441
442    #[test]
443    fn worker_pool_clear_and_reinit() {
444        let pool = WorkerPool::new(1).unwrap();
445
446        pool.broadcast(1, |worker| {
447            worker.init::<u64>(|_| 42);
448        });
449        let val = pool.install(|worker| *worker.get::<u64>());
450        assert_eq!(val, 42);
451
452        pool.clear();
453
454        pool.broadcast(1, |worker| {
455            worker.init::<String>(|_| "hello".to_string());
456        });
457        let val = pool.install(|worker| worker.get::<String>().clone());
458        assert_eq!(val, "hello");
459
460        pool.clear();
461    }
462
463    #[test]
464    fn worker_pool_par_iter_with_worker() {
465        use rayon::prelude::*;
466
467        let pool = WorkerPool::new(2).unwrap();
468
469        pool.broadcast(2, |worker| {
470            worker.init::<u64>(|_| 10);
471        });
472
473        let results: Vec<u64> = pool.install(|_| {
474            (0u64..4)
475                .into_par_iter()
476                .map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
477                .collect()
478        });
479        assert_eq!(results, vec![10, 11, 12, 13]);
480
481        pool.clear();
482    }
483}