1use dashmap::DashMap;
8use std::{panic::AssertUnwindSafe, thread};
9use tokio::sync::{mpsc, oneshot};
10
11type BoxedTask = Box<dyn FnOnce() + Send + 'static>;
12
13struct WorkerThread {
15 tx: mpsc::UnboundedSender<BoxedTask>,
17 handle: Option<thread::JoinHandle<()>>,
19}
20
21impl WorkerThread {
22 fn new(name: &'static str) -> Self {
24 let (tx, mut rx) = mpsc::unbounded_channel::<BoxedTask>();
25 let handle = thread::Builder::new()
26 .name(name.to_string())
27 .spawn(move || {
28 while let Some(task) = rx.blocking_recv() {
29 let _ = std::panic::catch_unwind(AssertUnwindSafe(task));
30 }
31 })
32 .unwrap_or_else(|e| panic!("failed to spawn worker thread {name:?}: {e}"));
33
34 Self { tx, handle: Some(handle) }
35 }
36}
37
38pub(crate) struct WorkerMap {
43 workers: DashMap<&'static str, WorkerThread>,
44}
45
46impl Default for WorkerMap {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl WorkerMap {
53 pub(crate) fn new() -> Self {
55 Self { workers: DashMap::new() }
56 }
57
58 pub(crate) fn spawn_on<F, R>(&self, name: &'static str, f: F) -> oneshot::Receiver<R>
64 where
65 F: FnOnce() -> R + Send + 'static,
66 R: Send + 'static,
67 {
68 let (result_tx, result_rx) = oneshot::channel();
69
70 let task: BoxedTask = Box::new(move || {
71 let _ = result_tx.send(f());
72 });
73
74 let worker = self.workers.entry(name).or_insert_with(|| WorkerThread::new(name));
75 let _ = worker.tx.send(task);
76
77 result_rx
78 }
79}
80
81impl Drop for WorkerMap {
82 fn drop(&mut self) {
83 for (_, mut w) in std::mem::take(&mut self.workers) {
84 drop(w.tx);
86 if let Some(handle) = w.handle.take() {
87 let _ = handle.join();
88 }
89 }
90 }
91}
92
93impl std::fmt::Debug for WorkerMap {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("WorkerMap").field("num_workers", &self.workers.len()).finish()
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[tokio::test]
104 async fn worker_map_basic() {
105 let map = WorkerMap::new();
106
107 let result = map.spawn_on("test", || 42).await.unwrap();
108 assert_eq!(result, 42);
109 }
110
111 #[tokio::test]
112 async fn worker_map_same_thread() {
113 let map = WorkerMap::new();
114
115 let id1 = map.spawn_on("test", || thread::current().id()).await.unwrap();
116 let id2 = map.spawn_on("test", || thread::current().id()).await.unwrap();
117 assert_eq!(id1, id2, "same name should run on the same thread");
118 }
119
120 #[tokio::test]
121 async fn worker_map_different_names_different_threads() {
122 let map = WorkerMap::new();
123
124 let id1 = map.spawn_on("worker-a", || thread::current().id()).await.unwrap();
125 let id2 = map.spawn_on("worker-b", || thread::current().id()).await.unwrap();
126 assert_ne!(id1, id2, "different names should run on different threads");
127 }
128
129 #[tokio::test]
130 async fn worker_map_sequential_execution() {
131 use std::sync::{
132 atomic::{AtomicUsize, Ordering},
133 Arc,
134 };
135
136 let map = WorkerMap::new();
137 let counter = Arc::new(AtomicUsize::new(0));
138
139 let mut receivers = Vec::new();
140 for i in 0..10 {
141 let c = counter.clone();
142 let rx = map.spawn_on("sequential", move || {
143 let val = c.fetch_add(1, Ordering::SeqCst);
144 assert_eq!(val, i, "tasks should execute in order");
145 val
146 });
147 receivers.push(rx);
148 }
149
150 for (i, rx) in receivers.into_iter().enumerate() {
151 let val = rx.await.unwrap();
152 assert_eq!(val, i);
153 }
154 }
155
156 #[tokio::test]
157 async fn worker_map_thread_name() {
158 let map = WorkerMap::new();
159
160 let name = map
161 .spawn_on("custom-worker", || thread::current().name().unwrap().to_string())
162 .await
163 .unwrap();
164 assert_eq!(name, "custom-worker");
165 }
166}