reth_tasks/pool.rs
1//! Additional helpers for executing tracing calls
2
3use std::{
4 future::Future,
5 panic::{catch_unwind, AssertUnwindSafe},
6 pin::Pin,
7 sync::Arc,
8 task::{ready, Context, Poll},
9 thread,
10};
11use tokio::sync::{oneshot, AcquireError, OwnedSemaphorePermit, Semaphore};
12
13/// RPC Tracing call guard semaphore.
14///
15/// This is used to restrict the number of concurrent RPC requests to tracing methods like
16/// `debug_traceTransaction` as well as `eth_getProof` because they can consume a lot of
17/// memory and CPU.
18///
19/// This types serves as an entry guard for the [`BlockingTaskPool`] and is used to rate limit
20/// parallel blocking tasks in the pool.
21#[derive(Clone, Debug)]
22pub struct BlockingTaskGuard(Arc<Semaphore>);
23
24impl BlockingTaskGuard {
25 /// Create a new `BlockingTaskGuard` with the given maximum number of blocking tasks in
26 /// parallel.
27 pub fn new(max_blocking_tasks: usize) -> Self {
28 Self(Arc::new(Semaphore::new(max_blocking_tasks)))
29 }
30
31 /// See also [`Semaphore::acquire_owned`]
32 pub async fn acquire_owned(self) -> Result<OwnedSemaphorePermit, AcquireError> {
33 self.0.acquire_owned().await
34 }
35
36 /// See also [`Semaphore::acquire_many_owned`]
37 pub async fn acquire_many_owned(self, n: u32) -> Result<OwnedSemaphorePermit, AcquireError> {
38 self.0.acquire_many_owned(n).await
39 }
40}
41
42/// Used to execute blocking tasks on a rayon threadpool from within a tokio runtime.
43///
44/// This is a dedicated threadpool for blocking tasks which are CPU bound.
45/// RPC calls that perform blocking IO (disk lookups) are not executed on this pool but on the tokio
46/// runtime's blocking pool, which performs poorly with CPU bound tasks (see
47/// <https://ryhl.io/blog/async-what-is-blocking/>). Once the tokio blocking
48/// pool is saturated it is converted into a queue, blocking tasks could then interfere with the
49/// queue and block other RPC calls.
50///
51/// See also [tokio-docs] for more information.
52///
53/// [tokio-docs]: https://docs.rs/tokio/latest/tokio/index.html#cpu-bound-tasks-and-blocking-code
54#[derive(Clone, Debug)]
55pub struct BlockingTaskPool {
56 pool: Arc<rayon::ThreadPool>,
57}
58
59impl BlockingTaskPool {
60 /// Create a new `BlockingTaskPool` with the given threadpool.
61 pub fn new(pool: rayon::ThreadPool) -> Self {
62 Self { pool: Arc::new(pool) }
63 }
64
65 /// Convenience function to start building a new threadpool.
66 pub fn builder() -> rayon::ThreadPoolBuilder {
67 rayon::ThreadPoolBuilder::new()
68 }
69
70 /// Convenience function to build a new threadpool with the default configuration.
71 ///
72 /// Uses [`rayon::ThreadPoolBuilder::build`](rayon::ThreadPoolBuilder::build) defaults.
73 /// If a different stack size or other parameters are needed, they can be configured via
74 /// [`rayon::ThreadPoolBuilder`] returned by [`Self::builder`].
75 pub fn build() -> Result<Self, rayon::ThreadPoolBuildError> {
76 Self::builder().build().map(Self::new)
77 }
78
79 /// Asynchronous wrapper around Rayon's
80 /// [`ThreadPool::spawn`](rayon::ThreadPool::spawn).
81 ///
82 /// Runs a function on the configured threadpool, returning a future that resolves with the
83 /// function's return value.
84 ///
85 /// If the function panics, the future will resolve to an error.
86 pub fn spawn<F, R>(&self, func: F) -> BlockingTaskHandle<R>
87 where
88 F: FnOnce() -> R + Send + 'static,
89 R: Send + 'static,
90 {
91 let (tx, rx) = oneshot::channel();
92
93 self.pool.spawn(move || {
94 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
95 });
96
97 BlockingTaskHandle { rx }
98 }
99
100 /// Asynchronous wrapper around Rayon's
101 /// [`ThreadPool::spawn_fifo`](rayon::ThreadPool::spawn_fifo).
102 ///
103 /// Runs a function on the configured threadpool, returning a future that resolves with the
104 /// function's return value.
105 ///
106 /// If the function panics, the future will resolve to an error.
107 pub fn spawn_fifo<F, R>(&self, func: F) -> BlockingTaskHandle<R>
108 where
109 F: FnOnce() -> R + Send + 'static,
110 R: Send + 'static,
111 {
112 let (tx, rx) = oneshot::channel();
113
114 self.pool.spawn_fifo(move || {
115 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
116 });
117
118 BlockingTaskHandle { rx }
119 }
120}
121
122/// Async handle for a blocking task running in a Rayon thread pool.
123///
124/// ## Panics
125///
126/// If polled from outside a tokio runtime.
127#[derive(Debug)]
128#[must_use = "futures do nothing unless you `.await` or poll them"]
129#[pin_project::pin_project]
130pub struct BlockingTaskHandle<T> {
131 #[pin]
132 pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
133}
134
135impl<T> Future for BlockingTaskHandle<T> {
136 type Output = thread::Result<T>;
137
138 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139 match ready!(self.project().rx.poll(cx)) {
140 Ok(res) => Poll::Ready(res),
141 Err(_) => Poll::Ready(Err(Box::<TokioBlockingTaskError>::default())),
142 }
143 }
144}
145
146/// An error returned when the Tokio channel is dropped while awaiting a result.
147///
148/// This should only happen
149#[derive(Debug, Default, thiserror::Error)]
150#[error("tokio channel dropped while awaiting result")]
151#[non_exhaustive]
152pub struct TokioBlockingTaskError;
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[tokio::test]
159 async fn blocking_pool() {
160 let pool = BlockingTaskPool::build().unwrap();
161 let res = pool.spawn(move || 5);
162 let res = res.await.unwrap();
163 assert_eq!(res, 5);
164 }
165
166 #[tokio::test]
167 async fn blocking_pool_panic() {
168 let pool = BlockingTaskPool::build().unwrap();
169 let res = pool.spawn(move || -> i32 {
170 panic!();
171 });
172 let res = res.await;
173 assert!(res.is_err());
174 }
175}