Skip to main content

reth_tasks/
pool.rs

1//! Additional helpers for executing tracing calls
2
3use std::{
4    any::Any,
5    cell::RefCell,
6    future::Future,
7    panic::{catch_unwind, AssertUnwindSafe},
8    pin::Pin,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc,
12    },
13    task::{ready, Context, Poll},
14    thread,
15};
16use tokio::sync::{oneshot, AcquireError, OwnedSemaphorePermit, Semaphore};
17
18/// RPC Tracing call guard semaphore.
19///
20/// This is used to restrict the number of concurrent RPC requests to tracing methods like
21/// `debug_traceTransaction` as well as `eth_getProof` because they can consume a lot of
22/// memory and CPU.
23///
24/// This types serves as an entry guard for the [`BlockingTaskPool`] and is used to rate limit
25/// parallel blocking tasks in the pool.
26#[derive(Clone, Debug)]
27pub struct BlockingTaskGuard(Arc<Semaphore>);
28
29impl BlockingTaskGuard {
30    /// Create a new `BlockingTaskGuard` with the given maximum number of blocking tasks in
31    /// parallel.
32    pub fn new(max_blocking_tasks: usize) -> Self {
33        Self(Arc::new(Semaphore::new(max_blocking_tasks)))
34    }
35
36    /// See also [`Semaphore::acquire_owned`]
37    pub async fn acquire_owned(self) -> Result<OwnedSemaphorePermit, AcquireError> {
38        self.0.acquire_owned().await
39    }
40
41    /// See also [`Semaphore::acquire_many_owned`]
42    pub async fn acquire_many_owned(self, n: u32) -> Result<OwnedSemaphorePermit, AcquireError> {
43        self.0.acquire_many_owned(n).await
44    }
45}
46
47/// Used to execute blocking tasks on a rayon threadpool from within a tokio runtime.
48///
49/// This is a dedicated threadpool for blocking tasks which are CPU bound.
50/// RPC calls that perform blocking IO (disk lookups) are not executed on this pool but on the tokio
51/// runtime's blocking pool, which performs poorly with CPU bound tasks (see
52/// <https://ryhl.io/blog/async-what-is-blocking/>). Once the tokio blocking
53/// pool is saturated it is converted into a queue, blocking tasks could then interfere with the
54/// queue and block other RPC calls.
55///
56/// See also [tokio-docs] for more information.
57///
58/// [tokio-docs]: https://docs.rs/tokio/latest/tokio/index.html#cpu-bound-tasks-and-blocking-code
59#[derive(Clone, Debug)]
60pub struct BlockingTaskPool {
61    pool: Arc<rayon::ThreadPool>,
62}
63
64impl BlockingTaskPool {
65    /// Create a new `BlockingTaskPool` with the given threadpool.
66    pub fn new(pool: rayon::ThreadPool) -> Self {
67        Self { pool: Arc::new(pool) }
68    }
69
70    /// Convenience function to start building a new threadpool.
71    pub fn builder() -> rayon::ThreadPoolBuilder {
72        rayon::ThreadPoolBuilder::new()
73    }
74
75    /// Convenience function to build a new threadpool with the default configuration.
76    ///
77    /// Uses [`rayon::ThreadPoolBuilder::build`](rayon::ThreadPoolBuilder::build) defaults.
78    /// If a different stack size or other parameters are needed, they can be configured via
79    /// [`rayon::ThreadPoolBuilder`] returned by [`Self::builder`].
80    pub fn build() -> Result<Self, rayon::ThreadPoolBuildError> {
81        Self::builder().build().map(Self::new)
82    }
83
84    /// Asynchronous wrapper around Rayon's
85    /// [`ThreadPool::spawn`](rayon::ThreadPool::spawn).
86    ///
87    /// Runs a function on the configured threadpool, returning a future that resolves with the
88    /// function's return value.
89    ///
90    /// If the function panics, the future will resolve to an error.
91    pub fn spawn<F, R>(&self, func: F) -> BlockingTaskHandle<R>
92    where
93        F: FnOnce() -> R + Send + 'static,
94        R: Send + 'static,
95    {
96        let (tx, rx) = oneshot::channel();
97
98        self.pool.spawn(move || {
99            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
100        });
101
102        BlockingTaskHandle { rx }
103    }
104
105    /// Asynchronous wrapper around Rayon's
106    /// [`ThreadPool::spawn_fifo`](rayon::ThreadPool::spawn_fifo).
107    ///
108    /// Runs a function on the configured threadpool, returning a future that resolves with the
109    /// function's return value.
110    ///
111    /// If the function panics, the future will resolve to an error.
112    pub fn spawn_fifo<F, R>(&self, func: F) -> BlockingTaskHandle<R>
113    where
114        F: FnOnce() -> R + Send + 'static,
115        R: Send + 'static,
116    {
117        let (tx, rx) = oneshot::channel();
118
119        self.pool.spawn_fifo(move || {
120            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
121        });
122
123        BlockingTaskHandle { rx }
124    }
125}
126
127/// Async handle for a blocking task running in a Rayon thread pool.
128///
129/// ## Panics
130///
131/// If polled from outside a tokio runtime.
132#[derive(Debug)]
133#[must_use = "futures do nothing unless you `.await` or poll them"]
134#[pin_project::pin_project]
135pub struct BlockingTaskHandle<T> {
136    #[pin]
137    pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
138}
139
140impl<T> Future for BlockingTaskHandle<T> {
141    type Output = thread::Result<T>;
142
143    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144        match ready!(self.project().rx.poll(cx)) {
145            Ok(res) => Poll::Ready(res),
146            Err(_) => Poll::Ready(Err(Box::<TokioBlockingTaskError>::default())),
147        }
148    }
149}
150
151/// An error returned when the Tokio channel is dropped while awaiting a result.
152///
153/// This should only happen
154#[derive(Debug, Default, thiserror::Error)]
155#[error("tokio channel dropped while awaiting result")]
156#[non_exhaustive]
157pub struct TokioBlockingTaskError;
158
159thread_local! {
160    static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
161}
162
163/// A rayon thread pool with per-thread [`Worker`] state.
164///
165/// Each thread in the pool has its own [`Worker`] that can hold arbitrary state via
166/// [`Worker::init`]. The state is thread-local and accessible during [`install`](Self::install)
167/// calls.
168///
169/// The pool supports multiple init/clear cycles, allowing reuse of the same threads with
170/// different state configurations.
171#[derive(Debug)]
172pub struct WorkerPool {
173    pool: rayon::ThreadPool,
174}
175
176impl WorkerPool {
177    /// Creates a new `WorkerPool` with the given number of threads.
178    pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
179        Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads))
180    }
181
182    /// Creates a new `WorkerPool` from a [`rayon::ThreadPoolBuilder`].
183    pub fn from_builder(
184        builder: rayon::ThreadPoolBuilder,
185    ) -> Result<Self, rayon::ThreadPoolBuildError> {
186        Ok(Self { pool: builder.build()? })
187    }
188
189    /// Returns the total number of threads in the underlying rayon pool.
190    pub fn current_num_threads(&self) -> usize {
191        self.pool.current_num_threads()
192    }
193
194    /// Initializes per-thread [`Worker`] state on every thread in the pool.
195    pub fn init<T: 'static>(&self, f: impl Fn(Option<&mut T>) -> T + Sync) {
196        self.broadcast(self.pool.current_num_threads(), |worker| {
197            worker.init::<T>(&f);
198        });
199    }
200
201    /// Runs a closure on `num_threads` threads in the pool, giving mutable access to each
202    /// thread's [`Worker`].
203    ///
204    /// Use this to initialize or re-initialize per-thread state via [`Worker::init`].
205    /// Only `num_threads` threads execute the closure; the rest skip it.
206    pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
207        if num_threads >= self.pool.current_num_threads() {
208            // Fast path: run on every thread, no atomic coordination needed.
209            self.pool.broadcast(|_| {
210                WORKER.with_borrow_mut(|worker| f(worker));
211            });
212        } else {
213            let remaining = AtomicUsize::new(num_threads);
214            self.pool.broadcast(|_| {
215                // Atomically claim a slot; threads that can't decrement skip the closure.
216                let mut current = remaining.load(Ordering::Relaxed);
217                loop {
218                    if current == 0 {
219                        return;
220                    }
221                    match remaining.compare_exchange_weak(
222                        current,
223                        current - 1,
224                        Ordering::Relaxed,
225                        Ordering::Relaxed,
226                    ) {
227                        Ok(_) => break,
228                        Err(actual) => current = actual,
229                    }
230                }
231                WORKER.with_borrow_mut(|worker| f(worker));
232            });
233        }
234    }
235
236    /// Clears the state on every thread in the pool.
237    pub fn clear(&self) {
238        self.pool.broadcast(|_| {
239            WORKER.with_borrow_mut(Worker::clear);
240        });
241    }
242
243    /// Runs a closure on the pool with access to the calling thread's [`Worker`].
244    ///
245    /// All rayon parallelism (e.g. `par_iter`) spawned inside the closure executes on this pool.
246    /// Each thread can access its own [`Worker`] via the provided reference or through additional
247    /// [`WorkerPool::with_worker`] calls.
248    pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
249        self.pool.install(|| WORKER.with_borrow(|worker| f(worker)))
250    }
251
252    /// Runs a closure on the pool without worker state access.
253    ///
254    /// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`]
255    /// state.
256    pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
257        self.pool.install(f)
258    }
259
260    /// Spawns a closure on the pool.
261    pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
262        self.pool.spawn(f);
263    }
264
265    /// Executes `f` on this pool using [`rayon::in_place_scope`], which converts the calling
266    /// thread into a worker for the duration — tasks spawned inside the scope run on the pool
267    /// and the call blocks until all of them complete.
268    pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R {
269        self.pool.in_place_scope(f)
270    }
271
272    /// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure.
273    ///
274    /// This is useful for accessing the worker from inside `par_iter` where the initial `&Worker`
275    /// reference from `install` belongs to a different thread.
276    pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
277        WORKER.with_borrow(|worker| f(worker))
278    }
279
280    /// Mutably access the current thread's [`Worker`] from within a pool closure.
281    pub fn with_worker_mut<R>(f: impl FnOnce(&mut Worker) -> R) -> R {
282        WORKER.with_borrow_mut(|worker| f(worker))
283    }
284}
285
286/// Per-thread state container for a [`WorkerPool`].
287///
288/// Holds a type-erased `Box<dyn Any>` that can be initialized and accessed with concrete types
289/// via [`init`](Self::init) and [`get`](Self::get).
290#[derive(Debug, Default)]
291pub struct Worker {
292    state: Option<Box<dyn Any>>,
293}
294
295impl Worker {
296    /// Creates a new empty `Worker`.
297    const fn new() -> Self {
298        Self { state: None }
299    }
300
301    /// Initializes the worker state.
302    ///
303    /// If state of type `T` already exists, passes `Some(&mut T)` to the closure so resources
304    /// can be reused. On first init, passes `None`.
305    pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
306        let existing =
307            self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
308
309        let new_state = match existing {
310            Some(mut boxed) => {
311                let r = boxed.downcast_mut::<T>().expect("type checked above");
312                *r = f(Some(r));
313                boxed
314            }
315            None => Box::new(f(None)),
316        };
317
318        self.state = Some(new_state);
319    }
320
321    /// Returns a reference to the state, downcasted to `T`.
322    ///
323    /// # Panics
324    ///
325    /// Panics if the worker has not been initialized or if the type does not match.
326    pub fn get<T: 'static>(&self) -> &T {
327        self.state
328            .as_ref()
329            .expect("worker not initialized")
330            .downcast_ref::<T>()
331            .expect("worker state type mismatch")
332    }
333
334    /// Returns a mutable reference to the state, downcasted to `T`.
335    ///
336    /// # Panics
337    ///
338    /// Panics if the worker has not been initialized or if the type does not match.
339    pub fn get_mut<T: 'static>(&mut self) -> &mut T {
340        self.state
341            .as_mut()
342            .expect("worker not initialized")
343            .downcast_mut::<T>()
344            .expect("worker state type mismatch")
345    }
346
347    /// Returns a mutable reference to the state, initializing it with `f` on first access.
348    ///
349    /// # Panics
350    ///
351    /// Panics if the state was previously initialized with a different type.
352    pub fn get_or_init<T: 'static>(&mut self, f: impl FnOnce() -> T) -> &mut T {
353        self.state
354            .get_or_insert_with(|| Box::new(f()))
355            .downcast_mut::<T>()
356            .expect("worker state type mismatch")
357    }
358
359    /// Clears the worker state, dropping the contained value.
360    pub fn clear(&mut self) {
361        self.state = None;
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[tokio::test]
370    async fn blocking_pool() {
371        let pool = BlockingTaskPool::build().unwrap();
372        let res = pool.spawn(move || 5);
373        let res = res.await.unwrap();
374        assert_eq!(res, 5);
375    }
376
377    #[tokio::test]
378    async fn blocking_pool_panic() {
379        let pool = BlockingTaskPool::build().unwrap();
380        let res = pool.spawn(move || -> i32 {
381            panic!();
382        });
383        let res = res.await;
384        assert!(res.is_err());
385    }
386
387    #[test]
388    fn worker_pool_init_and_access() {
389        let pool = WorkerPool::new(2).unwrap();
390
391        pool.broadcast(2, |worker| {
392            worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
393        });
394
395        let sum: u8 = pool.install(|worker| {
396            let v = worker.get::<Vec<u8>>();
397            v.iter().sum()
398        });
399        assert_eq!(sum, 6);
400
401        pool.clear();
402    }
403
404    #[test]
405    fn worker_pool_reinit_reuses_resources() {
406        let pool = WorkerPool::new(1).unwrap();
407
408        pool.broadcast(1, |worker| {
409            worker.init::<Vec<u8>>(|existing| {
410                assert!(existing.is_none());
411                vec![1, 2, 3]
412            });
413        });
414
415        pool.broadcast(1, |worker| {
416            worker.init::<Vec<u8>>(|existing| {
417                let v = existing.expect("should have existing state");
418                assert_eq!(v, &mut vec![1, 2, 3]);
419                v.push(4);
420                std::mem::take(v)
421            });
422        });
423
424        let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
425        assert_eq!(len, 4);
426
427        pool.clear();
428    }
429
430    #[test]
431    fn worker_pool_clear_and_reinit() {
432        let pool = WorkerPool::new(1).unwrap();
433
434        pool.broadcast(1, |worker| {
435            worker.init::<u64>(|_| 42);
436        });
437        let val = pool.install(|worker| *worker.get::<u64>());
438        assert_eq!(val, 42);
439
440        pool.clear();
441
442        pool.broadcast(1, |worker| {
443            worker.init::<String>(|_| "hello".to_string());
444        });
445        let val = pool.install(|worker| worker.get::<String>().clone());
446        assert_eq!(val, "hello");
447
448        pool.clear();
449    }
450
451    #[test]
452    fn worker_pool_par_iter_with_worker() {
453        use rayon::prelude::*;
454
455        let pool = WorkerPool::new(2).unwrap();
456
457        pool.broadcast(2, |worker| {
458            worker.init::<u64>(|_| 10);
459        });
460
461        let results: Vec<u64> = pool.install(|_| {
462            (0u64..4)
463                .into_par_iter()
464                .map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
465                .collect()
466        });
467        assert_eq!(results, vec![10, 11, 12, 13]);
468
469        pool.clear();
470    }
471}