reth_optimism_flashblocks/ws/
stream.rs

1use crate::{ws::FlashBlockDecoder, FlashBlock};
2use futures_util::{
3    stream::{SplitSink, SplitStream},
4    FutureExt, Sink, Stream, StreamExt,
5};
6use std::{
7    fmt::{Debug, Formatter},
8    future::Future,
9    pin::Pin,
10    task::{ready, Context, Poll},
11};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{
14    connect_async,
15    tungstenite::{protocol::CloseFrame, Bytes, Error, Message},
16    MaybeTlsStream, WebSocketStream,
17};
18use tracing::debug;
19use url::Url;
20
21/// An asynchronous stream of [`FlashBlock`] from a websocket connection.
22///
23/// The stream attempts to connect to a websocket URL and then decode each received item.
24///
25/// If the connection fails, the error is returned and connection retried. The number of retries is
26/// unbounded.
27pub struct WsFlashBlockStream<Stream, Sink, Connector> {
28    ws_url: Url,
29    state: State,
30    connector: Connector,
31    decoder: Box<dyn FlashBlockDecoder>,
32    connect: ConnectFuture<Sink, Stream>,
33    stream: Option<Stream>,
34    sink: Option<Sink>,
35}
36
37impl WsFlashBlockStream<WsStream, WsSink, WsConnector> {
38    /// Creates a new websocket stream over `ws_url`.
39    pub fn new(ws_url: Url) -> Self {
40        Self {
41            ws_url,
42            state: State::default(),
43            connector: WsConnector,
44            decoder: Box::new(()),
45            connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
46            stream: None,
47            sink: None,
48        }
49    }
50
51    /// Sets the [`FlashBlock`] decoder for the websocket stream.
52    pub fn with_decoder(self, decoder: Box<dyn FlashBlockDecoder>) -> Self {
53        Self { decoder, ..self }
54    }
55}
56
57impl<Stream, S, C> WsFlashBlockStream<Stream, S, C> {
58    /// Creates a new websocket stream over `ws_url`.
59    pub fn with_connector(ws_url: Url, connector: C) -> Self {
60        Self {
61            ws_url,
62            state: State::default(),
63            decoder: Box::new(()),
64            connector,
65            connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
66            stream: None,
67            sink: None,
68        }
69    }
70}
71
72impl<Str, S, C> Stream for WsFlashBlockStream<Str, S, C>
73where
74    Str: Stream<Item = Result<Message, Error>> + Unpin,
75    S: Sink<Message> + Send + Unpin,
76    C: WsConnect<Stream = Str, Sink = S> + Clone + Send + 'static + Unpin,
77{
78    type Item = eyre::Result<FlashBlock>;
79
80    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        let this = self.get_mut();
82
83        'start: loop {
84            if this.state == State::Initial {
85                this.connect();
86            }
87
88            if this.state == State::Connect {
89                match ready!(this.connect.poll_unpin(cx)) {
90                    Ok((sink, stream)) => this.stream(sink, stream),
91                    Err(err) => {
92                        this.state = State::Initial;
93
94                        return Poll::Ready(Some(Err(err)));
95                    }
96                }
97            }
98
99            while let State::Stream(msg) = &mut this.state {
100                if msg.is_some() {
101                    let mut sink = Pin::new(this.sink.as_mut().unwrap());
102                    let _ = ready!(sink.as_mut().poll_ready(cx));
103                    if let Some(pong) = msg.take() {
104                        let _ = sink.as_mut().start_send(pong);
105                    }
106                    let _ = ready!(sink.as_mut().poll_flush(cx));
107                }
108
109                let Some(msg) = ready!(this
110                    .stream
111                    .as_mut()
112                    .expect("Stream state should be unreachable without stream")
113                    .poll_next_unpin(cx))
114                else {
115                    this.state = State::Initial;
116
117                    continue 'start;
118                };
119
120                match msg {
121                    Ok(Message::Binary(bytes)) => {
122                        return Poll::Ready(Some(this.decoder.decode(bytes)))
123                    }
124                    Ok(Message::Text(bytes)) => {
125                        return Poll::Ready(Some(this.decoder.decode(bytes.into())))
126                    }
127                    Ok(Message::Ping(bytes)) => this.ping(bytes),
128                    Ok(Message::Close(frame)) => this.close(frame),
129                    Ok(msg) => {
130                        debug!(target: "flashblocks", "Received unexpected message: {:?}", msg)
131                    }
132                    Err(err) => return Poll::Ready(Some(Err(err.into()))),
133                }
134            }
135        }
136    }
137}
138
139impl<Stream, S, C> WsFlashBlockStream<Stream, S, C>
140where
141    C: WsConnect<Stream = Stream, Sink = S> + Clone + Send + 'static,
142{
143    fn connect(&mut self) {
144        let ws_url = self.ws_url.clone();
145        let mut connector = self.connector.clone();
146
147        Pin::new(&mut self.connect).set(Box::pin(async move { connector.connect(ws_url).await }));
148
149        self.state = State::Connect;
150    }
151
152    fn stream(&mut self, sink: S, stream: Stream) {
153        self.sink.replace(sink);
154        self.stream.replace(stream);
155
156        self.state = State::Stream(None);
157    }
158
159    fn ping(&mut self, pong: Bytes) {
160        if let State::Stream(current) = &mut self.state {
161            current.replace(Message::Pong(pong));
162        }
163    }
164
165    fn close(&mut self, frame: Option<CloseFrame>) {
166        if let State::Stream(current) = &mut self.state {
167            current.replace(Message::Close(frame));
168        }
169    }
170}
171
172impl<Stream: Debug, S: Debug, C: Debug> Debug for WsFlashBlockStream<Stream, S, C> {
173    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
174        f.debug_struct("FlashBlockStream")
175            .field("ws_url", &self.ws_url)
176            .field("state", &self.state)
177            .field("connector", &self.connector)
178            .field("connect", &"Pin<Box<dyn Future<..>>>")
179            .field("stream", &self.stream)
180            .finish()
181    }
182}
183
184#[derive(Default, Debug, Eq, PartialEq)]
185enum State {
186    #[default]
187    Initial,
188    Connect,
189    Stream(Option<Message>),
190}
191
192type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
193type WsStream = SplitStream<Ws>;
194type WsSink = SplitSink<Ws, Message>;
195type ConnectFuture<Sink, Stream> =
196    Pin<Box<dyn Future<Output = eyre::Result<(Sink, Stream)>> + Send + 'static>>;
197
198/// The `WsConnect` trait allows for connecting to a websocket.
199///
200/// Implementors of the `WsConnect` trait are called 'connectors'.
201///
202/// Connectors are defined by one method, [`connect()`]. A call to [`connect()`] attempts to
203/// establish a secure websocket connection and return an asynchronous stream of [`Message`]s
204/// wrapped in a [`Result`].
205///
206/// [`connect()`]: Self::connect
207pub trait WsConnect {
208    /// An associated `Stream` of [`Message`]s wrapped in a [`Result`] that this connection returns.
209    type Stream;
210
211    /// An associated `Sink` of [`Message`]s that this connection sends.
212    type Sink;
213
214    /// Asynchronously connects to a websocket hosted on `ws_url`.
215    ///
216    /// See the [`WsConnect`] documentation for details.
217    fn connect(
218        &mut self,
219        ws_url: Url,
220    ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send;
221}
222
223/// Establishes a secure websocket subscription.
224///
225/// See the [`WsConnect`] documentation for details.
226#[derive(Debug, Clone)]
227pub struct WsConnector;
228
229impl WsConnect for WsConnector {
230    type Stream = WsStream;
231    type Sink = WsSink;
232
233    async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> {
234        let (stream, _response) = connect_async(ws_url.as_str()).await?;
235
236        Ok(stream.split())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use alloy_primitives::bytes::Bytes;
244    use brotli::enc::BrotliEncoderParams;
245    use std::{future, iter};
246    use tokio_tungstenite::tungstenite::{
247        protocol::frame::{coding::CloseCode, Frame},
248        Error,
249    };
250
251    /// A `FakeConnector` creates [`FakeStream`].
252    ///
253    /// It simulates the websocket stream instead of connecting to a real websocket.
254    #[derive(Clone)]
255    struct FakeConnector(FakeStream);
256
257    /// A `FakeConnectorWithSink` creates [`FakeStream`] and [`FakeSink`].
258    ///
259    /// It simulates the websocket stream instead of connecting to a real websocket. It also accepts
260    /// messages into an in-memory buffer.
261    #[derive(Clone)]
262    struct FakeConnectorWithSink(FakeStream);
263
264    /// Simulates a websocket stream while using a preprogrammed set of messages instead.
265    #[derive(Default)]
266    struct FakeStream(Vec<Result<Message, Error>>);
267
268    impl FakeStream {
269        fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
270            messages.reverse();
271
272            Self(messages)
273        }
274    }
275
276    impl Clone for FakeStream {
277        fn clone(&self) -> Self {
278            Self(
279                self.0
280                    .iter()
281                    .map(|v| match v {
282                        Ok(msg) => Ok(msg.clone()),
283                        Err(err) => Err(match err {
284                            Error::AttackAttempt => Error::AttackAttempt,
285                            err => unimplemented!("Cannot clone this error: {err}"),
286                        }),
287                    })
288                    .collect(),
289            )
290        }
291    }
292
293    impl Stream for FakeStream {
294        type Item = Result<Message, Error>;
295
296        fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297            let this = self.get_mut();
298
299            Poll::Ready(this.0.pop())
300        }
301    }
302
303    #[derive(Clone)]
304    struct NoopSink;
305
306    impl<T> Sink<T> for NoopSink {
307        type Error = ();
308
309        fn poll_ready(
310            self: Pin<&mut Self>,
311            _cx: &mut Context<'_>,
312        ) -> Poll<Result<(), Self::Error>> {
313            unimplemented!()
314        }
315
316        fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
317            unimplemented!()
318        }
319
320        fn poll_flush(
321            self: Pin<&mut Self>,
322            _cx: &mut Context<'_>,
323        ) -> Poll<Result<(), Self::Error>> {
324            unimplemented!()
325        }
326
327        fn poll_close(
328            self: Pin<&mut Self>,
329            _cx: &mut Context<'_>,
330        ) -> Poll<Result<(), Self::Error>> {
331            unimplemented!()
332        }
333    }
334
335    /// Receives [`Message`]s and stores them. A call to `start_send` first buffers the message
336    /// to simulate flushing behavior.
337    #[derive(Clone, Default)]
338    struct FakeSink(Option<Message>, Vec<Message>);
339
340    impl Sink<Message> for FakeSink {
341        type Error = ();
342
343        fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344            self.poll_flush(cx)
345        }
346
347        fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
348            self.get_mut().0.replace(item);
349            Ok(())
350        }
351
352        fn poll_flush(
353            self: Pin<&mut Self>,
354            _cx: &mut Context<'_>,
355        ) -> Poll<Result<(), Self::Error>> {
356            let this = self.get_mut();
357            if let Some(item) = this.0.take() {
358                this.1.push(item);
359            }
360            Poll::Ready(Ok(()))
361        }
362
363        fn poll_close(
364            self: Pin<&mut Self>,
365            _cx: &mut Context<'_>,
366        ) -> Poll<Result<(), Self::Error>> {
367            Poll::Ready(Ok(()))
368        }
369    }
370
371    impl WsConnect for FakeConnector {
372        type Stream = FakeStream;
373        type Sink = NoopSink;
374
375        fn connect(
376            &mut self,
377            _ws_url: Url,
378        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
379            future::ready(Ok((NoopSink, self.0.clone())))
380        }
381    }
382
383    impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
384        fn from(value: T) -> Self {
385            Self(FakeStream::new(value.into_iter().collect()))
386        }
387    }
388
389    impl WsConnect for FakeConnectorWithSink {
390        type Stream = FakeStream;
391        type Sink = FakeSink;
392
393        fn connect(
394            &mut self,
395            _ws_url: Url,
396        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
397            future::ready(Ok((FakeSink::default(), self.0.clone())))
398        }
399    }
400
401    impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
402        fn from(value: T) -> Self {
403            Self(FakeStream::new(value.into_iter().collect()))
404        }
405    }
406
407    /// Repeatedly fails to connect with the given error message.
408    #[derive(Clone)]
409    struct FailingConnector(String);
410
411    impl WsConnect for FailingConnector {
412        type Stream = FakeStream;
413        type Sink = NoopSink;
414
415        fn connect(
416            &mut self,
417            _ws_url: Url,
418        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
419            future::ready(Err(eyre::eyre!("{}", &self.0)))
420        }
421    }
422
423    fn to_json_message<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
424        wrapper_f: F,
425    ) -> impl Fn(&FlashBlock) -> Result<Message, Error> + use<F, B> {
426        move |block| to_json_message_using(block, &wrapper_f)
427    }
428
429    fn to_json_binary_message(block: &FlashBlock) -> Result<Message, Error> {
430        to_json_message_using(block, Message::Binary)
431    }
432
433    fn to_json_message_using<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
434        block: &FlashBlock,
435        wrapper_f: F,
436    ) -> Result<Message, Error> {
437        Ok(wrapper_f(B::try_from(Bytes::from(serde_json::to_vec(block).unwrap())).unwrap()))
438    }
439
440    fn to_brotli_message(block: &FlashBlock) -> Result<Message, Error> {
441        let json = serde_json::to_vec(block).unwrap();
442        let mut compressed = Vec::new();
443        brotli::BrotliCompress(
444            &mut json.as_slice(),
445            &mut compressed,
446            &BrotliEncoderParams::default(),
447        )?;
448
449        Ok(Message::Binary(Bytes::from(compressed)))
450    }
451
452    fn flashblock() -> FlashBlock {
453        Default::default()
454    }
455
456    #[test_case::test_case(to_json_message(Message::Binary); "json binary")]
457    #[test_case::test_case(to_json_message(Message::Text); "json UTF-8")]
458    #[test_case::test_case(to_brotli_message; "brotli")]
459    #[tokio::test]
460    async fn test_stream_decodes_messages_successfully(
461        to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
462    ) {
463        let flashblocks = [flashblock()];
464        let connector = FakeConnector::from(flashblocks.iter().map(to_message));
465        let ws_url = "http://localhost".parse().unwrap();
466        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
467
468        let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
469        let expected_messages = flashblocks.to_vec();
470
471        assert_eq!(actual_messages, expected_messages);
472    }
473
474    #[test_case::test_case(Message::Pong(Bytes::from(b"test".as_slice())); "pong")]
475    #[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
476    #[tokio::test]
477    async fn test_stream_ignores_unexpected_message(message: Message) {
478        let flashblock = flashblock();
479        let connector = FakeConnector::from([Ok(message), to_json_binary_message(&flashblock)]);
480        let ws_url = "http://localhost".parse().unwrap();
481        let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
482
483        let expected_message = flashblock;
484        let actual_message =
485            stream.next().await.expect("Binary message should not be ignored").unwrap();
486
487        assert_eq!(actual_message, expected_message)
488    }
489
490    #[tokio::test]
491    async fn test_stream_passes_errors_through() {
492        let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
493        let ws_url = "http://localhost".parse().unwrap();
494        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
495
496        let actual_messages: Vec<_> =
497            stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
498        let expected_messages = vec!["Attack attempt detected".to_owned()];
499
500        assert_eq!(actual_messages, expected_messages);
501    }
502
503    #[tokio::test]
504    async fn test_connect_error_causes_retries() {
505        let tries = 3;
506        let error_msg = "test".to_owned();
507        let connector = FailingConnector(error_msg.clone());
508        let ws_url = "http://localhost".parse().unwrap();
509        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
510
511        let actual_errors: Vec<_> =
512            stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
513        let expected_errors: Vec<_> = iter::repeat_n(error_msg, tries).collect();
514
515        assert_eq!(actual_errors, expected_errors);
516    }
517
518    #[test_case::test_case(
519        Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() })),
520        Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() }));
521        "close"
522    )]
523    #[test_case::test_case(
524        Message::Ping(Bytes::from_static(&[1u8, 2, 3])),
525        Message::Pong(Bytes::from_static(&[1u8, 2, 3]));
526        "ping"
527    )]
528    #[tokio::test]
529    async fn test_stream_responds_to_messages(msg: Message, expected_response: Message) {
530        let flashblock = flashblock();
531        let messages = [Ok(msg), to_json_binary_message(&flashblock)];
532        let connector = FakeConnectorWithSink::from(messages);
533        let ws_url = "http://localhost".parse().unwrap();
534        let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
535
536        let _ = stream.next().await;
537
538        let expected_response = vec![expected_response];
539        let FakeSink(actual_buffer, actual_response) = stream.sink.unwrap();
540
541        assert!(actual_buffer.is_none(), "buffer not flushed: {actual_buffer:#?}");
542        assert_eq!(actual_response, expected_response);
543    }
544}