reth_tokio_util/
event_stream.rs

1//! Event streams related functionality.
2
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7use tokio_stream::Stream;
8use tracing::warn;
9
10/// Thin wrapper around tokio's `BroadcastStream` to allow skipping broadcast errors.
11#[derive(Debug)]
12pub struct EventStream<T> {
13    inner: tokio_stream::wrappers::BroadcastStream<T>,
14}
15
16impl<T> EventStream<T>
17where
18    T: Clone + Send + 'static,
19{
20    /// Creates a new `EventStream`.
21    pub fn new(receiver: tokio::sync::broadcast::Receiver<T>) -> Self {
22        let inner = tokio_stream::wrappers::BroadcastStream::new(receiver);
23        Self { inner }
24    }
25}
26
27impl<T> Stream for EventStream<T>
28where
29    T: Clone + Send + 'static,
30{
31    type Item = T;
32
33    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34        loop {
35            match Pin::new(&mut self.inner).poll_next(cx) {
36                Poll::Ready(Some(Ok(item))) => return Poll::Ready(Some(item)),
37                Poll::Ready(Some(Err(e))) => {
38                    warn!("BroadcastStream lagged: {e:?}");
39                }
40                Poll::Ready(None) => return Poll::Ready(None),
41                Poll::Pending => return Poll::Pending,
42            }
43        }
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use tokio::sync::broadcast;
51    use tokio_stream::StreamExt;
52
53    #[tokio::test]
54    async fn test_event_stream_yields_items() {
55        let (tx, _) = broadcast::channel(16);
56        let my_stream = EventStream::new(tx.subscribe());
57
58        tx.send(1).unwrap();
59        tx.send(2).unwrap();
60        tx.send(3).unwrap();
61
62        // drop the sender to terminate the stream and allow collect to work.
63        drop(tx);
64
65        let items: Vec<i32> = my_stream.collect().await;
66
67        assert_eq!(items, vec![1, 2, 3]);
68    }
69
70    #[tokio::test]
71    async fn test_event_stream_skips_lag_errors() {
72        let (tx, _) = broadcast::channel(2);
73        let my_stream = EventStream::new(tx.subscribe());
74
75        let mut _rx2 = tx.subscribe();
76        let mut _rx3 = tx.subscribe();
77
78        tx.send(1).unwrap();
79        tx.send(2).unwrap();
80        tx.send(3).unwrap();
81        tx.send(4).unwrap(); // This will cause lag for the first subscriber
82
83        // drop the sender to terminate the stream and allow collect to work.
84        drop(tx);
85
86        // Ensure lag errors are skipped and only valid items are collected
87        let items: Vec<i32> = my_stream.collect().await;
88
89        assert_eq!(items, vec![3, 4]);
90    }
91}