reth_optimism_flashblocks/ws/
stream.rs1use crate::{FlashBlock, FlashBlockDecoder};
2use futures_util::{
3 stream::{SplitSink, SplitStream},
4 FutureExt, Sink, Stream, StreamExt,
5};
6use std::{
7 fmt::{Debug, Formatter},
8 future::Future,
9 pin::Pin,
10 task::{ready, Context, Poll},
11};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{
14 connect_async,
15 tungstenite::{protocol::CloseFrame, Bytes, Error, Message},
16 MaybeTlsStream, WebSocketStream,
17};
18use tracing::debug;
19use url::Url;
20
21pub struct WsFlashBlockStream<Stream, Sink, Connector> {
28 ws_url: Url,
29 state: State,
30 connector: Connector,
31 decoder: Box<dyn FlashBlockDecoder>,
32 connect: ConnectFuture<Sink, Stream>,
33 stream: Option<Stream>,
34 sink: Option<Sink>,
35}
36
37impl WsFlashBlockStream<WsStream, WsSink, WsConnector> {
38 pub fn new(ws_url: Url) -> Self {
40 Self {
41 ws_url,
42 state: State::default(),
43 connector: WsConnector,
44 decoder: Box::new(()),
45 connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
46 stream: None,
47 sink: None,
48 }
49 }
50
51 pub fn with_decoder(self, decoder: Box<dyn FlashBlockDecoder>) -> Self {
53 Self { decoder, ..self }
54 }
55}
56
57impl<Stream, S, C> WsFlashBlockStream<Stream, S, C> {
58 pub fn with_connector(ws_url: Url, connector: C) -> Self {
60 Self {
61 ws_url,
62 state: State::default(),
63 decoder: Box::new(()),
64 connector,
65 connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
66 stream: None,
67 sink: None,
68 }
69 }
70}
71
72impl<Str, S, C> Stream for WsFlashBlockStream<Str, S, C>
73where
74 Str: Stream<Item = Result<Message, Error>> + Unpin,
75 S: Sink<Message> + Send + Unpin,
76 C: WsConnect<Stream = Str, Sink = S> + Clone + Send + 'static + Unpin,
77{
78 type Item = eyre::Result<FlashBlock>;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let this = self.get_mut();
82
83 'start: loop {
84 if this.state == State::Initial {
85 this.connect();
86 }
87
88 if this.state == State::Connect {
89 match ready!(this.connect.poll_unpin(cx)) {
90 Ok((sink, stream)) => this.stream(sink, stream),
91 Err(err) => {
92 this.state = State::Initial;
93
94 return Poll::Ready(Some(Err(err)));
95 }
96 }
97 }
98
99 while let State::Stream(msg) = &mut this.state {
100 if msg.is_some() {
101 let mut sink = Pin::new(this.sink.as_mut().unwrap());
102 let _ = ready!(sink.as_mut().poll_ready(cx));
103 if let Some(pong) = msg.take() {
104 let _ = sink.as_mut().start_send(pong);
105 }
106 let _ = ready!(sink.as_mut().poll_flush(cx));
107 }
108
109 let Some(msg) = ready!(this
110 .stream
111 .as_mut()
112 .expect("Stream state should be unreachable without stream")
113 .poll_next_unpin(cx))
114 else {
115 this.state = State::Initial;
116
117 continue 'start;
118 };
119
120 match msg {
121 Ok(Message::Binary(bytes)) => {
122 return Poll::Ready(Some(this.decoder.decode(bytes)))
123 }
124 Ok(Message::Text(bytes)) => {
125 return Poll::Ready(Some(this.decoder.decode(bytes.into())))
126 }
127 Ok(Message::Ping(bytes)) => this.ping(bytes),
128 Ok(Message::Close(frame)) => this.close(frame),
129 Ok(msg) => debug!("Received unexpected message: {:?}", msg),
130 Err(err) => return Poll::Ready(Some(Err(err.into()))),
131 }
132 }
133 }
134 }
135}
136
137impl<Stream, S, C> WsFlashBlockStream<Stream, S, C>
138where
139 C: WsConnect<Stream = Stream, Sink = S> + Clone + Send + 'static,
140{
141 fn connect(&mut self) {
142 let ws_url = self.ws_url.clone();
143 let mut connector = self.connector.clone();
144
145 Pin::new(&mut self.connect).set(Box::pin(async move { connector.connect(ws_url).await }));
146
147 self.state = State::Connect;
148 }
149
150 fn stream(&mut self, sink: S, stream: Stream) {
151 self.sink.replace(sink);
152 self.stream.replace(stream);
153
154 self.state = State::Stream(None);
155 }
156
157 fn ping(&mut self, pong: Bytes) {
158 if let State::Stream(current) = &mut self.state {
159 current.replace(Message::Pong(pong));
160 }
161 }
162
163 fn close(&mut self, frame: Option<CloseFrame>) {
164 if let State::Stream(current) = &mut self.state {
165 current.replace(Message::Close(frame));
166 }
167 }
168}
169
170impl<Stream: Debug, S: Debug, C: Debug> Debug for WsFlashBlockStream<Stream, S, C> {
171 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct("FlashBlockStream")
173 .field("ws_url", &self.ws_url)
174 .field("state", &self.state)
175 .field("connector", &self.connector)
176 .field("connect", &"Pin<Box<dyn Future<..>>>")
177 .field("stream", &self.stream)
178 .finish()
179 }
180}
181
182#[derive(Default, Debug, Eq, PartialEq)]
183enum State {
184 #[default]
185 Initial,
186 Connect,
187 Stream(Option<Message>),
188}
189
190type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
191type WsStream = SplitStream<Ws>;
192type WsSink = SplitSink<Ws, Message>;
193type ConnectFuture<Sink, Stream> =
194 Pin<Box<dyn Future<Output = eyre::Result<(Sink, Stream)>> + Send + 'static>>;
195
196pub trait WsConnect {
206 type Stream;
208
209 type Sink;
211
212 fn connect(
216 &mut self,
217 ws_url: Url,
218 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send;
219}
220
221#[derive(Debug, Clone)]
225pub struct WsConnector;
226
227impl WsConnect for WsConnector {
228 type Stream = WsStream;
229 type Sink = WsSink;
230
231 async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> {
232 let (stream, _response) = connect_async(ws_url.as_str()).await?;
233
234 Ok(stream.split())
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::ExecutionPayloadBaseV1;
242 use alloy_primitives::bytes::Bytes;
243 use brotli::enc::BrotliEncoderParams;
244 use std::{future, iter};
245 use tokio_tungstenite::tungstenite::{
246 protocol::frame::{coding::CloseCode, Frame},
247 Error,
248 };
249
250 #[derive(Clone)]
254 struct FakeConnector(FakeStream);
255
256 #[derive(Clone)]
261 struct FakeConnectorWithSink(FakeStream);
262
263 #[derive(Default)]
265 struct FakeStream(Vec<Result<Message, Error>>);
266
267 impl FakeStream {
268 fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
269 messages.reverse();
270
271 Self(messages)
272 }
273 }
274
275 impl Clone for FakeStream {
276 fn clone(&self) -> Self {
277 Self(
278 self.0
279 .iter()
280 .map(|v| match v {
281 Ok(msg) => Ok(msg.clone()),
282 Err(err) => Err(match err {
283 Error::AttackAttempt => Error::AttackAttempt,
284 err => unimplemented!("Cannot clone this error: {err}"),
285 }),
286 })
287 .collect(),
288 )
289 }
290 }
291
292 impl Stream for FakeStream {
293 type Item = Result<Message, Error>;
294
295 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296 let this = self.get_mut();
297
298 Poll::Ready(this.0.pop())
299 }
300 }
301
302 #[derive(Clone)]
303 struct NoopSink;
304
305 impl<T> Sink<T> for NoopSink {
306 type Error = ();
307
308 fn poll_ready(
309 self: Pin<&mut Self>,
310 _cx: &mut Context<'_>,
311 ) -> Poll<Result<(), Self::Error>> {
312 unimplemented!()
313 }
314
315 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
316 unimplemented!()
317 }
318
319 fn poll_flush(
320 self: Pin<&mut Self>,
321 _cx: &mut Context<'_>,
322 ) -> Poll<Result<(), Self::Error>> {
323 unimplemented!()
324 }
325
326 fn poll_close(
327 self: Pin<&mut Self>,
328 _cx: &mut Context<'_>,
329 ) -> Poll<Result<(), Self::Error>> {
330 unimplemented!()
331 }
332 }
333
334 #[derive(Clone, Default)]
337 struct FakeSink(Option<Message>, Vec<Message>);
338
339 impl Sink<Message> for FakeSink {
340 type Error = ();
341
342 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
343 self.poll_flush(cx)
344 }
345
346 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
347 self.get_mut().0.replace(item);
348 Ok(())
349 }
350
351 fn poll_flush(
352 self: Pin<&mut Self>,
353 _cx: &mut Context<'_>,
354 ) -> Poll<Result<(), Self::Error>> {
355 let this = self.get_mut();
356 if let Some(item) = this.0.take() {
357 this.1.push(item);
358 }
359 Poll::Ready(Ok(()))
360 }
361
362 fn poll_close(
363 self: Pin<&mut Self>,
364 _cx: &mut Context<'_>,
365 ) -> Poll<Result<(), Self::Error>> {
366 Poll::Ready(Ok(()))
367 }
368 }
369
370 impl WsConnect for FakeConnector {
371 type Stream = FakeStream;
372 type Sink = NoopSink;
373
374 fn connect(
375 &mut self,
376 _ws_url: Url,
377 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
378 future::ready(Ok((NoopSink, self.0.clone())))
379 }
380 }
381
382 impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
383 fn from(value: T) -> Self {
384 Self(FakeStream::new(value.into_iter().collect()))
385 }
386 }
387
388 impl WsConnect for FakeConnectorWithSink {
389 type Stream = FakeStream;
390 type Sink = FakeSink;
391
392 fn connect(
393 &mut self,
394 _ws_url: Url,
395 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
396 future::ready(Ok((FakeSink::default(), self.0.clone())))
397 }
398 }
399
400 impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
401 fn from(value: T) -> Self {
402 Self(FakeStream::new(value.into_iter().collect()))
403 }
404 }
405
406 #[derive(Clone)]
408 struct FailingConnector(String);
409
410 impl WsConnect for FailingConnector {
411 type Stream = FakeStream;
412 type Sink = NoopSink;
413
414 fn connect(
415 &mut self,
416 _ws_url: Url,
417 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
418 future::ready(Err(eyre::eyre!("{}", &self.0)))
419 }
420 }
421
422 fn to_json_message<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
423 wrapper_f: F,
424 ) -> impl Fn(&FlashBlock) -> Result<Message, Error> + use<F, B> {
425 move |block| to_json_message_using(block, &wrapper_f)
426 }
427
428 fn to_json_binary_message(block: &FlashBlock) -> Result<Message, Error> {
429 to_json_message_using(block, Message::Binary)
430 }
431
432 fn to_json_message_using<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
433 block: &FlashBlock,
434 wrapper_f: F,
435 ) -> Result<Message, Error> {
436 Ok(wrapper_f(B::try_from(Bytes::from(serde_json::to_vec(block).unwrap())).unwrap()))
437 }
438
439 fn to_brotli_message(block: &FlashBlock) -> Result<Message, Error> {
440 let json = serde_json::to_vec(block).unwrap();
441 let mut compressed = Vec::new();
442 brotli::BrotliCompress(
443 &mut json.as_slice(),
444 &mut compressed,
445 &BrotliEncoderParams::default(),
446 )?;
447
448 Ok(Message::Binary(Bytes::from(compressed)))
449 }
450
451 fn flashblock() -> FlashBlock {
452 FlashBlock {
453 payload_id: Default::default(),
454 index: 0,
455 base: Some(ExecutionPayloadBaseV1 {
456 parent_beacon_block_root: Default::default(),
457 parent_hash: Default::default(),
458 fee_recipient: Default::default(),
459 prev_randao: Default::default(),
460 block_number: 0,
461 gas_limit: 0,
462 timestamp: 0,
463 extra_data: Default::default(),
464 base_fee_per_gas: Default::default(),
465 }),
466 diff: Default::default(),
467 metadata: Default::default(),
468 }
469 }
470
471 #[test_case::test_case(to_json_message(Message::Binary); "json binary")]
472 #[test_case::test_case(to_json_message(Message::Text); "json UTF-8")]
473 #[test_case::test_case(to_brotli_message; "brotli")]
474 #[tokio::test]
475 async fn test_stream_decodes_messages_successfully(
476 to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
477 ) {
478 let flashblocks = [flashblock()];
479 let connector = FakeConnector::from(flashblocks.iter().map(to_message));
480 let ws_url = "http://localhost".parse().unwrap();
481 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
482
483 let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
484 let expected_messages = flashblocks.to_vec();
485
486 assert_eq!(actual_messages, expected_messages);
487 }
488
489 #[test_case::test_case(Message::Pong(Bytes::from(b"test".as_slice())); "pong")]
490 #[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
491 #[tokio::test]
492 async fn test_stream_ignores_unexpected_message(message: Message) {
493 let flashblock = flashblock();
494 let connector = FakeConnector::from([Ok(message), to_json_binary_message(&flashblock)]);
495 let ws_url = "http://localhost".parse().unwrap();
496 let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
497
498 let expected_message = flashblock;
499 let actual_message =
500 stream.next().await.expect("Binary message should not be ignored").unwrap();
501
502 assert_eq!(actual_message, expected_message)
503 }
504
505 #[tokio::test]
506 async fn test_stream_passes_errors_through() {
507 let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
508 let ws_url = "http://localhost".parse().unwrap();
509 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
510
511 let actual_messages: Vec<_> =
512 stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
513 let expected_messages = vec!["Attack attempt detected".to_owned()];
514
515 assert_eq!(actual_messages, expected_messages);
516 }
517
518 #[tokio::test]
519 async fn test_connect_error_causes_retries() {
520 let tries = 3;
521 let error_msg = "test".to_owned();
522 let connector = FailingConnector(error_msg.clone());
523 let ws_url = "http://localhost".parse().unwrap();
524 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
525
526 let actual_errors: Vec<_> =
527 stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
528 let expected_errors: Vec<_> = iter::repeat_n(error_msg, tries).collect();
529
530 assert_eq!(actual_errors, expected_errors);
531 }
532
533 #[test_case::test_case(
534 Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() })),
535 Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() }));
536 "close"
537 )]
538 #[test_case::test_case(
539 Message::Ping(Bytes::from_static(&[1u8, 2, 3])),
540 Message::Pong(Bytes::from_static(&[1u8, 2, 3]));
541 "ping"
542 )]
543 #[tokio::test]
544 async fn test_stream_responds_to_messages(msg: Message, expected_response: Message) {
545 let flashblock = flashblock();
546 let messages = [Ok(msg), to_json_binary_message(&flashblock)];
547 let connector = FakeConnectorWithSink::from(messages);
548 let ws_url = "http://localhost".parse().unwrap();
549 let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
550
551 let _ = stream.next().await;
552
553 let expected_response = vec![expected_response];
554 let FakeSink(actual_buffer, actual_response) = stream.sink.unwrap();
555
556 assert!(actual_buffer.is_none(), "buffer not flushed: {actual_buffer:#?}");
557 assert_eq!(actual_response, expected_response);
558 }
559}