Skip to main content

reth_tasks/
worker_map.rs

1//! A map of named single-thread worker pools.
2//!
3//! Each worker is a dedicated OS thread that processes closures sent to it via a channel.
4//! This is a substitute for `spawn_blocking` that reuses the same OS thread for the same
5//! named task, like a 1-thread thread pool keyed by name.
6
7use dashmap::DashMap;
8use std::{panic::AssertUnwindSafe, thread};
9use tokio::sync::{mpsc, oneshot};
10
11type BoxedTask = Box<dyn FnOnce() + Send + 'static>;
12
13/// A single-thread worker that processes closures sequentially on a dedicated OS thread.
14struct WorkerThread {
15    /// Sender to submit work to this worker's thread.
16    tx: mpsc::UnboundedSender<BoxedTask>,
17    /// The OS thread handle. Taken during shutdown to join.
18    handle: Option<thread::JoinHandle<()>>,
19}
20
21impl WorkerThread {
22    /// Spawns a new worker thread with the given name.
23    fn new(name: &'static str) -> Self {
24        let (tx, mut rx) = mpsc::unbounded_channel::<BoxedTask>();
25        let handle = thread::Builder::new()
26            .name(name.to_string())
27            .spawn(move || {
28                while let Some(task) = rx.blocking_recv() {
29                    let _ = std::panic::catch_unwind(AssertUnwindSafe(task));
30                }
31            })
32            .unwrap_or_else(|e| panic!("failed to spawn worker thread {name:?}: {e}"));
33
34        Self { tx, handle: Some(handle) }
35    }
36}
37
38/// A map of named single-thread workers.
39///
40/// Each unique name gets a dedicated OS thread that is reused for all tasks submitted under
41/// that name. Workers are created lazily on first use.
42pub(crate) struct WorkerMap {
43    workers: DashMap<&'static str, WorkerThread>,
44}
45
46impl Default for WorkerMap {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl WorkerMap {
53    /// Creates a new empty `WorkerMap`.
54    pub(crate) fn new() -> Self {
55        Self { workers: DashMap::new() }
56    }
57
58    /// Spawns a closure on the dedicated worker thread for the given name.
59    ///
60    /// If no worker thread exists for this name yet, one is created with the given name as
61    /// the OS thread name. The closure executes on the worker's OS thread and the returned
62    /// future resolves with the result.
63    pub(crate) fn spawn_on<F, R>(&self, name: &'static str, f: F) -> oneshot::Receiver<R>
64    where
65        F: FnOnce() -> R + Send + 'static,
66        R: Send + 'static,
67    {
68        let (result_tx, result_rx) = oneshot::channel();
69
70        let task: BoxedTask = Box::new(move || {
71            let _ = result_tx.send(f());
72        });
73
74        let worker = self.workers.entry(name).or_insert_with(|| WorkerThread::new(name));
75        let _ = worker.tx.send(task);
76
77        result_rx
78    }
79}
80
81impl Drop for WorkerMap {
82    fn drop(&mut self) {
83        for (_, mut w) in std::mem::take(&mut self.workers) {
84            // Drop sender so the thread's recv loop exits, then join.
85            drop(w.tx);
86            if let Some(handle) = w.handle.take() {
87                let _ = handle.join();
88            }
89        }
90    }
91}
92
93impl std::fmt::Debug for WorkerMap {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.debug_struct("WorkerMap").field("num_workers", &self.workers.len()).finish()
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[tokio::test]
104    async fn worker_map_basic() {
105        let map = WorkerMap::new();
106
107        let result = map.spawn_on("test", || 42).await.unwrap();
108        assert_eq!(result, 42);
109    }
110
111    #[tokio::test]
112    async fn worker_map_same_thread() {
113        let map = WorkerMap::new();
114
115        let id1 = map.spawn_on("test", || thread::current().id()).await.unwrap();
116        let id2 = map.spawn_on("test", || thread::current().id()).await.unwrap();
117        assert_eq!(id1, id2, "same name should run on the same thread");
118    }
119
120    #[tokio::test]
121    async fn worker_map_different_names_different_threads() {
122        let map = WorkerMap::new();
123
124        let id1 = map.spawn_on("worker-a", || thread::current().id()).await.unwrap();
125        let id2 = map.spawn_on("worker-b", || thread::current().id()).await.unwrap();
126        assert_ne!(id1, id2, "different names should run on different threads");
127    }
128
129    #[tokio::test]
130    async fn worker_map_sequential_execution() {
131        use std::sync::{
132            atomic::{AtomicUsize, Ordering},
133            Arc,
134        };
135
136        let map = WorkerMap::new();
137        let counter = Arc::new(AtomicUsize::new(0));
138
139        let mut receivers = Vec::new();
140        for i in 0..10 {
141            let c = counter.clone();
142            let rx = map.spawn_on("sequential", move || {
143                let val = c.fetch_add(1, Ordering::SeqCst);
144                assert_eq!(val, i, "tasks should execute in order");
145                val
146            });
147            receivers.push(rx);
148        }
149
150        for (i, rx) in receivers.into_iter().enumerate() {
151            let val = rx.await.unwrap();
152            assert_eq!(val, i);
153        }
154    }
155
156    #[tokio::test]
157    async fn worker_map_thread_name() {
158        let map = WorkerMap::new();
159
160        let name = map
161            .spawn_on("custom-worker", || thread::current().name().unwrap().to_string())
162            .await
163            .unwrap();
164        assert_eq!(name, "custom-worker");
165    }
166}