reth_tasks/
for_each_ordered.rs1use crossbeam_utils::CachePadded;
2use parking_lot::{Condvar, Mutex};
3use rayon::iter::{IndexedParallelIterator, ParallelIterator};
4use std::sync::atomic::{AtomicBool, Ordering};
5
6pub trait ForEachOrdered: IndexedParallelIterator {
9 fn for_each_ordered<F>(self, f: F)
25 where
26 Self::Item: Send,
27 F: FnMut(Self::Item);
28
29 fn for_each_ordered_in<F>(self, pool: &rayon::ThreadPool, f: F)
32 where
33 Self::Item: Send,
34 F: FnMut(Self::Item);
35}
36
37impl<I: IndexedParallelIterator> ForEachOrdered for I {
38 fn for_each_ordered<F>(self, f: F)
39 where
40 Self::Item: Send,
41 F: FnMut(Self::Item),
42 {
43 ordered_impl(self, None, f);
44 }
45
46 fn for_each_ordered_in<F>(self, pool: &rayon::ThreadPool, f: F)
47 where
48 Self::Item: Send,
49 F: FnMut(Self::Item),
50 {
51 ordered_impl(self, Some(pool), f);
52 }
53}
54
55struct Slot<T> {
57 value: Mutex<Option<T>>,
58 notify: Condvar,
59}
60
61impl<T> Slot<T> {
62 const fn new() -> Self {
63 Self { value: Mutex::new(None), notify: Condvar::new() }
64 }
65}
66
67struct Shared<T> {
68 slots: Box<[CachePadded<Slot<T>>]>,
69 panicked: AtomicBool,
70}
71
72impl<T> Shared<T> {
73 fn new(n: usize) -> Self {
74 let slots =
75 (0..n).map(|_| CachePadded::new(Slot::new())).collect::<Vec<_>>().into_boxed_slice();
76 Self { slots, panicked: AtomicBool::new(false) }
77 }
78
79 #[inline]
81 fn write(&self, i: usize, val: T) {
82 let slot = &self.slots[i];
83 *slot.value.lock() = Some(val);
84 slot.notify.notify_one();
85 }
86
87 fn take(&self, i: usize) -> Option<T> {
90 let slot = &self.slots[i];
91 let mut guard = slot.value.lock();
92 loop {
93 if let Some(val) = guard.take() {
94 return Some(val);
95 }
96 if self.panicked.load(Ordering::Acquire) {
97 return None;
98 }
99 slot.notify.wait(&mut guard);
100 }
101 }
102}
103
104fn ordered_impl<I, F>(iter: I, pool: Option<&rayon::ThreadPool>, mut f: F)
109where
110 I: IndexedParallelIterator,
111 I::Item: Send,
112 F: FnMut(I::Item),
113{
114 use std::panic::{catch_unwind, AssertUnwindSafe};
115
116 let n = iter.len();
117 if n == 0 {
118 return;
119 }
120
121 let shared = Shared::<I::Item>::new(n);
122
123 in_place_scope_in(pool, |s| {
124 s.spawn(|_| {
126 let res = catch_unwind(AssertUnwindSafe(|| {
127 iter.enumerate().for_each(|(i, item)| {
128 shared.write(i, item);
129 });
130 }));
131 if let Err(payload) = res {
132 shared.panicked.store(true, Ordering::Release);
133 for slot in &*shared.slots {
136 let _guard = slot.value.lock();
137 slot.notify.notify_one();
138 }
139 std::panic::resume_unwind(payload);
140 }
141 });
142
143 for i in 0..n {
145 let Some(value) = shared.take(i) else {
146 return;
147 };
148 f(value);
149 }
150 });
151}
152
153fn in_place_scope_in<'scope, F, R>(pool: Option<&rayon::ThreadPool>, f: F)
154where
155 F: FnOnce(&rayon::Scope<'scope>) -> R,
156{
157 if let Some(pool) = pool {
158 pool.in_place_scope(f);
159 } else {
160 rayon::in_place_scope(f);
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use rayon::prelude::*;
168 use std::sync::{
169 atomic::{AtomicUsize, Ordering},
170 Barrier,
171 };
172
173 #[test]
174 fn preserves_order() {
175 let input: Vec<u64> = (0..1000).collect();
176 let mut output = Vec::with_capacity(input.len());
177 input.par_iter().map(|x| x * 2).for_each_ordered(|x| output.push(x));
178 let expected: Vec<u64> = (0..1000).map(|x| x * 2).collect();
179 assert_eq!(output, expected);
180 }
181
182 #[test]
183 fn empty_iterator() {
184 let input: Vec<u64> = vec![];
185 let mut output = Vec::new();
186 input.par_iter().map(|x| *x).for_each_ordered(|x| output.push(x));
187 assert!(output.is_empty());
188 }
189
190 #[test]
191 fn single_element() {
192 let mut output = Vec::new();
193 vec![42u64].par_iter().map(|x| *x).for_each_ordered(|x| output.push(x));
194 assert_eq!(output, vec![42]);
195 }
196
197 #[test]
198 fn slow_early_items_still_delivered_in_order() {
199 let barrier = Barrier::new(2);
200 let n = 64usize;
201 let input: Vec<usize> = (0..n).collect();
202 let mut output = Vec::with_capacity(n);
203
204 input
205 .par_iter()
206 .map(|&i| {
207 if i == 0 || i == n - 1 {
208 barrier.wait();
209 }
210 i
211 })
212 .for_each_ordered(|x| output.push(x));
213
214 assert_eq!(output, input);
215 }
216
217 #[test]
218 fn drops_unconsumed_slots_on_panic() {
219 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
220
221 #[derive(Clone)]
222 struct Tracked(#[allow(dead_code)] u64);
223 impl Drop for Tracked {
224 fn drop(&mut self) {
225 DROP_COUNT.fetch_add(1, Ordering::Relaxed);
226 }
227 }
228
229 DROP_COUNT.store(0, Ordering::Relaxed);
230
231 let input: Vec<u64> = (0..100).collect();
232 let result = std::panic::catch_unwind(|| {
233 input
234 .par_iter()
235 .map(|&i| {
236 assert!(i != 50, "intentional");
237 Tracked(i)
238 })
239 .for_each_ordered(|_item| {});
240 });
241
242 assert!(result.is_err());
243 let drops = DROP_COUNT.load(Ordering::Relaxed);
244 assert!(drops > 0, "some items should have been dropped");
245 }
246
247 #[test]
248 fn no_double_drop() {
249 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
250
251 struct Counted(#[allow(dead_code)] u64);
252 impl Drop for Counted {
253 fn drop(&mut self) {
254 DROP_COUNT.fetch_add(1, Ordering::Relaxed);
255 }
256 }
257
258 DROP_COUNT.store(0, Ordering::Relaxed);
259 let n = 200u64;
260 let input: Vec<u64> = (0..n).collect();
261 input.par_iter().map(|&i| Counted(i)).for_each_ordered(|_item| {});
262
263 assert_eq!(DROP_COUNT.load(Ordering::Relaxed), n as usize);
264 }
265
266 #[test]
267 fn callback_is_not_send() {
268 use std::rc::Rc;
269 let counter = Rc::new(std::cell::Cell::new(0u64));
270 let input: Vec<u64> = (0..100).collect();
271 input.par_iter().map(|&x| x).for_each_ordered(|x| {
272 counter.set(counter.get() + x);
273 });
274 assert_eq!(counter.get(), (0..100u64).sum::<u64>());
275 }
276
277 #[test]
278 fn panic_does_not_deadlock_consumer() {
279 for _ in 0..100 {
283 let result = std::panic::catch_unwind(|| {
284 let input: Vec<usize> = (0..256).collect();
285 input
286 .par_iter()
287 .map(|&i| {
288 if i == 128 {
289 std::thread::yield_now();
292 panic!("intentional");
293 }
294 i
295 })
296 .for_each_ordered(|_| {});
297 });
298 assert!(result.is_err());
299 }
300 }
301
302 #[test]
303 fn early_panic_at_item_zero() {
304 let result = std::panic::catch_unwind(|| {
305 let input: Vec<u64> = (0..10).collect();
306 input
307 .par_iter()
308 .map(|&i| {
309 assert!(i != 0, "boom at zero");
310 i
311 })
312 .for_each_ordered(|_| {});
313 });
314 assert!(result.is_err());
315 }
316
317 #[test]
318 fn late_panic_at_last_item() {
319 let n = 100usize;
320 let result = std::panic::catch_unwind(|| {
321 let input: Vec<usize> = (0..n).collect();
322 input
323 .par_iter()
324 .map(|&i| {
325 assert!(i != n - 1, "boom at last");
326 i
327 })
328 .for_each_ordered(|_| {});
329 });
330 assert!(result.is_err());
331 }
332
333 #[test]
334 fn large_items() {
335 let n = 500usize;
336 let input: Vec<usize> = (0..n).collect();
337 let mut output = Vec::with_capacity(n);
338 input
339 .par_iter()
340 .map(|&i| {
341 vec![i; 64]
343 })
344 .for_each_ordered(|v| output.push(v[0]));
345 assert_eq!(output, input);
346 }
347
348 #[test]
349 fn consumer_slower_than_producer() {
350 let n = 64usize;
352 let input: Vec<usize> = (0..n).collect();
353 let mut output = Vec::with_capacity(n);
354 input.par_iter().map(|&i| i).for_each_ordered(|x| {
355 if x % 8 == 0 {
356 std::thread::yield_now();
357 }
358 output.push(x);
359 });
360 assert_eq!(output, input);
361 }
362
363 #[test]
364 fn concurrent_panic_and_drop_no_leak() {
365 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
367 static PRODUCED: AtomicUsize = AtomicUsize::new(0);
368
369 struct Tracked;
370 impl Drop for Tracked {
371 fn drop(&mut self) {
372 DROP_COUNT.fetch_add(1, Ordering::Relaxed);
373 }
374 }
375
376 DROP_COUNT.store(0, Ordering::Relaxed);
377 PRODUCED.store(0, Ordering::Relaxed);
378
379 let barrier = Barrier::new(2);
380 let result = std::panic::catch_unwind(|| {
381 let input: Vec<usize> = (0..64).collect();
382 input
383 .par_iter()
384 .map(|&i| {
385 if i == 32 {
386 barrier.wait();
387 panic!("intentional");
388 }
389 if i == 0 {
390 barrier.wait();
391 }
392 PRODUCED.fetch_add(1, Ordering::Relaxed);
393 Tracked
394 })
395 .for_each_ordered(|_| {});
396 });
397
398 assert!(result.is_err());
399 let produced = PRODUCED.load(Ordering::Relaxed);
400 let dropped = DROP_COUNT.load(Ordering::Relaxed);
401 assert_eq!(dropped, produced, "all produced items must be dropped");
402 }
403}