reth_optimism_flashblocks/ws/
stream.rs

1use crate::{FlashBlock, FlashBlockDecoder};
2use futures_util::{
3    stream::{SplitSink, SplitStream},
4    FutureExt, Sink, Stream, StreamExt,
5};
6use std::{
7    fmt::{Debug, Formatter},
8    future::Future,
9    pin::Pin,
10    task::{ready, Context, Poll},
11};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{
14    connect_async,
15    tungstenite::{protocol::CloseFrame, Bytes, Error, Message},
16    MaybeTlsStream, WebSocketStream,
17};
18use tracing::debug;
19use url::Url;
20
21/// An asynchronous stream of [`FlashBlock`] from a websocket connection.
22///
23/// The stream attempts to connect to a websocket URL and then decode each received item.
24///
25/// If the connection fails, the error is returned and connection retried. The number of retries is
26/// unbounded.
27pub struct WsFlashBlockStream<Stream, Sink, Connector> {
28    ws_url: Url,
29    state: State,
30    connector: Connector,
31    decoder: Box<dyn FlashBlockDecoder>,
32    connect: ConnectFuture<Sink, Stream>,
33    stream: Option<Stream>,
34    sink: Option<Sink>,
35}
36
37impl WsFlashBlockStream<WsStream, WsSink, WsConnector> {
38    /// Creates a new websocket stream over `ws_url`.
39    pub fn new(ws_url: Url) -> Self {
40        Self {
41            ws_url,
42            state: State::default(),
43            connector: WsConnector,
44            decoder: Box::new(()),
45            connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
46            stream: None,
47            sink: None,
48        }
49    }
50
51    /// Sets the [`FlashBlock`] decoder for the websocket stream.
52    pub fn with_decoder(self, decoder: Box<dyn FlashBlockDecoder>) -> Self {
53        Self { decoder, ..self }
54    }
55}
56
57impl<Stream, S, C> WsFlashBlockStream<Stream, S, C> {
58    /// Creates a new websocket stream over `ws_url`.
59    pub fn with_connector(ws_url: Url, connector: C) -> Self {
60        Self {
61            ws_url,
62            state: State::default(),
63            decoder: Box::new(()),
64            connector,
65            connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
66            stream: None,
67            sink: None,
68        }
69    }
70}
71
72impl<Str, S, C> Stream for WsFlashBlockStream<Str, S, C>
73where
74    Str: Stream<Item = Result<Message, Error>> + Unpin,
75    S: Sink<Message> + Send + Unpin,
76    C: WsConnect<Stream = Str, Sink = S> + Clone + Send + 'static + Unpin,
77{
78    type Item = eyre::Result<FlashBlock>;
79
80    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        let this = self.get_mut();
82
83        'start: loop {
84            if this.state == State::Initial {
85                this.connect();
86            }
87
88            if this.state == State::Connect {
89                match ready!(this.connect.poll_unpin(cx)) {
90                    Ok((sink, stream)) => this.stream(sink, stream),
91                    Err(err) => {
92                        this.state = State::Initial;
93
94                        return Poll::Ready(Some(Err(err)));
95                    }
96                }
97            }
98
99            while let State::Stream(msg) = &mut this.state {
100                if msg.is_some() {
101                    let mut sink = Pin::new(this.sink.as_mut().unwrap());
102                    let _ = ready!(sink.as_mut().poll_ready(cx));
103                    if let Some(pong) = msg.take() {
104                        let _ = sink.as_mut().start_send(pong);
105                    }
106                    let _ = ready!(sink.as_mut().poll_flush(cx));
107                }
108
109                let Some(msg) = ready!(this
110                    .stream
111                    .as_mut()
112                    .expect("Stream state should be unreachable without stream")
113                    .poll_next_unpin(cx))
114                else {
115                    this.state = State::Initial;
116
117                    continue 'start;
118                };
119
120                match msg {
121                    Ok(Message::Binary(bytes)) => {
122                        return Poll::Ready(Some(this.decoder.decode(bytes)))
123                    }
124                    Ok(Message::Text(bytes)) => {
125                        return Poll::Ready(Some(this.decoder.decode(bytes.into())))
126                    }
127                    Ok(Message::Ping(bytes)) => this.ping(bytes),
128                    Ok(Message::Close(frame)) => this.close(frame),
129                    Ok(msg) => debug!("Received unexpected message: {:?}", msg),
130                    Err(err) => return Poll::Ready(Some(Err(err.into()))),
131                }
132            }
133        }
134    }
135}
136
137impl<Stream, S, C> WsFlashBlockStream<Stream, S, C>
138where
139    C: WsConnect<Stream = Stream, Sink = S> + Clone + Send + 'static,
140{
141    fn connect(&mut self) {
142        let ws_url = self.ws_url.clone();
143        let mut connector = self.connector.clone();
144
145        Pin::new(&mut self.connect).set(Box::pin(async move { connector.connect(ws_url).await }));
146
147        self.state = State::Connect;
148    }
149
150    fn stream(&mut self, sink: S, stream: Stream) {
151        self.sink.replace(sink);
152        self.stream.replace(stream);
153
154        self.state = State::Stream(None);
155    }
156
157    fn ping(&mut self, pong: Bytes) {
158        if let State::Stream(current) = &mut self.state {
159            current.replace(Message::Pong(pong));
160        }
161    }
162
163    fn close(&mut self, frame: Option<CloseFrame>) {
164        if let State::Stream(current) = &mut self.state {
165            current.replace(Message::Close(frame));
166        }
167    }
168}
169
170impl<Stream: Debug, S: Debug, C: Debug> Debug for WsFlashBlockStream<Stream, S, C> {
171    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
172        f.debug_struct("FlashBlockStream")
173            .field("ws_url", &self.ws_url)
174            .field("state", &self.state)
175            .field("connector", &self.connector)
176            .field("connect", &"Pin<Box<dyn Future<..>>>")
177            .field("stream", &self.stream)
178            .finish()
179    }
180}
181
182#[derive(Default, Debug, Eq, PartialEq)]
183enum State {
184    #[default]
185    Initial,
186    Connect,
187    Stream(Option<Message>),
188}
189
190type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
191type WsStream = SplitStream<Ws>;
192type WsSink = SplitSink<Ws, Message>;
193type ConnectFuture<Sink, Stream> =
194    Pin<Box<dyn Future<Output = eyre::Result<(Sink, Stream)>> + Send + 'static>>;
195
196/// The `WsConnect` trait allows for connecting to a websocket.
197///
198/// Implementors of the `WsConnect` trait are called 'connectors'.
199///
200/// Connectors are defined by one method, [`connect()`]. A call to [`connect()`] attempts to
201/// establish a secure websocket connection and return an asynchronous stream of [`Message`]s
202/// wrapped in a [`Result`].
203///
204/// [`connect()`]: Self::connect
205pub trait WsConnect {
206    /// An associated `Stream` of [`Message`]s wrapped in a [`Result`] that this connection returns.
207    type Stream;
208
209    /// An associated `Sink` of [`Message`]s that this connection sends.
210    type Sink;
211
212    /// Asynchronously connects to a websocket hosted on `ws_url`.
213    ///
214    /// See the [`WsConnect`] documentation for details.
215    fn connect(
216        &mut self,
217        ws_url: Url,
218    ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send;
219}
220
221/// Establishes a secure websocket subscription.
222///
223/// See the [`WsConnect`] documentation for details.
224#[derive(Debug, Clone)]
225pub struct WsConnector;
226
227impl WsConnect for WsConnector {
228    type Stream = WsStream;
229    type Sink = WsSink;
230
231    async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> {
232        let (stream, _response) = connect_async(ws_url.as_str()).await?;
233
234        Ok(stream.split())
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::ExecutionPayloadBaseV1;
242    use alloy_primitives::bytes::Bytes;
243    use brotli::enc::BrotliEncoderParams;
244    use std::{future, iter};
245    use tokio_tungstenite::tungstenite::{
246        protocol::frame::{coding::CloseCode, Frame},
247        Error,
248    };
249
250    /// A `FakeConnector` creates [`FakeStream`].
251    ///
252    /// It simulates the websocket stream instead of connecting to a real websocket.
253    #[derive(Clone)]
254    struct FakeConnector(FakeStream);
255
256    /// A `FakeConnectorWithSink` creates [`FakeStream`] and [`FakeSink`].
257    ///
258    /// It simulates the websocket stream instead of connecting to a real websocket. It also accepts
259    /// messages into an in-memory buffer.
260    #[derive(Clone)]
261    struct FakeConnectorWithSink(FakeStream);
262
263    /// Simulates a websocket stream while using a preprogrammed set of messages instead.
264    #[derive(Default)]
265    struct FakeStream(Vec<Result<Message, Error>>);
266
267    impl FakeStream {
268        fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
269            messages.reverse();
270
271            Self(messages)
272        }
273    }
274
275    impl Clone for FakeStream {
276        fn clone(&self) -> Self {
277            Self(
278                self.0
279                    .iter()
280                    .map(|v| match v {
281                        Ok(msg) => Ok(msg.clone()),
282                        Err(err) => Err(match err {
283                            Error::AttackAttempt => Error::AttackAttempt,
284                            err => unimplemented!("Cannot clone this error: {err}"),
285                        }),
286                    })
287                    .collect(),
288            )
289        }
290    }
291
292    impl Stream for FakeStream {
293        type Item = Result<Message, Error>;
294
295        fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296            let this = self.get_mut();
297
298            Poll::Ready(this.0.pop())
299        }
300    }
301
302    #[derive(Clone)]
303    struct NoopSink;
304
305    impl<T> Sink<T> for NoopSink {
306        type Error = ();
307
308        fn poll_ready(
309            self: Pin<&mut Self>,
310            _cx: &mut Context<'_>,
311        ) -> Poll<Result<(), Self::Error>> {
312            unimplemented!()
313        }
314
315        fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
316            unimplemented!()
317        }
318
319        fn poll_flush(
320            self: Pin<&mut Self>,
321            _cx: &mut Context<'_>,
322        ) -> Poll<Result<(), Self::Error>> {
323            unimplemented!()
324        }
325
326        fn poll_close(
327            self: Pin<&mut Self>,
328            _cx: &mut Context<'_>,
329        ) -> Poll<Result<(), Self::Error>> {
330            unimplemented!()
331        }
332    }
333
334    /// Receives [`Message`]s and stores them. A call to `start_send` first buffers the message
335    /// to simulate flushing behavior.
336    #[derive(Clone, Default)]
337    struct FakeSink(Option<Message>, Vec<Message>);
338
339    impl Sink<Message> for FakeSink {
340        type Error = ();
341
342        fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
343            self.poll_flush(cx)
344        }
345
346        fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
347            self.get_mut().0.replace(item);
348            Ok(())
349        }
350
351        fn poll_flush(
352            self: Pin<&mut Self>,
353            _cx: &mut Context<'_>,
354        ) -> Poll<Result<(), Self::Error>> {
355            let this = self.get_mut();
356            if let Some(item) = this.0.take() {
357                this.1.push(item);
358            }
359            Poll::Ready(Ok(()))
360        }
361
362        fn poll_close(
363            self: Pin<&mut Self>,
364            _cx: &mut Context<'_>,
365        ) -> Poll<Result<(), Self::Error>> {
366            Poll::Ready(Ok(()))
367        }
368    }
369
370    impl WsConnect for FakeConnector {
371        type Stream = FakeStream;
372        type Sink = NoopSink;
373
374        fn connect(
375            &mut self,
376            _ws_url: Url,
377        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
378            future::ready(Ok((NoopSink, self.0.clone())))
379        }
380    }
381
382    impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
383        fn from(value: T) -> Self {
384            Self(FakeStream::new(value.into_iter().collect()))
385        }
386    }
387
388    impl WsConnect for FakeConnectorWithSink {
389        type Stream = FakeStream;
390        type Sink = FakeSink;
391
392        fn connect(
393            &mut self,
394            _ws_url: Url,
395        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
396            future::ready(Ok((FakeSink::default(), self.0.clone())))
397        }
398    }
399
400    impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
401        fn from(value: T) -> Self {
402            Self(FakeStream::new(value.into_iter().collect()))
403        }
404    }
405
406    /// Repeatedly fails to connect with the given error message.
407    #[derive(Clone)]
408    struct FailingConnector(String);
409
410    impl WsConnect for FailingConnector {
411        type Stream = FakeStream;
412        type Sink = NoopSink;
413
414        fn connect(
415            &mut self,
416            _ws_url: Url,
417        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
418            future::ready(Err(eyre::eyre!("{}", &self.0)))
419        }
420    }
421
422    fn to_json_message<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
423        wrapper_f: F,
424    ) -> impl Fn(&FlashBlock) -> Result<Message, Error> + use<F, B> {
425        move |block| to_json_message_using(block, &wrapper_f)
426    }
427
428    fn to_json_binary_message(block: &FlashBlock) -> Result<Message, Error> {
429        to_json_message_using(block, Message::Binary)
430    }
431
432    fn to_json_message_using<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
433        block: &FlashBlock,
434        wrapper_f: F,
435    ) -> Result<Message, Error> {
436        Ok(wrapper_f(B::try_from(Bytes::from(serde_json::to_vec(block).unwrap())).unwrap()))
437    }
438
439    fn to_brotli_message(block: &FlashBlock) -> Result<Message, Error> {
440        let json = serde_json::to_vec(block).unwrap();
441        let mut compressed = Vec::new();
442        brotli::BrotliCompress(
443            &mut json.as_slice(),
444            &mut compressed,
445            &BrotliEncoderParams::default(),
446        )?;
447
448        Ok(Message::Binary(Bytes::from(compressed)))
449    }
450
451    fn flashblock() -> FlashBlock {
452        FlashBlock {
453            payload_id: Default::default(),
454            index: 0,
455            base: Some(ExecutionPayloadBaseV1 {
456                parent_beacon_block_root: Default::default(),
457                parent_hash: Default::default(),
458                fee_recipient: Default::default(),
459                prev_randao: Default::default(),
460                block_number: 0,
461                gas_limit: 0,
462                timestamp: 0,
463                extra_data: Default::default(),
464                base_fee_per_gas: Default::default(),
465            }),
466            diff: Default::default(),
467            metadata: Default::default(),
468        }
469    }
470
471    #[test_case::test_case(to_json_message(Message::Binary); "json binary")]
472    #[test_case::test_case(to_json_message(Message::Text); "json UTF-8")]
473    #[test_case::test_case(to_brotli_message; "brotli")]
474    #[tokio::test]
475    async fn test_stream_decodes_messages_successfully(
476        to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
477    ) {
478        let flashblocks = [flashblock()];
479        let connector = FakeConnector::from(flashblocks.iter().map(to_message));
480        let ws_url = "http://localhost".parse().unwrap();
481        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
482
483        let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
484        let expected_messages = flashblocks.to_vec();
485
486        assert_eq!(actual_messages, expected_messages);
487    }
488
489    #[test_case::test_case(Message::Pong(Bytes::from(b"test".as_slice())); "pong")]
490    #[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
491    #[tokio::test]
492    async fn test_stream_ignores_unexpected_message(message: Message) {
493        let flashblock = flashblock();
494        let connector = FakeConnector::from([Ok(message), to_json_binary_message(&flashblock)]);
495        let ws_url = "http://localhost".parse().unwrap();
496        let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
497
498        let expected_message = flashblock;
499        let actual_message =
500            stream.next().await.expect("Binary message should not be ignored").unwrap();
501
502        assert_eq!(actual_message, expected_message)
503    }
504
505    #[tokio::test]
506    async fn test_stream_passes_errors_through() {
507        let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
508        let ws_url = "http://localhost".parse().unwrap();
509        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
510
511        let actual_messages: Vec<_> =
512            stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
513        let expected_messages = vec!["Attack attempt detected".to_owned()];
514
515        assert_eq!(actual_messages, expected_messages);
516    }
517
518    #[tokio::test]
519    async fn test_connect_error_causes_retries() {
520        let tries = 3;
521        let error_msg = "test".to_owned();
522        let connector = FailingConnector(error_msg.clone());
523        let ws_url = "http://localhost".parse().unwrap();
524        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
525
526        let actual_errors: Vec<_> =
527            stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
528        let expected_errors: Vec<_> = iter::repeat_n(error_msg, tries).collect();
529
530        assert_eq!(actual_errors, expected_errors);
531    }
532
533    #[test_case::test_case(
534        Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() })),
535        Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() }));
536        "close"
537    )]
538    #[test_case::test_case(
539        Message::Ping(Bytes::from_static(&[1u8, 2, 3])),
540        Message::Pong(Bytes::from_static(&[1u8, 2, 3]));
541        "ping"
542    )]
543    #[tokio::test]
544    async fn test_stream_responds_to_messages(msg: Message, expected_response: Message) {
545        let flashblock = flashblock();
546        let messages = [Ok(msg), to_json_binary_message(&flashblock)];
547        let connector = FakeConnectorWithSink::from(messages);
548        let ws_url = "http://localhost".parse().unwrap();
549        let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
550
551        let _ = stream.next().await;
552
553        let expected_response = vec![expected_response];
554        let FakeSink(actual_buffer, actual_response) = stream.sink.unwrap();
555
556        assert!(actual_buffer.is_none(), "buffer not flushed: {actual_buffer:#?}");
557        assert_eq!(actual_response, expected_response);
558    }
559}