1use std::{
4 any::Any,
5 cell::RefCell,
6 future::Future,
7 panic::{catch_unwind, AssertUnwindSafe},
8 pin::Pin,
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc,
12 },
13 task::{ready, Context, Poll},
14 thread,
15};
16use tokio::sync::{oneshot, AcquireError, OwnedSemaphorePermit, Semaphore};
17
18#[derive(Clone, Debug)]
27pub struct BlockingTaskGuard(Arc<Semaphore>);
28
29impl BlockingTaskGuard {
30 pub fn new(max_blocking_tasks: usize) -> Self {
33 Self(Arc::new(Semaphore::new(max_blocking_tasks)))
34 }
35
36 pub async fn acquire_owned(self) -> Result<OwnedSemaphorePermit, AcquireError> {
38 self.0.acquire_owned().await
39 }
40
41 pub async fn acquire_many_owned(self, n: u32) -> Result<OwnedSemaphorePermit, AcquireError> {
43 self.0.acquire_many_owned(n).await
44 }
45}
46
47#[derive(Clone, Debug)]
60pub struct BlockingTaskPool {
61 pool: Arc<rayon::ThreadPool>,
62}
63
64impl BlockingTaskPool {
65 pub fn new(pool: rayon::ThreadPool) -> Self {
67 Self { pool: Arc::new(pool) }
68 }
69
70 pub fn builder() -> rayon::ThreadPoolBuilder {
72 rayon::ThreadPoolBuilder::new()
73 }
74
75 pub fn build() -> Result<Self, rayon::ThreadPoolBuildError> {
81 Self::builder().build().map(Self::new)
82 }
83
84 pub fn spawn<F, R>(&self, func: F) -> BlockingTaskHandle<R>
92 where
93 F: FnOnce() -> R + Send + 'static,
94 R: Send + 'static,
95 {
96 let (tx, rx) = oneshot::channel();
97
98 self.pool.spawn(move || {
99 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
100 });
101
102 BlockingTaskHandle { rx }
103 }
104
105 pub fn spawn_fifo<F, R>(&self, func: F) -> BlockingTaskHandle<R>
113 where
114 F: FnOnce() -> R + Send + 'static,
115 R: Send + 'static,
116 {
117 let (tx, rx) = oneshot::channel();
118
119 self.pool.spawn_fifo(move || {
120 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
121 });
122
123 BlockingTaskHandle { rx }
124 }
125}
126
127#[derive(Debug)]
133#[must_use = "futures do nothing unless you `.await` or poll them"]
134#[pin_project::pin_project]
135pub struct BlockingTaskHandle<T> {
136 #[pin]
137 pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
138}
139
140impl<T> Future for BlockingTaskHandle<T> {
141 type Output = thread::Result<T>;
142
143 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144 match ready!(self.project().rx.poll(cx)) {
145 Ok(res) => Poll::Ready(res),
146 Err(_) => Poll::Ready(Err(Box::<TokioBlockingTaskError>::default())),
147 }
148 }
149}
150
151#[derive(Debug, Default, thiserror::Error)]
155#[error("tokio channel dropped while awaiting result")]
156#[non_exhaustive]
157pub struct TokioBlockingTaskError;
158
159thread_local! {
160 static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
161}
162
163#[derive(Debug)]
172pub struct WorkerPool {
173 pool: rayon::ThreadPool,
174}
175
176impl WorkerPool {
177 pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
179 Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads))
180 }
181
182 pub fn from_builder(
186 builder: rayon::ThreadPoolBuilder,
187 ) -> Result<Self, rayon::ThreadPoolBuildError> {
188 Ok(Self { pool: build_pool_with_panic_handler(builder)? })
189 }
190
191 pub fn current_num_threads(&self) -> usize {
193 self.pool.current_num_threads()
194 }
195
196 pub fn init<T: 'static>(&self, f: impl Fn(Option<&mut T>) -> T + Sync) {
198 self.broadcast(self.pool.current_num_threads(), |worker| {
199 worker.init::<T>(&f);
200 });
201 }
202
203 pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
209 if num_threads >= self.pool.current_num_threads() {
210 self.pool.broadcast(|_| {
212 WORKER.with_borrow_mut(|worker| f(worker));
213 });
214 } else {
215 let remaining = AtomicUsize::new(num_threads);
216 self.pool.broadcast(|_| {
217 let mut current = remaining.load(Ordering::Relaxed);
219 loop {
220 if current == 0 {
221 return;
222 }
223 match remaining.compare_exchange_weak(
224 current,
225 current - 1,
226 Ordering::Relaxed,
227 Ordering::Relaxed,
228 ) {
229 Ok(_) => break,
230 Err(actual) => current = actual,
231 }
232 }
233 WORKER.with_borrow_mut(|worker| f(worker));
234 });
235 }
236 }
237
238 pub fn clear(&self) {
240 self.pool.broadcast(|_| {
241 WORKER.with_borrow_mut(Worker::clear);
242 });
243 }
244
245 pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
251 self.pool.install(|| WORKER.with_borrow(|worker| f(worker)))
252 }
253
254 pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
259 self.pool.install(f)
260 }
261
262 pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
264 self.pool.spawn(f);
265 }
266
267 pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R {
271 self.pool.in_place_scope(f)
272 }
273
274 pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
279 WORKER.with_borrow(|worker| f(worker))
280 }
281
282 pub fn with_worker_mut<R>(f: impl FnOnce(&mut Worker) -> R) -> R {
284 WORKER.with_borrow_mut(|worker| f(worker))
285 }
286}
287
288pub fn build_pool_with_panic_handler(
293 builder: rayon::ThreadPoolBuilder,
294) -> Result<rayon::ThreadPool, rayon::ThreadPoolBuildError> {
295 builder.panic_handler(|_| {}).build()
296}
297
298#[derive(Debug, Default)]
303pub struct Worker {
304 state: Option<Box<dyn Any>>,
305}
306
307impl Worker {
308 const fn new() -> Self {
310 Self { state: None }
311 }
312
313 pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
318 let existing =
319 self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
320
321 let new_state = match existing {
322 Some(mut boxed) => {
323 let r = boxed.downcast_mut::<T>().expect("type checked above");
324 *r = f(Some(r));
325 boxed
326 }
327 None => Box::new(f(None)),
328 };
329
330 self.state = Some(new_state);
331 }
332
333 pub fn get<T: 'static>(&self) -> &T {
339 self.state
340 .as_ref()
341 .expect("worker not initialized")
342 .downcast_ref::<T>()
343 .expect("worker state type mismatch")
344 }
345
346 pub fn get_mut<T: 'static>(&mut self) -> &mut T {
352 self.state
353 .as_mut()
354 .expect("worker not initialized")
355 .downcast_mut::<T>()
356 .expect("worker state type mismatch")
357 }
358
359 pub fn get_or_init<T: 'static>(&mut self, f: impl FnOnce() -> T) -> &mut T {
365 self.state
366 .get_or_insert_with(|| Box::new(f()))
367 .downcast_mut::<T>()
368 .expect("worker state type mismatch")
369 }
370
371 pub fn clear(&mut self) {
373 self.state = None;
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[tokio::test]
382 async fn blocking_pool() {
383 let pool = BlockingTaskPool::build().unwrap();
384 let res = pool.spawn(move || 5);
385 let res = res.await.unwrap();
386 assert_eq!(res, 5);
387 }
388
389 #[tokio::test]
390 async fn blocking_pool_panic() {
391 let pool = BlockingTaskPool::build().unwrap();
392 let res = pool.spawn(move || -> i32 {
393 panic!();
394 });
395 let res = res.await;
396 assert!(res.is_err());
397 }
398
399 #[test]
400 fn worker_pool_init_and_access() {
401 let pool = WorkerPool::new(2).unwrap();
402
403 pool.broadcast(2, |worker| {
404 worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
405 });
406
407 let sum: u8 = pool.install(|worker| {
408 let v = worker.get::<Vec<u8>>();
409 v.iter().sum()
410 });
411 assert_eq!(sum, 6);
412
413 pool.clear();
414 }
415
416 #[test]
417 fn worker_pool_reinit_reuses_resources() {
418 let pool = WorkerPool::new(1).unwrap();
419
420 pool.broadcast(1, |worker| {
421 worker.init::<Vec<u8>>(|existing| {
422 assert!(existing.is_none());
423 vec![1, 2, 3]
424 });
425 });
426
427 pool.broadcast(1, |worker| {
428 worker.init::<Vec<u8>>(|existing| {
429 let v = existing.expect("should have existing state");
430 assert_eq!(v, &mut vec![1, 2, 3]);
431 v.push(4);
432 std::mem::take(v)
433 });
434 });
435
436 let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
437 assert_eq!(len, 4);
438
439 pool.clear();
440 }
441
442 #[test]
443 fn worker_pool_clear_and_reinit() {
444 let pool = WorkerPool::new(1).unwrap();
445
446 pool.broadcast(1, |worker| {
447 worker.init::<u64>(|_| 42);
448 });
449 let val = pool.install(|worker| *worker.get::<u64>());
450 assert_eq!(val, 42);
451
452 pool.clear();
453
454 pool.broadcast(1, |worker| {
455 worker.init::<String>(|_| "hello".to_string());
456 });
457 let val = pool.install(|worker| worker.get::<String>().clone());
458 assert_eq!(val, "hello");
459
460 pool.clear();
461 }
462
463 #[test]
464 fn worker_pool_par_iter_with_worker() {
465 use rayon::prelude::*;
466
467 let pool = WorkerPool::new(2).unwrap();
468
469 pool.broadcast(2, |worker| {
470 worker.init::<u64>(|_| 10);
471 });
472
473 let results: Vec<u64> = pool.install(|_| {
474 (0u64..4)
475 .into_par_iter()
476 .map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
477 .collect()
478 });
479 assert_eq!(results, vec![10, 11, 12, 13]);
480
481 pool.clear();
482 }
483}