Skip to main content

reth_tasks/
for_each_ordered.rs

1use crossbeam_utils::CachePadded;
2use parking_lot::{Condvar, Mutex};
3use rayon::iter::{IndexedParallelIterator, ParallelIterator};
4use std::sync::atomic::{AtomicBool, Ordering};
5
6/// Extension trait for [`IndexedParallelIterator`]
7/// that streams results to a sequential consumer in index order.
8pub trait ForEachOrdered: IndexedParallelIterator {
9    /// Executes the parallel iterator, calling `f` on each result **sequentially in index
10    /// order**.
11    ///
12    /// Items are computed in parallel, but `f` is invoked as `f(item_0)`, `f(item_1)`, …,
13    /// `f(item_{n-1})` on the calling thread. The calling thread receives each item as soon
14    /// as it (and all preceding items) are ready.
15    ///
16    /// `f` does **not** need to be [`Send`] — it runs exclusively on the calling thread.
17    ///
18    /// # Blocking
19    ///
20    /// The calling thread blocks (via [`Condvar`]) while waiting for the next item to become
21    /// ready. It does **not** participate in rayon's work-stealing while blocked. Callers
22    /// should invoke this from a dedicated blocking thread (e.g. via
23    /// [`tokio::task::spawn_blocking`]) rather than from within the rayon thread pool.
24    fn for_each_ordered<F>(self, f: F)
25    where
26        Self::Item: Send,
27        F: FnMut(Self::Item);
28
29    /// Like [`for_each_ordered`](Self::for_each_ordered), but runs the parallel work on the
30    /// given `pool` instead of the global rayon thread pool.
31    fn for_each_ordered_in<F>(self, pool: &rayon::ThreadPool, f: F)
32    where
33        Self::Item: Send,
34        F: FnMut(Self::Item);
35}
36
37impl<I: IndexedParallelIterator> ForEachOrdered for I {
38    fn for_each_ordered<F>(self, f: F)
39    where
40        Self::Item: Send,
41        F: FnMut(Self::Item),
42    {
43        ordered_impl(self, None, f);
44    }
45
46    fn for_each_ordered_in<F>(self, pool: &rayon::ThreadPool, f: F)
47    where
48        Self::Item: Send,
49        F: FnMut(Self::Item),
50    {
51        ordered_impl(self, Some(pool), f);
52    }
53}
54
55/// A slot holding an optional value and a condvar for notification.
56struct Slot<T> {
57    value: Mutex<Option<T>>,
58    notify: Condvar,
59}
60
61impl<T> Slot<T> {
62    const fn new() -> Self {
63        Self { value: Mutex::new(None), notify: Condvar::new() }
64    }
65}
66
67struct Shared<T> {
68    slots: Box<[CachePadded<Slot<T>>]>,
69    panicked: AtomicBool,
70}
71
72impl<T> Shared<T> {
73    fn new(n: usize) -> Self {
74        let slots =
75            (0..n).map(|_| CachePadded::new(Slot::new())).collect::<Vec<_>>().into_boxed_slice();
76        Self { slots, panicked: AtomicBool::new(false) }
77    }
78
79    /// Writes a value into slot `i`. Must only be called once per index.
80    #[inline]
81    fn write(&self, i: usize, val: T) {
82        let slot = &self.slots[i];
83        *slot.value.lock() = Some(val);
84        slot.notify.notify_one();
85    }
86
87    /// Blocks until slot `i` is ready and takes the value.
88    /// Returns `None` if the producer panicked.
89    fn take(&self, i: usize) -> Option<T> {
90        let slot = &self.slots[i];
91        let mut guard = slot.value.lock();
92        loop {
93            if let Some(val) = guard.take() {
94                return Some(val);
95            }
96            if self.panicked.load(Ordering::Acquire) {
97                return None;
98            }
99            slot.notify.wait(&mut guard);
100        }
101    }
102}
103
104/// Executes a parallel iterator and delivers results to a sequential callback in index order.
105///
106/// Each slot has its own [`Condvar`], so the consumer blocks precisely on the slot it needs
107/// with zero spurious wakeups.
108fn ordered_impl<I, F>(iter: I, pool: Option<&rayon::ThreadPool>, mut f: F)
109where
110    I: IndexedParallelIterator,
111    I::Item: Send,
112    F: FnMut(I::Item),
113{
114    use std::panic::{catch_unwind, AssertUnwindSafe};
115
116    let n = iter.len();
117    if n == 0 {
118        return;
119    }
120
121    let shared = Shared::<I::Item>::new(n);
122
123    in_place_scope_in(pool, |s| {
124        // Producer: compute items in parallel and write them into their slots.
125        s.spawn(|_| {
126            let res = catch_unwind(AssertUnwindSafe(|| {
127                iter.enumerate().for_each(|(i, item)| {
128                    shared.write(i, item);
129                });
130            }));
131            if let Err(payload) = res {
132                shared.panicked.store(true, Ordering::Release);
133                // Wake all slots so the consumer doesn't hang. Lock each slot's mutex
134                // first to serialize with the consumer's panicked check → wait sequence.
135                for slot in &*shared.slots {
136                    let _guard = slot.value.lock();
137                    slot.notify.notify_one();
138                }
139                std::panic::resume_unwind(payload);
140            }
141        });
142
143        // Consumer: sequential, ordered, on the calling thread.
144        for i in 0..n {
145            let Some(value) = shared.take(i) else {
146                return;
147            };
148            f(value);
149        }
150    });
151}
152
153fn in_place_scope_in<'scope, F, R>(pool: Option<&rayon::ThreadPool>, f: F)
154where
155    F: FnOnce(&rayon::Scope<'scope>) -> R,
156{
157    if let Some(pool) = pool {
158        pool.in_place_scope(f);
159    } else {
160        rayon::in_place_scope(f);
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use rayon::prelude::*;
168    use std::sync::{
169        atomic::{AtomicUsize, Ordering},
170        Barrier,
171    };
172
173    #[test]
174    fn preserves_order() {
175        let input: Vec<u64> = (0..1000).collect();
176        let mut output = Vec::with_capacity(input.len());
177        input.par_iter().map(|x| x * 2).for_each_ordered(|x| output.push(x));
178        let expected: Vec<u64> = (0..1000).map(|x| x * 2).collect();
179        assert_eq!(output, expected);
180    }
181
182    #[test]
183    fn empty_iterator() {
184        let input: Vec<u64> = vec![];
185        let mut output = Vec::new();
186        input.par_iter().map(|x| *x).for_each_ordered(|x| output.push(x));
187        assert!(output.is_empty());
188    }
189
190    #[test]
191    fn single_element() {
192        let mut output = Vec::new();
193        vec![42u64].par_iter().map(|x| *x).for_each_ordered(|x| output.push(x));
194        assert_eq!(output, vec![42]);
195    }
196
197    #[test]
198    fn slow_early_items_still_delivered_in_order() {
199        let barrier = Barrier::new(2);
200        let n = 64usize;
201        let input: Vec<usize> = (0..n).collect();
202        let mut output = Vec::with_capacity(n);
203
204        input
205            .par_iter()
206            .map(|&i| {
207                if i == 0 || i == n - 1 {
208                    barrier.wait();
209                }
210                i
211            })
212            .for_each_ordered(|x| output.push(x));
213
214        assert_eq!(output, input);
215    }
216
217    #[test]
218    fn drops_unconsumed_slots_on_panic() {
219        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
220
221        #[derive(Clone)]
222        struct Tracked(#[allow(dead_code)] u64);
223        impl Drop for Tracked {
224            fn drop(&mut self) {
225                DROP_COUNT.fetch_add(1, Ordering::Relaxed);
226            }
227        }
228
229        DROP_COUNT.store(0, Ordering::Relaxed);
230
231        let input: Vec<u64> = (0..100).collect();
232        let result = std::panic::catch_unwind(|| {
233            input
234                .par_iter()
235                .map(|&i| {
236                    assert!(i != 50, "intentional");
237                    Tracked(i)
238                })
239                .for_each_ordered(|_item| {});
240        });
241
242        assert!(result.is_err());
243        let drops = DROP_COUNT.load(Ordering::Relaxed);
244        assert!(drops > 0, "some items should have been dropped");
245    }
246
247    #[test]
248    fn no_double_drop() {
249        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
250
251        struct Counted(#[allow(dead_code)] u64);
252        impl Drop for Counted {
253            fn drop(&mut self) {
254                DROP_COUNT.fetch_add(1, Ordering::Relaxed);
255            }
256        }
257
258        DROP_COUNT.store(0, Ordering::Relaxed);
259        let n = 200u64;
260        let input: Vec<u64> = (0..n).collect();
261        input.par_iter().map(|&i| Counted(i)).for_each_ordered(|_item| {});
262
263        assert_eq!(DROP_COUNT.load(Ordering::Relaxed), n as usize);
264    }
265
266    #[test]
267    fn callback_is_not_send() {
268        use std::rc::Rc;
269        let counter = Rc::new(std::cell::Cell::new(0u64));
270        let input: Vec<u64> = (0..100).collect();
271        input.par_iter().map(|&x| x).for_each_ordered(|x| {
272            counter.set(counter.get() + x);
273        });
274        assert_eq!(counter.get(), (0..100u64).sum::<u64>());
275    }
276
277    #[test]
278    fn panic_does_not_deadlock_consumer() {
279        // Regression test: producer panics while consumer is waiting on a condvar.
280        // Without the lock-before-notify fix, the consumer could miss the wakeup
281        // and deadlock.
282        for _ in 0..100 {
283            let result = std::panic::catch_unwind(|| {
284                let input: Vec<usize> = (0..256).collect();
285                input
286                    .par_iter()
287                    .map(|&i| {
288                        if i == 128 {
289                            // Yield to increase the chance of the consumer being
290                            // between the panicked check and condvar wait.
291                            std::thread::yield_now();
292                            panic!("intentional");
293                        }
294                        i
295                    })
296                    .for_each_ordered(|_| {});
297            });
298            assert!(result.is_err());
299        }
300    }
301
302    #[test]
303    fn early_panic_at_item_zero() {
304        let result = std::panic::catch_unwind(|| {
305            let input: Vec<u64> = (0..10).collect();
306            input
307                .par_iter()
308                .map(|&i| {
309                    assert!(i != 0, "boom at zero");
310                    i
311                })
312                .for_each_ordered(|_| {});
313        });
314        assert!(result.is_err());
315    }
316
317    #[test]
318    fn late_panic_at_last_item() {
319        let n = 100usize;
320        let result = std::panic::catch_unwind(|| {
321            let input: Vec<usize> = (0..n).collect();
322            input
323                .par_iter()
324                .map(|&i| {
325                    assert!(i != n - 1, "boom at last");
326                    i
327                })
328                .for_each_ordered(|_| {});
329        });
330        assert!(result.is_err());
331    }
332
333    #[test]
334    fn large_items() {
335        let n = 500usize;
336        let input: Vec<usize> = (0..n).collect();
337        let mut output = Vec::with_capacity(n);
338        input
339            .par_iter()
340            .map(|&i| {
341                // Return a heap-allocated value to stress drop semantics.
342                vec![i; 64]
343            })
344            .for_each_ordered(|v| output.push(v[0]));
345        assert_eq!(output, input);
346    }
347
348    #[test]
349    fn consumer_slower_than_producer() {
350        // Producer is fast, consumer is slow. All items should still arrive in order.
351        let n = 64usize;
352        let input: Vec<usize> = (0..n).collect();
353        let mut output = Vec::with_capacity(n);
354        input.par_iter().map(|&i| i).for_each_ordered(|x| {
355            if x % 8 == 0 {
356                std::thread::yield_now();
357            }
358            output.push(x);
359        });
360        assert_eq!(output, input);
361    }
362
363    #[test]
364    fn concurrent_panic_and_drop_no_leak() {
365        // Ensure items produced before a panic are all dropped.
366        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
367        static PRODUCED: AtomicUsize = AtomicUsize::new(0);
368
369        struct Tracked;
370        impl Drop for Tracked {
371            fn drop(&mut self) {
372                DROP_COUNT.fetch_add(1, Ordering::Relaxed);
373            }
374        }
375
376        DROP_COUNT.store(0, Ordering::Relaxed);
377        PRODUCED.store(0, Ordering::Relaxed);
378
379        let barrier = Barrier::new(2);
380        let result = std::panic::catch_unwind(|| {
381            let input: Vec<usize> = (0..64).collect();
382            input
383                .par_iter()
384                .map(|&i| {
385                    if i == 32 {
386                        barrier.wait();
387                        panic!("intentional");
388                    }
389                    if i == 0 {
390                        barrier.wait();
391                    }
392                    PRODUCED.fetch_add(1, Ordering::Relaxed);
393                    Tracked
394                })
395                .for_each_ordered(|_| {});
396        });
397
398        assert!(result.is_err());
399        let produced = PRODUCED.load(Ordering::Relaxed);
400        let dropped = DROP_COUNT.load(Ordering::Relaxed);
401        assert_eq!(dropped, produced, "all produced items must be dropped");
402    }
403}