Skip to main content

reth_tasks/
worker_map.rs

1//! A map of named single-thread worker pools.
2//!
3//! Each worker is a dedicated OS thread that processes closures sent to it via a channel.
4//! This is a substitute for `spawn_blocking` that reuses the same OS thread for the same
5//! named task, like a 1-thread thread pool keyed by name.
6
7use dashmap::DashMap;
8use std::{
9    panic::AssertUnwindSafe,
10    sync::{
11        atomic::{AtomicUsize, Ordering},
12        Arc,
13    },
14    thread,
15};
16use tokio::sync::{mpsc, oneshot};
17
18type BoxedTask = Box<dyn FnOnce() + Send + 'static>;
19
20/// A single-thread worker that processes closures sequentially on a dedicated OS thread.
21struct WorkerThread {
22    /// Sender to submit work to this worker's thread.
23    tx: mpsc::UnboundedSender<BoxedTask>,
24    /// Number of tasks currently running or queued on this worker.
25    pending: Arc<AtomicUsize>,
26    /// The OS thread handle. Taken during shutdown to join.
27    handle: Option<thread::JoinHandle<()>>,
28}
29
30impl WorkerThread {
31    /// Spawns a new worker thread with the given name.
32    fn new(name: &'static str) -> Self {
33        let (tx, mut rx) = mpsc::unbounded_channel::<BoxedTask>();
34        let handle = thread::Builder::new()
35            .name(name.to_string())
36            .spawn(move || {
37                while let Some(task) = rx.blocking_recv() {
38                    let _ = std::panic::catch_unwind(AssertUnwindSafe(task));
39                }
40            })
41            .unwrap_or_else(|e| panic!("failed to spawn worker thread {name:?}: {e}"));
42
43        Self { tx, pending: Arc::new(AtomicUsize::new(0)), handle: Some(handle) }
44    }
45
46    /// Spawns a closure on this worker.
47    fn spawn<F, R>(&self, f: F) -> oneshot::Receiver<R>
48    where
49        F: FnOnce() -> R + Send + 'static,
50        R: Send + 'static,
51    {
52        self.pending.fetch_add(1, Ordering::AcqRel);
53
54        let (result_tx, result_rx) = oneshot::channel();
55        let pending = self.pending.clone();
56
57        let task: BoxedTask = Box::new(move || {
58            let _decrement_pending = DecrementPendingOnDrop(pending);
59            let _ = result_tx.send(f());
60        });
61
62        if self.tx.send(task).is_err() {
63            self.pending.fetch_sub(1, Ordering::AcqRel);
64        }
65
66        result_rx
67    }
68
69    /// Attempts to spawn a closure if this worker has no task running or queued.
70    fn try_spawn<F, R>(&self, f: F) -> Option<oneshot::Receiver<R>>
71    where
72        F: FnOnce() -> R + Send + 'static,
73        R: Send + 'static,
74    {
75        self.pending.compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire).ok()?;
76
77        let (result_tx, result_rx) = oneshot::channel();
78        let pending = self.pending.clone();
79
80        let task: BoxedTask = Box::new(move || {
81            let _decrement_pending = DecrementPendingOnDrop(pending);
82            let _ = result_tx.send(f());
83        });
84
85        if self.tx.send(task).is_err() {
86            self.pending.fetch_sub(1, Ordering::AcqRel);
87            return None
88        }
89
90        Some(result_rx)
91    }
92}
93
94/// Decrements a worker's pending task count when a task finishes, including after panic.
95struct DecrementPendingOnDrop(Arc<AtomicUsize>);
96
97impl Drop for DecrementPendingOnDrop {
98    fn drop(&mut self) {
99        self.0.fetch_sub(1, Ordering::AcqRel);
100    }
101}
102
103/// A map of named single-thread workers.
104///
105/// Each unique name gets a dedicated OS thread that is reused for all tasks submitted under
106/// that name. Workers are created lazily on first use.
107pub(crate) struct WorkerMap {
108    workers: DashMap<&'static str, WorkerThread>,
109}
110
111impl Default for WorkerMap {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl WorkerMap {
118    /// Creates a new empty `WorkerMap`.
119    pub(crate) fn new() -> Self {
120        Self { workers: DashMap::new() }
121    }
122
123    /// Spawns a closure on the dedicated worker thread for the given name.
124    ///
125    /// If no worker thread exists for this name yet, one is created with the given name as
126    /// the OS thread name. The closure executes on the worker's OS thread and the returned
127    /// future resolves with the result.
128    pub(crate) fn spawn_on<F, R>(&self, name: &'static str, f: F) -> oneshot::Receiver<R>
129    where
130        F: FnOnce() -> R + Send + 'static,
131        R: Send + 'static,
132    {
133        let worker = self.workers.entry(name).or_insert_with(|| WorkerThread::new(name));
134        worker.spawn(f)
135    }
136
137    /// Attempts to spawn a closure on the dedicated worker thread for the given name.
138    ///
139    /// Returns `None` if the named worker already has a task running or queued.
140    pub(crate) fn try_spawn_on<F, R>(
141        &self,
142        name: &'static str,
143        f: F,
144    ) -> Option<oneshot::Receiver<R>>
145    where
146        F: FnOnce() -> R + Send + 'static,
147        R: Send + 'static,
148    {
149        let worker = self.workers.entry(name).or_insert_with(|| WorkerThread::new(name));
150        worker.try_spawn(f)
151    }
152}
153
154impl Drop for WorkerMap {
155    fn drop(&mut self) {
156        for (_, mut w) in std::mem::take(&mut self.workers) {
157            // Drop sender so the thread's recv loop exits, then join.
158            drop(w.tx);
159            if let Some(handle) = w.handle.take() {
160                let _ = handle.join();
161            }
162        }
163    }
164}
165
166impl std::fmt::Debug for WorkerMap {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("WorkerMap").field("num_workers", &self.workers.len()).finish()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[tokio::test]
177    async fn worker_map_basic() {
178        let map = WorkerMap::new();
179
180        let result = map.spawn_on("test", || 42).await.unwrap();
181        assert_eq!(result, 42);
182    }
183
184    #[tokio::test]
185    async fn worker_map_same_thread() {
186        let map = WorkerMap::new();
187
188        let id1 = map.spawn_on("test", || thread::current().id()).await.unwrap();
189        let id2 = map.spawn_on("test", || thread::current().id()).await.unwrap();
190        assert_eq!(id1, id2, "same name should run on the same thread");
191    }
192
193    #[tokio::test]
194    async fn worker_map_different_names_different_threads() {
195        let map = WorkerMap::new();
196
197        let id1 = map.spawn_on("worker-a", || thread::current().id()).await.unwrap();
198        let id2 = map.spawn_on("worker-b", || thread::current().id()).await.unwrap();
199        assert_ne!(id1, id2, "different names should run on different threads");
200    }
201
202    #[tokio::test]
203    async fn worker_map_sequential_execution() {
204        use std::sync::{
205            atomic::{AtomicUsize, Ordering},
206            Arc,
207        };
208
209        let map = WorkerMap::new();
210        let counter = Arc::new(AtomicUsize::new(0));
211
212        let mut receivers = Vec::new();
213        for i in 0..10 {
214            let c = counter.clone();
215            let rx = map.spawn_on("sequential", move || {
216                let val = c.fetch_add(1, Ordering::SeqCst);
217                assert_eq!(val, i, "tasks should execute in order");
218                val
219            });
220            receivers.push(rx);
221        }
222
223        for (i, rx) in receivers.into_iter().enumerate() {
224            let val = rx.await.unwrap();
225            assert_eq!(val, i);
226        }
227    }
228
229    #[tokio::test]
230    async fn worker_map_thread_name() {
231        let map = WorkerMap::new();
232
233        let name = map
234            .spawn_on("custom-worker", || thread::current().name().unwrap().to_string())
235            .await
236            .unwrap();
237        assert_eq!(name, "custom-worker");
238    }
239
240    #[tokio::test]
241    async fn worker_map_try_spawn_busy() {
242        let map = WorkerMap::new();
243        let (release_tx, release_rx) = std::sync::mpsc::channel();
244
245        let first = map.try_spawn_on("busy-worker", move || {
246            release_rx.recv().unwrap();
247            1
248        });
249        assert!(first.is_some());
250
251        let second = map.try_spawn_on("busy-worker", || 2);
252        assert!(second.is_none(), "busy worker should reject queued work");
253
254        release_tx.send(()).unwrap();
255        assert_eq!(first.unwrap().await.unwrap(), 1);
256
257        let third = map.try_spawn_on("busy-worker", || 3).expect("worker should be idle");
258        assert_eq!(third.await.unwrap(), 3);
259    }
260}