reth_optimism_flashblocks/ws/
stream.rs

1use crate::FlashBlock;
2use futures_util::{
3    stream::{SplitSink, SplitStream},
4    FutureExt, Sink, Stream, StreamExt,
5};
6use std::{
7    fmt::{Debug, Formatter},
8    future::Future,
9    pin::Pin,
10    task::{ready, Context, Poll},
11};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{
14    connect_async,
15    tungstenite::{protocol::CloseFrame, Bytes, Error, Message},
16    MaybeTlsStream, WebSocketStream,
17};
18use tracing::debug;
19use url::Url;
20
21/// An asynchronous stream of [`FlashBlock`] from a websocket connection.
22///
23/// The stream attempts to connect to a websocket URL and then decode each received item.
24///
25/// If the connection fails, the error is returned and connection retried. The number of retries is
26/// unbounded.
27pub struct WsFlashBlockStream<Stream, Sink, Connector> {
28    ws_url: Url,
29    state: State,
30    connector: Connector,
31    connect: ConnectFuture<Sink, Stream>,
32    stream: Option<Stream>,
33    sink: Option<Sink>,
34}
35
36impl WsFlashBlockStream<WsStream, WsSink, WsConnector> {
37    /// Creates a new websocket stream over `ws_url`.
38    pub fn new(ws_url: Url) -> Self {
39        Self {
40            ws_url,
41            state: State::default(),
42            connector: WsConnector,
43            connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
44            stream: None,
45            sink: None,
46        }
47    }
48}
49
50impl<Stream, S, C> WsFlashBlockStream<Stream, S, C> {
51    /// Creates a new websocket stream over `ws_url`.
52    pub fn with_connector(ws_url: Url, connector: C) -> Self {
53        Self {
54            ws_url,
55            state: State::default(),
56            connector,
57            connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
58            stream: None,
59            sink: None,
60        }
61    }
62}
63
64impl<Str, S, C> Stream for WsFlashBlockStream<Str, S, C>
65where
66    Str: Stream<Item = Result<Message, Error>> + Unpin,
67    S: Sink<Message> + Send + Sync + Unpin,
68    C: WsConnect<Stream = Str, Sink = S> + Clone + Send + Sync + 'static + Unpin,
69{
70    type Item = eyre::Result<FlashBlock>;
71
72    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
73        let this = self.get_mut();
74
75        'start: loop {
76            if this.state == State::Initial {
77                this.connect();
78            }
79
80            if this.state == State::Connect {
81                match ready!(this.connect.poll_unpin(cx)) {
82                    Ok((sink, stream)) => this.stream(sink, stream),
83                    Err(err) => {
84                        this.state = State::Initial;
85
86                        return Poll::Ready(Some(Err(err)));
87                    }
88                }
89            }
90
91            while let State::Stream(msg) = &mut this.state {
92                if msg.is_some() {
93                    let mut sink = Pin::new(this.sink.as_mut().unwrap());
94                    let _ = ready!(sink.as_mut().poll_ready(cx));
95                    if let Some(pong) = msg.take() {
96                        let _ = sink.as_mut().start_send(pong);
97                    }
98                    let _ = ready!(sink.as_mut().poll_flush(cx));
99                }
100
101                let Some(msg) = ready!(this
102                    .stream
103                    .as_mut()
104                    .expect("Stream state should be unreachable without stream")
105                    .poll_next_unpin(cx))
106                else {
107                    this.state = State::Initial;
108
109                    continue 'start;
110                };
111
112                match msg {
113                    Ok(Message::Binary(bytes)) => {
114                        return Poll::Ready(Some(FlashBlock::decode(bytes)))
115                    }
116                    Ok(Message::Text(bytes)) => {
117                        return Poll::Ready(Some(FlashBlock::decode(bytes.into())))
118                    }
119                    Ok(Message::Ping(bytes)) => this.ping(bytes),
120                    Ok(Message::Close(frame)) => this.close(frame),
121                    Ok(msg) => debug!("Received unexpected message: {:?}", msg),
122                    Err(err) => return Poll::Ready(Some(Err(err.into()))),
123                }
124            }
125        }
126    }
127}
128
129impl<Stream, S, C> WsFlashBlockStream<Stream, S, C>
130where
131    C: WsConnect<Stream = Stream, Sink = S> + Clone + Send + Sync + 'static,
132{
133    fn connect(&mut self) {
134        let ws_url = self.ws_url.clone();
135        let mut connector = self.connector.clone();
136
137        Pin::new(&mut self.connect).set(Box::pin(async move { connector.connect(ws_url).await }));
138
139        self.state = State::Connect;
140    }
141
142    fn stream(&mut self, sink: S, stream: Stream) {
143        self.sink.replace(sink);
144        self.stream.replace(stream);
145
146        self.state = State::Stream(None);
147    }
148
149    fn ping(&mut self, pong: Bytes) {
150        if let State::Stream(current) = &mut self.state {
151            current.replace(Message::Pong(pong));
152        }
153    }
154
155    fn close(&mut self, frame: Option<CloseFrame>) {
156        if let State::Stream(current) = &mut self.state {
157            current.replace(Message::Close(frame));
158        }
159    }
160}
161
162impl<Stream: Debug, S: Debug, C: Debug> Debug for WsFlashBlockStream<Stream, S, C> {
163    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164        f.debug_struct("FlashBlockStream")
165            .field("ws_url", &self.ws_url)
166            .field("state", &self.state)
167            .field("connector", &self.connector)
168            .field("connect", &"Pin<Box<dyn Future<..>>>")
169            .field("stream", &self.stream)
170            .finish()
171    }
172}
173
174#[derive(Default, Debug, Eq, PartialEq)]
175enum State {
176    #[default]
177    Initial,
178    Connect,
179    Stream(Option<Message>),
180}
181
182type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
183type WsStream = SplitStream<Ws>;
184type WsSink = SplitSink<Ws, Message>;
185type ConnectFuture<Sink, Stream> =
186    Pin<Box<dyn Future<Output = eyre::Result<(Sink, Stream)>> + Send + Sync + 'static>>;
187
188/// The `WsConnect` trait allows for connecting to a websocket.
189///
190/// Implementors of the `WsConnect` trait are called 'connectors'.
191///
192/// Connectors are defined by one method, [`connect()`]. A call to [`connect()`] attempts to
193/// establish a secure websocket connection and return an asynchronous stream of [`Message`]s
194/// wrapped in a [`Result`].
195///
196/// [`connect()`]: Self::connect
197pub trait WsConnect {
198    /// An associated `Stream` of [`Message`]s wrapped in a [`Result`] that this connection returns.
199    type Stream;
200
201    /// An associated `Sink` of [`Message`]s that this connection sends.
202    type Sink;
203
204    /// Asynchronously connects to a websocket hosted on `ws_url`.
205    ///
206    /// See the [`WsConnect`] documentation for details.
207    fn connect(
208        &mut self,
209        ws_url: Url,
210    ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync;
211}
212
213/// Establishes a secure websocket subscription.
214///
215/// See the [`WsConnect`] documentation for details.
216#[derive(Debug, Clone)]
217pub struct WsConnector;
218
219impl WsConnect for WsConnector {
220    type Stream = WsStream;
221    type Sink = WsSink;
222
223    async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> {
224        let (stream, _response) = connect_async(ws_url.as_str()).await?;
225
226        Ok(stream.split())
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::ExecutionPayloadBaseV1;
234    use alloy_primitives::bytes::Bytes;
235    use brotli::enc::BrotliEncoderParams;
236    use std::{future, iter};
237    use tokio_tungstenite::tungstenite::{
238        protocol::frame::{coding::CloseCode, Frame},
239        Error,
240    };
241
242    /// A `FakeConnector` creates [`FakeStream`].
243    ///
244    /// It simulates the websocket stream instead of connecting to a real websocket.
245    #[derive(Clone)]
246    struct FakeConnector(FakeStream);
247
248    /// A `FakeConnectorWithSink` creates [`FakeStream`] and [`FakeSink`].
249    ///
250    /// It simulates the websocket stream instead of connecting to a real websocket. It also accepts
251    /// messages into an in-memory buffer.
252    #[derive(Clone)]
253    struct FakeConnectorWithSink(FakeStream);
254
255    /// Simulates a websocket stream while using a preprogrammed set of messages instead.
256    #[derive(Default)]
257    struct FakeStream(Vec<Result<Message, Error>>);
258
259    impl FakeStream {
260        fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
261            messages.reverse();
262
263            Self(messages)
264        }
265    }
266
267    impl Clone for FakeStream {
268        fn clone(&self) -> Self {
269            Self(
270                self.0
271                    .iter()
272                    .map(|v| match v {
273                        Ok(msg) => Ok(msg.clone()),
274                        Err(err) => Err(match err {
275                            Error::AttackAttempt => Error::AttackAttempt,
276                            err => unimplemented!("Cannot clone this error: {err}"),
277                        }),
278                    })
279                    .collect(),
280            )
281        }
282    }
283
284    impl Stream for FakeStream {
285        type Item = Result<Message, Error>;
286
287        fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
288            let this = self.get_mut();
289
290            Poll::Ready(this.0.pop())
291        }
292    }
293
294    #[derive(Clone)]
295    struct NoopSink;
296
297    impl<T> Sink<T> for NoopSink {
298        type Error = ();
299
300        fn poll_ready(
301            self: Pin<&mut Self>,
302            _cx: &mut Context<'_>,
303        ) -> Poll<Result<(), Self::Error>> {
304            unimplemented!()
305        }
306
307        fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
308            unimplemented!()
309        }
310
311        fn poll_flush(
312            self: Pin<&mut Self>,
313            _cx: &mut Context<'_>,
314        ) -> Poll<Result<(), Self::Error>> {
315            unimplemented!()
316        }
317
318        fn poll_close(
319            self: Pin<&mut Self>,
320            _cx: &mut Context<'_>,
321        ) -> Poll<Result<(), Self::Error>> {
322            unimplemented!()
323        }
324    }
325
326    /// Receives [`Message`]s and stores them. A call to `start_send` first buffers the message
327    /// to simulate flushing behavior.
328    #[derive(Clone, Default)]
329    struct FakeSink(Option<Message>, Vec<Message>);
330
331    impl Sink<Message> for FakeSink {
332        type Error = ();
333
334        fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
335            self.poll_flush(cx)
336        }
337
338        fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
339            self.get_mut().0.replace(item);
340            Ok(())
341        }
342
343        fn poll_flush(
344            self: Pin<&mut Self>,
345            _cx: &mut Context<'_>,
346        ) -> Poll<Result<(), Self::Error>> {
347            let this = self.get_mut();
348            if let Some(item) = this.0.take() {
349                this.1.push(item);
350            }
351            Poll::Ready(Ok(()))
352        }
353
354        fn poll_close(
355            self: Pin<&mut Self>,
356            _cx: &mut Context<'_>,
357        ) -> Poll<Result<(), Self::Error>> {
358            Poll::Ready(Ok(()))
359        }
360    }
361
362    impl WsConnect for FakeConnector {
363        type Stream = FakeStream;
364        type Sink = NoopSink;
365
366        fn connect(
367            &mut self,
368            _ws_url: Url,
369        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
370            future::ready(Ok((NoopSink, self.0.clone())))
371        }
372    }
373
374    impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
375        fn from(value: T) -> Self {
376            Self(FakeStream::new(value.into_iter().collect()))
377        }
378    }
379
380    impl WsConnect for FakeConnectorWithSink {
381        type Stream = FakeStream;
382        type Sink = FakeSink;
383
384        fn connect(
385            &mut self,
386            _ws_url: Url,
387        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
388            future::ready(Ok((FakeSink::default(), self.0.clone())))
389        }
390    }
391
392    impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
393        fn from(value: T) -> Self {
394            Self(FakeStream::new(value.into_iter().collect()))
395        }
396    }
397
398    /// Repeatedly fails to connect with the given error message.
399    #[derive(Clone)]
400    struct FailingConnector(String);
401
402    impl WsConnect for FailingConnector {
403        type Stream = FakeStream;
404        type Sink = NoopSink;
405
406        fn connect(
407            &mut self,
408            _ws_url: Url,
409        ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
410            future::ready(Err(eyre::eyre!("{}", &self.0)))
411        }
412    }
413
414    fn to_json_message<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
415        wrapper_f: F,
416    ) -> impl Fn(&FlashBlock) -> Result<Message, Error> + use<F, B> {
417        move |block| to_json_message_using(block, &wrapper_f)
418    }
419
420    fn to_json_binary_message(block: &FlashBlock) -> Result<Message, Error> {
421        to_json_message_using(block, Message::Binary)
422    }
423
424    fn to_json_message_using<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
425        block: &FlashBlock,
426        wrapper_f: F,
427    ) -> Result<Message, Error> {
428        Ok(wrapper_f(B::try_from(Bytes::from(serde_json::to_vec(block).unwrap())).unwrap()))
429    }
430
431    fn to_brotli_message(block: &FlashBlock) -> Result<Message, Error> {
432        let json = serde_json::to_vec(block).unwrap();
433        let mut compressed = Vec::new();
434        brotli::BrotliCompress(
435            &mut json.as_slice(),
436            &mut compressed,
437            &BrotliEncoderParams::default(),
438        )?;
439
440        Ok(Message::Binary(Bytes::from(compressed)))
441    }
442
443    fn flashblock() -> FlashBlock {
444        FlashBlock {
445            payload_id: Default::default(),
446            index: 0,
447            base: Some(ExecutionPayloadBaseV1 {
448                parent_beacon_block_root: Default::default(),
449                parent_hash: Default::default(),
450                fee_recipient: Default::default(),
451                prev_randao: Default::default(),
452                block_number: 0,
453                gas_limit: 0,
454                timestamp: 0,
455                extra_data: Default::default(),
456                base_fee_per_gas: Default::default(),
457            }),
458            diff: Default::default(),
459            metadata: Default::default(),
460        }
461    }
462
463    #[test_case::test_case(to_json_message(Message::Binary); "json binary")]
464    #[test_case::test_case(to_json_message(Message::Text); "json UTF-8")]
465    #[test_case::test_case(to_brotli_message; "brotli")]
466    #[tokio::test]
467    async fn test_stream_decodes_messages_successfully(
468        to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
469    ) {
470        let flashblocks = [flashblock()];
471        let connector = FakeConnector::from(flashblocks.iter().map(to_message));
472        let ws_url = "http://localhost".parse().unwrap();
473        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
474
475        let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
476        let expected_messages = flashblocks.to_vec();
477
478        assert_eq!(actual_messages, expected_messages);
479    }
480
481    #[test_case::test_case(Message::Pong(Bytes::from(b"test".as_slice())); "pong")]
482    #[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
483    #[tokio::test]
484    async fn test_stream_ignores_unexpected_message(message: Message) {
485        let flashblock = flashblock();
486        let connector = FakeConnector::from([Ok(message), to_json_binary_message(&flashblock)]);
487        let ws_url = "http://localhost".parse().unwrap();
488        let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
489
490        let expected_message = flashblock;
491        let actual_message =
492            stream.next().await.expect("Binary message should not be ignored").unwrap();
493
494        assert_eq!(actual_message, expected_message)
495    }
496
497    #[tokio::test]
498    async fn test_stream_passes_errors_through() {
499        let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
500        let ws_url = "http://localhost".parse().unwrap();
501        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
502
503        let actual_messages: Vec<_> =
504            stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
505        let expected_messages = vec!["Attack attempt detected".to_owned()];
506
507        assert_eq!(actual_messages, expected_messages);
508    }
509
510    #[tokio::test]
511    async fn test_connect_error_causes_retries() {
512        let tries = 3;
513        let error_msg = "test".to_owned();
514        let connector = FailingConnector(error_msg.clone());
515        let ws_url = "http://localhost".parse().unwrap();
516        let stream = WsFlashBlockStream::with_connector(ws_url, connector);
517
518        let actual_errors: Vec<_> =
519            stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
520        let expected_errors: Vec<_> = iter::repeat_n(error_msg, tries).collect();
521
522        assert_eq!(actual_errors, expected_errors);
523    }
524
525    #[test_case::test_case(
526        Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() })),
527        Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() }));
528        "close"
529    )]
530    #[test_case::test_case(
531        Message::Ping(Bytes::from_static(&[1u8, 2, 3])),
532        Message::Pong(Bytes::from_static(&[1u8, 2, 3]));
533        "ping"
534    )]
535    #[tokio::test]
536    async fn test_stream_responds_to_messages(msg: Message, expected_response: Message) {
537        let flashblock = flashblock();
538        let messages = [Ok(msg), to_json_binary_message(&flashblock)];
539        let connector = FakeConnectorWithSink::from(messages);
540        let ws_url = "http://localhost".parse().unwrap();
541        let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
542
543        let _ = stream.next().await;
544
545        let expected_response = vec![expected_response];
546        let FakeSink(actual_buffer, actual_response) = stream.sink.unwrap();
547
548        assert!(actual_buffer.is_none(), "buffer not flushed: {actual_buffer:#?}");
549        assert_eq!(actual_response, expected_response);
550    }
551}