1use dashmap::DashMap;
8use std::{
9 panic::AssertUnwindSafe,
10 sync::{
11 atomic::{AtomicUsize, Ordering},
12 Arc,
13 },
14 thread,
15};
16use tokio::sync::{mpsc, oneshot};
17
18type BoxedTask = Box<dyn FnOnce() + Send + 'static>;
19
20struct WorkerThread {
22 tx: mpsc::UnboundedSender<BoxedTask>,
24 pending: Arc<AtomicUsize>,
26 handle: Option<thread::JoinHandle<()>>,
28}
29
30impl WorkerThread {
31 fn new(name: &'static str) -> Self {
33 let (tx, mut rx) = mpsc::unbounded_channel::<BoxedTask>();
34 let handle = thread::Builder::new()
35 .name(name.to_string())
36 .spawn(move || {
37 while let Some(task) = rx.blocking_recv() {
38 let _ = std::panic::catch_unwind(AssertUnwindSafe(task));
39 }
40 })
41 .unwrap_or_else(|e| panic!("failed to spawn worker thread {name:?}: {e}"));
42
43 Self { tx, pending: Arc::new(AtomicUsize::new(0)), handle: Some(handle) }
44 }
45
46 fn spawn<F, R>(&self, f: F) -> oneshot::Receiver<R>
48 where
49 F: FnOnce() -> R + Send + 'static,
50 R: Send + 'static,
51 {
52 self.pending.fetch_add(1, Ordering::AcqRel);
53
54 let (result_tx, result_rx) = oneshot::channel();
55 let pending = self.pending.clone();
56
57 let task: BoxedTask = Box::new(move || {
58 let _decrement_pending = DecrementPendingOnDrop(pending);
59 let _ = result_tx.send(f());
60 });
61
62 if self.tx.send(task).is_err() {
63 self.pending.fetch_sub(1, Ordering::AcqRel);
64 }
65
66 result_rx
67 }
68
69 fn try_spawn<F, R>(&self, f: F) -> Option<oneshot::Receiver<R>>
71 where
72 F: FnOnce() -> R + Send + 'static,
73 R: Send + 'static,
74 {
75 self.pending.compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire).ok()?;
76
77 let (result_tx, result_rx) = oneshot::channel();
78 let pending = self.pending.clone();
79
80 let task: BoxedTask = Box::new(move || {
81 let _decrement_pending = DecrementPendingOnDrop(pending);
82 let _ = result_tx.send(f());
83 });
84
85 if self.tx.send(task).is_err() {
86 self.pending.fetch_sub(1, Ordering::AcqRel);
87 return None
88 }
89
90 Some(result_rx)
91 }
92}
93
94struct DecrementPendingOnDrop(Arc<AtomicUsize>);
96
97impl Drop for DecrementPendingOnDrop {
98 fn drop(&mut self) {
99 self.0.fetch_sub(1, Ordering::AcqRel);
100 }
101}
102
103pub(crate) struct WorkerMap {
108 workers: DashMap<&'static str, WorkerThread>,
109}
110
111impl Default for WorkerMap {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl WorkerMap {
118 pub(crate) fn new() -> Self {
120 Self { workers: DashMap::new() }
121 }
122
123 pub(crate) fn spawn_on<F, R>(&self, name: &'static str, f: F) -> oneshot::Receiver<R>
129 where
130 F: FnOnce() -> R + Send + 'static,
131 R: Send + 'static,
132 {
133 let worker = self.workers.entry(name).or_insert_with(|| WorkerThread::new(name));
134 worker.spawn(f)
135 }
136
137 pub(crate) fn try_spawn_on<F, R>(
141 &self,
142 name: &'static str,
143 f: F,
144 ) -> Option<oneshot::Receiver<R>>
145 where
146 F: FnOnce() -> R + Send + 'static,
147 R: Send + 'static,
148 {
149 let worker = self.workers.entry(name).or_insert_with(|| WorkerThread::new(name));
150 worker.try_spawn(f)
151 }
152}
153
154impl Drop for WorkerMap {
155 fn drop(&mut self) {
156 for (_, mut w) in std::mem::take(&mut self.workers) {
157 drop(w.tx);
159 if let Some(handle) = w.handle.take() {
160 let _ = handle.join();
161 }
162 }
163 }
164}
165
166impl std::fmt::Debug for WorkerMap {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("WorkerMap").field("num_workers", &self.workers.len()).finish()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[tokio::test]
177 async fn worker_map_basic() {
178 let map = WorkerMap::new();
179
180 let result = map.spawn_on("test", || 42).await.unwrap();
181 assert_eq!(result, 42);
182 }
183
184 #[tokio::test]
185 async fn worker_map_same_thread() {
186 let map = WorkerMap::new();
187
188 let id1 = map.spawn_on("test", || thread::current().id()).await.unwrap();
189 let id2 = map.spawn_on("test", || thread::current().id()).await.unwrap();
190 assert_eq!(id1, id2, "same name should run on the same thread");
191 }
192
193 #[tokio::test]
194 async fn worker_map_different_names_different_threads() {
195 let map = WorkerMap::new();
196
197 let id1 = map.spawn_on("worker-a", || thread::current().id()).await.unwrap();
198 let id2 = map.spawn_on("worker-b", || thread::current().id()).await.unwrap();
199 assert_ne!(id1, id2, "different names should run on different threads");
200 }
201
202 #[tokio::test]
203 async fn worker_map_sequential_execution() {
204 use std::sync::{
205 atomic::{AtomicUsize, Ordering},
206 Arc,
207 };
208
209 let map = WorkerMap::new();
210 let counter = Arc::new(AtomicUsize::new(0));
211
212 let mut receivers = Vec::new();
213 for i in 0..10 {
214 let c = counter.clone();
215 let rx = map.spawn_on("sequential", move || {
216 let val = c.fetch_add(1, Ordering::SeqCst);
217 assert_eq!(val, i, "tasks should execute in order");
218 val
219 });
220 receivers.push(rx);
221 }
222
223 for (i, rx) in receivers.into_iter().enumerate() {
224 let val = rx.await.unwrap();
225 assert_eq!(val, i);
226 }
227 }
228
229 #[tokio::test]
230 async fn worker_map_thread_name() {
231 let map = WorkerMap::new();
232
233 let name = map
234 .spawn_on("custom-worker", || thread::current().name().unwrap().to_string())
235 .await
236 .unwrap();
237 assert_eq!(name, "custom-worker");
238 }
239
240 #[tokio::test]
241 async fn worker_map_try_spawn_busy() {
242 let map = WorkerMap::new();
243 let (release_tx, release_rx) = std::sync::mpsc::channel();
244
245 let first = map.try_spawn_on("busy-worker", move || {
246 release_rx.recv().unwrap();
247 1
248 });
249 assert!(first.is_some());
250
251 let second = map.try_spawn_on("busy-worker", || 2);
252 assert!(second.is_none(), "busy worker should reject queued work");
253
254 release_tx.send(()).unwrap();
255 assert_eq!(first.unwrap().await.unwrap(), 1);
256
257 let third = map.try_spawn_on("busy-worker", || 3).expect("worker should be idle");
258 assert_eq!(third.await.unwrap(), 3);
259 }
260}