1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
//! Event streams related functionality.

use std::{
    pin::Pin,
    task::{Context, Poll},
};
use tokio_stream::Stream;
use tracing::warn;

/// Thin wrapper around tokio's `BroadcastStream` to allow skipping broadcast errors.
#[derive(Debug)]
pub struct EventStream<T> {
    inner: tokio_stream::wrappers::BroadcastStream<T>,
}

impl<T> EventStream<T>
where
    T: Clone + Send + 'static,
{
    /// Creates a new `EventStream`.
    pub fn new(receiver: tokio::sync::broadcast::Receiver<T>) -> Self {
        let inner = tokio_stream::wrappers::BroadcastStream::new(receiver);
        Self { inner }
    }
}

impl<T> Stream for EventStream<T>
where
    T: Clone + Send + 'static,
{
    type Item = T;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        loop {
            match Pin::new(&mut self.inner).poll_next(cx) {
                Poll::Ready(Some(Ok(item))) => return Poll::Ready(Some(item)),
                Poll::Ready(Some(Err(e))) => {
                    warn!("BroadcastStream lagged: {e:?}");
                    continue
                }
                Poll::Ready(None) => return Poll::Ready(None),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokio::sync::broadcast;
    use tokio_stream::StreamExt;

    #[tokio::test]
    async fn test_event_stream_yields_items() {
        let (tx, _) = broadcast::channel(16);
        let my_stream = EventStream::new(tx.subscribe());

        tx.send(1).unwrap();
        tx.send(2).unwrap();
        tx.send(3).unwrap();

        // drop the sender to terminate the stream and allow collect to work.
        drop(tx);

        let items: Vec<i32> = my_stream.collect().await;

        assert_eq!(items, vec![1, 2, 3]);
    }

    #[tokio::test]
    async fn test_event_stream_skips_lag_errors() {
        let (tx, _) = broadcast::channel(2);
        let my_stream = EventStream::new(tx.subscribe());

        let mut _rx2 = tx.subscribe();
        let mut _rx3 = tx.subscribe();

        tx.send(1).unwrap();
        tx.send(2).unwrap();
        tx.send(3).unwrap();
        tx.send(4).unwrap(); // This will cause lag for the first subscriber

        // drop the sender to terminate the stream and allow collect to work.
        drop(tx);

        // Ensure lag errors are skipped and only valid items are collected
        let items: Vec<i32> = my_stream.collect().await;

        assert_eq!(items, vec![3, 4]);
    }
}