1use std::{
4 any::Any,
5 cell::RefCell,
6 future::Future,
7 panic::{catch_unwind, AssertUnwindSafe},
8 pin::Pin,
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc,
12 },
13 task::{ready, Context, Poll},
14 thread,
15};
16use tokio::sync::{oneshot, AcquireError, OwnedSemaphorePermit, Semaphore};
17
18#[derive(Clone, Debug)]
27pub struct BlockingTaskGuard(Arc<Semaphore>);
28
29impl BlockingTaskGuard {
30 pub fn new(max_blocking_tasks: usize) -> Self {
33 Self(Arc::new(Semaphore::new(max_blocking_tasks)))
34 }
35
36 pub async fn acquire_owned(self) -> Result<OwnedSemaphorePermit, AcquireError> {
38 self.0.acquire_owned().await
39 }
40
41 pub async fn acquire_many_owned(self, n: u32) -> Result<OwnedSemaphorePermit, AcquireError> {
43 self.0.acquire_many_owned(n).await
44 }
45}
46
47#[derive(Clone, Debug)]
60pub struct BlockingTaskPool {
61 pool: Arc<rayon::ThreadPool>,
62}
63
64impl BlockingTaskPool {
65 pub fn new(pool: rayon::ThreadPool) -> Self {
67 Self { pool: Arc::new(pool) }
68 }
69
70 pub fn builder() -> rayon::ThreadPoolBuilder {
72 rayon::ThreadPoolBuilder::new()
73 }
74
75 pub fn build() -> Result<Self, rayon::ThreadPoolBuildError> {
81 Self::builder().build().map(Self::new)
82 }
83
84 pub fn spawn<F, R>(&self, func: F) -> BlockingTaskHandle<R>
92 where
93 F: FnOnce() -> R + Send + 'static,
94 R: Send + 'static,
95 {
96 let (tx, rx) = oneshot::channel();
97
98 self.pool.spawn(move || {
99 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
100 });
101
102 BlockingTaskHandle { rx }
103 }
104
105 pub fn spawn_fifo<F, R>(&self, func: F) -> BlockingTaskHandle<R>
113 where
114 F: FnOnce() -> R + Send + 'static,
115 R: Send + 'static,
116 {
117 let (tx, rx) = oneshot::channel();
118
119 self.pool.spawn_fifo(move || {
120 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
121 });
122
123 BlockingTaskHandle { rx }
124 }
125}
126
127#[derive(Debug)]
133#[must_use = "futures do nothing unless you `.await` or poll them"]
134#[pin_project::pin_project]
135pub struct BlockingTaskHandle<T> {
136 #[pin]
137 pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
138}
139
140impl<T> Future for BlockingTaskHandle<T> {
141 type Output = thread::Result<T>;
142
143 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144 match ready!(self.project().rx.poll(cx)) {
145 Ok(res) => Poll::Ready(res),
146 Err(_) => Poll::Ready(Err(Box::<TokioBlockingTaskError>::default())),
147 }
148 }
149}
150
151#[derive(Debug, Default, thiserror::Error)]
155#[error("tokio channel dropped while awaiting result")]
156#[non_exhaustive]
157pub struct TokioBlockingTaskError;
158
159thread_local! {
160 static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
161}
162
163#[derive(Debug)]
172pub struct WorkerPool {
173 pool: rayon::ThreadPool,
174}
175
176impl WorkerPool {
177 pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
179 Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads))
180 }
181
182 pub fn from_builder(
184 builder: rayon::ThreadPoolBuilder,
185 ) -> Result<Self, rayon::ThreadPoolBuildError> {
186 Ok(Self { pool: builder.build()? })
187 }
188
189 pub fn current_num_threads(&self) -> usize {
191 self.pool.current_num_threads()
192 }
193
194 pub fn init<T: 'static>(&self, f: impl Fn(Option<&mut T>) -> T + Sync) {
196 self.broadcast(self.pool.current_num_threads(), |worker| {
197 worker.init::<T>(&f);
198 });
199 }
200
201 pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
207 if num_threads >= self.pool.current_num_threads() {
208 self.pool.broadcast(|_| {
210 WORKER.with_borrow_mut(|worker| f(worker));
211 });
212 } else {
213 let remaining = AtomicUsize::new(num_threads);
214 self.pool.broadcast(|_| {
215 let mut current = remaining.load(Ordering::Relaxed);
217 loop {
218 if current == 0 {
219 return;
220 }
221 match remaining.compare_exchange_weak(
222 current,
223 current - 1,
224 Ordering::Relaxed,
225 Ordering::Relaxed,
226 ) {
227 Ok(_) => break,
228 Err(actual) => current = actual,
229 }
230 }
231 WORKER.with_borrow_mut(|worker| f(worker));
232 });
233 }
234 }
235
236 pub fn clear(&self) {
238 self.pool.broadcast(|_| {
239 WORKER.with_borrow_mut(Worker::clear);
240 });
241 }
242
243 pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
249 self.pool.install(|| WORKER.with_borrow(|worker| f(worker)))
250 }
251
252 pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
257 self.pool.install(f)
258 }
259
260 pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
262 self.pool.spawn(f);
263 }
264
265 pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R {
269 self.pool.in_place_scope(f)
270 }
271
272 pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
277 WORKER.with_borrow(|worker| f(worker))
278 }
279
280 pub fn with_worker_mut<R>(f: impl FnOnce(&mut Worker) -> R) -> R {
282 WORKER.with_borrow_mut(|worker| f(worker))
283 }
284}
285
286#[derive(Debug, Default)]
291pub struct Worker {
292 state: Option<Box<dyn Any>>,
293}
294
295impl Worker {
296 const fn new() -> Self {
298 Self { state: None }
299 }
300
301 pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
306 let existing =
307 self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
308
309 let new_state = match existing {
310 Some(mut boxed) => {
311 let r = boxed.downcast_mut::<T>().expect("type checked above");
312 *r = f(Some(r));
313 boxed
314 }
315 None => Box::new(f(None)),
316 };
317
318 self.state = Some(new_state);
319 }
320
321 pub fn get<T: 'static>(&self) -> &T {
327 self.state
328 .as_ref()
329 .expect("worker not initialized")
330 .downcast_ref::<T>()
331 .expect("worker state type mismatch")
332 }
333
334 pub fn get_mut<T: 'static>(&mut self) -> &mut T {
340 self.state
341 .as_mut()
342 .expect("worker not initialized")
343 .downcast_mut::<T>()
344 .expect("worker state type mismatch")
345 }
346
347 pub fn get_or_init<T: 'static>(&mut self, f: impl FnOnce() -> T) -> &mut T {
353 self.state
354 .get_or_insert_with(|| Box::new(f()))
355 .downcast_mut::<T>()
356 .expect("worker state type mismatch")
357 }
358
359 pub fn clear(&mut self) {
361 self.state = None;
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[tokio::test]
370 async fn blocking_pool() {
371 let pool = BlockingTaskPool::build().unwrap();
372 let res = pool.spawn(move || 5);
373 let res = res.await.unwrap();
374 assert_eq!(res, 5);
375 }
376
377 #[tokio::test]
378 async fn blocking_pool_panic() {
379 let pool = BlockingTaskPool::build().unwrap();
380 let res = pool.spawn(move || -> i32 {
381 panic!();
382 });
383 let res = res.await;
384 assert!(res.is_err());
385 }
386
387 #[test]
388 fn worker_pool_init_and_access() {
389 let pool = WorkerPool::new(2).unwrap();
390
391 pool.broadcast(2, |worker| {
392 worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
393 });
394
395 let sum: u8 = pool.install(|worker| {
396 let v = worker.get::<Vec<u8>>();
397 v.iter().sum()
398 });
399 assert_eq!(sum, 6);
400
401 pool.clear();
402 }
403
404 #[test]
405 fn worker_pool_reinit_reuses_resources() {
406 let pool = WorkerPool::new(1).unwrap();
407
408 pool.broadcast(1, |worker| {
409 worker.init::<Vec<u8>>(|existing| {
410 assert!(existing.is_none());
411 vec![1, 2, 3]
412 });
413 });
414
415 pool.broadcast(1, |worker| {
416 worker.init::<Vec<u8>>(|existing| {
417 let v = existing.expect("should have existing state");
418 assert_eq!(v, &mut vec![1, 2, 3]);
419 v.push(4);
420 std::mem::take(v)
421 });
422 });
423
424 let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
425 assert_eq!(len, 4);
426
427 pool.clear();
428 }
429
430 #[test]
431 fn worker_pool_clear_and_reinit() {
432 let pool = WorkerPool::new(1).unwrap();
433
434 pool.broadcast(1, |worker| {
435 worker.init::<u64>(|_| 42);
436 });
437 let val = pool.install(|worker| *worker.get::<u64>());
438 assert_eq!(val, 42);
439
440 pool.clear();
441
442 pool.broadcast(1, |worker| {
443 worker.init::<String>(|_| "hello".to_string());
444 });
445 let val = pool.install(|worker| worker.get::<String>().clone());
446 assert_eq!(val, "hello");
447
448 pool.clear();
449 }
450
451 #[test]
452 fn worker_pool_par_iter_with_worker() {
453 use rayon::prelude::*;
454
455 let pool = WorkerPool::new(2).unwrap();
456
457 pool.broadcast(2, |worker| {
458 worker.init::<u64>(|_| 10);
459 });
460
461 let results: Vec<u64> = pool.install(|_| {
462 (0u64..4)
463 .into_par_iter()
464 .map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
465 .collect()
466 });
467 assert_eq!(results, vec![10, 11, 12, 13]);
468
469 pool.clear();
470 }
471}