1use std::{
4 any::Any,
5 cell::RefCell,
6 future::Future,
7 panic::{catch_unwind, AssertUnwindSafe},
8 pin::Pin,
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc, OnceLock,
12 },
13 task::{ready, Context, Poll},
14 thread,
15};
16use tokio::sync::{oneshot, AcquireError, OwnedSemaphorePermit, Semaphore};
17
18#[derive(Clone, Debug)]
27pub struct BlockingTaskGuard(Arc<Semaphore>);
28
29impl BlockingTaskGuard {
30 pub fn new(max_blocking_tasks: usize) -> Self {
33 Self(Arc::new(Semaphore::new(max_blocking_tasks)))
34 }
35
36 pub async fn acquire_owned(self) -> Result<OwnedSemaphorePermit, AcquireError> {
38 self.0.acquire_owned().await
39 }
40
41 pub async fn acquire_many_owned(self, n: u32) -> Result<OwnedSemaphorePermit, AcquireError> {
43 self.0.acquire_many_owned(n).await
44 }
45}
46
47#[derive(Clone, Debug)]
60pub struct BlockingTaskPool {
61 pool: Arc<rayon::ThreadPool>,
62}
63
64impl BlockingTaskPool {
65 pub fn new(pool: rayon::ThreadPool) -> Self {
67 Self { pool: Arc::new(pool) }
68 }
69
70 pub fn builder() -> rayon::ThreadPoolBuilder {
72 rayon::ThreadPoolBuilder::new()
73 }
74
75 pub fn build() -> Result<Self, rayon::ThreadPoolBuildError> {
81 Self::builder().build().map(Self::new)
82 }
83
84 pub fn spawn<F, R>(&self, func: F) -> BlockingTaskHandle<R>
92 where
93 F: FnOnce() -> R + Send + 'static,
94 R: Send + 'static,
95 {
96 let (tx, rx) = oneshot::channel();
97
98 self.pool.spawn(move || {
99 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
100 });
101
102 BlockingTaskHandle { rx }
103 }
104
105 pub fn spawn_fifo<F, R>(&self, func: F) -> BlockingTaskHandle<R>
113 where
114 F: FnOnce() -> R + Send + 'static,
115 R: Send + 'static,
116 {
117 let (tx, rx) = oneshot::channel();
118
119 self.pool.spawn_fifo(move || {
120 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
121 });
122
123 BlockingTaskHandle { rx }
124 }
125}
126
127#[derive(Debug)]
133#[must_use = "futures do nothing unless you `.await` or poll them"]
134#[pin_project::pin_project]
135pub struct BlockingTaskHandle<T> {
136 #[pin]
137 pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
138}
139
140impl<T> Future for BlockingTaskHandle<T> {
141 type Output = thread::Result<T>;
142
143 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144 match ready!(self.project().rx.poll(cx)) {
145 Ok(res) => Poll::Ready(res),
146 Err(_) => Poll::Ready(Err(Box::<TokioBlockingTaskError>::default())),
147 }
148 }
149}
150
151#[derive(Debug, Default, thiserror::Error)]
155#[error("tokio channel dropped while awaiting result")]
156#[non_exhaustive]
157pub struct TokioBlockingTaskError;
158
159thread_local! {
160 static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
161}
162
163#[derive(Debug)]
174pub struct WorkerPool {
175 pool: OnceLock<rayon::ThreadPool>,
176 num_threads: usize,
177 thread_name_prefix: &'static str,
178}
179
180impl WorkerPool {
181 pub const fn new(num_threads: usize, thread_name_prefix: &'static str) -> Self {
186 Self { pool: OnceLock::new(), num_threads, thread_name_prefix }
187 }
188
189 fn pool(&self) -> &rayon::ThreadPool {
191 self.pool.get_or_init(|| {
192 let prefix = self.thread_name_prefix;
193 build_pool_with_panic_handler(
194 rayon::ThreadPoolBuilder::new()
195 .num_threads(self.num_threads)
196 .thread_name(move |i| format!("{prefix}-{i:02}")),
197 )
198 .unwrap_or_else(|err| panic!("failed to build {prefix} worker pool: {err}"))
199 })
200 }
201
202 pub fn is_initialized(&self) -> bool {
204 self.pool.get().is_some()
205 }
206
207 pub fn current_num_threads(&self) -> usize {
209 self.pool().current_num_threads()
210 }
211
212 pub fn init<T: 'static>(&self, f: impl Fn(Option<&mut T>) -> T + Sync) {
214 self.broadcast(self.pool().current_num_threads(), |worker| {
215 worker.init::<T>(&f);
216 });
217 }
218
219 pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
225 if num_threads >= self.pool().current_num_threads() {
226 self.pool().broadcast(|_| {
228 WORKER.with_borrow_mut(|worker| f(worker));
229 });
230 } else {
231 let remaining = AtomicUsize::new(num_threads);
232 self.pool().broadcast(|_| {
233 let mut current = remaining.load(Ordering::Relaxed);
235 loop {
236 if current == 0 {
237 return;
238 }
239 match remaining.compare_exchange_weak(
240 current,
241 current - 1,
242 Ordering::Relaxed,
243 Ordering::Relaxed,
244 ) {
245 Ok(_) => break,
246 Err(actual) => current = actual,
247 }
248 }
249 WORKER.with_borrow_mut(|worker| f(worker));
250 });
251 }
252 }
253
254 pub fn clear(&self) {
256 self.pool().broadcast(|_| {
257 WORKER.with_borrow_mut(Worker::clear);
258 });
259 }
260
261 pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
267 self.pool().install(|| WORKER.with_borrow(|worker| f(worker)))
268 }
269
270 pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
275 self.pool().install(f)
276 }
277
278 pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
280 self.pool().spawn(f);
281 }
282
283 pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R {
287 self.pool().in_place_scope(f)
288 }
289
290 pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
295 WORKER.with_borrow(|worker| f(worker))
296 }
297
298 pub fn with_worker_mut<R>(f: impl FnOnce(&mut Worker) -> R) -> R {
300 WORKER.with_borrow_mut(|worker| f(worker))
301 }
302}
303
304pub fn build_pool_with_panic_handler(
309 builder: rayon::ThreadPoolBuilder,
310) -> Result<rayon::ThreadPool, rayon::ThreadPoolBuildError> {
311 builder.panic_handler(|_| {}).build()
312}
313
314#[derive(Debug, Default)]
319pub struct Worker {
320 state: Option<Box<dyn Any>>,
321}
322
323impl Worker {
324 const fn new() -> Self {
326 Self { state: None }
327 }
328
329 pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
334 let existing =
335 self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
336
337 let new_state = match existing {
338 Some(mut boxed) => {
339 let r = boxed.downcast_mut::<T>().expect("type checked above");
340 *r = f(Some(r));
341 boxed
342 }
343 None => Box::new(f(None)),
344 };
345
346 self.state = Some(new_state);
347 }
348
349 pub fn get<T: 'static>(&self) -> &T {
355 self.state
356 .as_ref()
357 .expect("worker not initialized")
358 .downcast_ref::<T>()
359 .expect("worker state type mismatch")
360 }
361
362 pub fn get_mut<T: 'static>(&mut self) -> &mut T {
368 self.state
369 .as_mut()
370 .expect("worker not initialized")
371 .downcast_mut::<T>()
372 .expect("worker state type mismatch")
373 }
374
375 pub fn get_or_init<T: 'static>(&mut self, f: impl FnOnce() -> T) -> &mut T {
381 self.state
382 .get_or_insert_with(|| Box::new(f()))
383 .downcast_mut::<T>()
384 .expect("worker state type mismatch")
385 }
386
387 pub fn clear(&mut self) {
389 self.state = None;
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[tokio::test]
398 async fn blocking_pool() {
399 let pool = BlockingTaskPool::build().unwrap();
400 let res = pool.spawn(move || 5);
401 let res = res.await.unwrap();
402 assert_eq!(res, 5);
403 }
404
405 #[tokio::test]
406 async fn blocking_pool_panic() {
407 let pool = BlockingTaskPool::build().unwrap();
408 let res = pool.spawn(move || -> i32 {
409 panic!();
410 });
411 let res = res.await;
412 assert!(res.is_err());
413 }
414
415 #[test]
416 fn worker_pool_init_and_access() {
417 let pool = WorkerPool::new(2, "test");
418
419 pool.broadcast(2, |worker| {
420 worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
421 });
422
423 let sum: u8 = pool.install(|worker| {
424 let v = worker.get::<Vec<u8>>();
425 v.iter().sum()
426 });
427 assert_eq!(sum, 6);
428
429 pool.clear();
430 }
431
432 #[test]
433 fn worker_pool_reinit_reuses_resources() {
434 let pool = WorkerPool::new(1, "test");
435
436 pool.broadcast(1, |worker| {
437 worker.init::<Vec<u8>>(|existing| {
438 assert!(existing.is_none());
439 vec![1, 2, 3]
440 });
441 });
442
443 pool.broadcast(1, |worker| {
444 worker.init::<Vec<u8>>(|existing| {
445 let v = existing.expect("should have existing state");
446 assert_eq!(v, &mut vec![1, 2, 3]);
447 v.push(4);
448 std::mem::take(v)
449 });
450 });
451
452 let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
453 assert_eq!(len, 4);
454
455 pool.clear();
456 }
457
458 #[test]
459 fn worker_pool_clear_and_reinit() {
460 let pool = WorkerPool::new(1, "test");
461
462 pool.broadcast(1, |worker| {
463 worker.init::<u64>(|_| 42);
464 });
465 let val = pool.install(|worker| *worker.get::<u64>());
466 assert_eq!(val, 42);
467
468 pool.clear();
469
470 pool.broadcast(1, |worker| {
471 worker.init::<String>(|_| "hello".to_string());
472 });
473 let val = pool.install(|worker| worker.get::<String>().clone());
474 assert_eq!(val, "hello");
475
476 pool.clear();
477 }
478
479 #[test]
480 fn worker_pool_par_iter_with_worker() {
481 use rayon::prelude::*;
482
483 let pool = WorkerPool::new(2, "test");
484
485 pool.broadcast(2, |worker| {
486 worker.init::<u64>(|_| 10);
487 });
488
489 let results: Vec<u64> = pool.install(|_| {
490 (0u64..4)
491 .into_par_iter()
492 .map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
493 .collect()
494 });
495 assert_eq!(results, vec![10, 11, 12, 13]);
496
497 pool.clear();
498 }
499}