1use crate::FlashBlock;
2use futures_util::{
3 stream::{SplitSink, SplitStream},
4 FutureExt, Sink, Stream, StreamExt,
5};
6use std::{
7 fmt::{Debug, Formatter},
8 future::Future,
9 pin::Pin,
10 task::{ready, Context, Poll},
11};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{
14 connect_async,
15 tungstenite::{protocol::CloseFrame, Bytes, Error, Message},
16 MaybeTlsStream, WebSocketStream,
17};
18use tracing::debug;
19use url::Url;
20
21pub struct WsFlashBlockStream<Stream, Sink, Connector> {
28 ws_url: Url,
29 state: State,
30 connector: Connector,
31 connect: ConnectFuture<Sink, Stream>,
32 stream: Option<Stream>,
33 sink: Option<Sink>,
34}
35
36impl WsFlashBlockStream<WsStream, WsSink, WsConnector> {
37 pub fn new(ws_url: Url) -> Self {
39 Self {
40 ws_url,
41 state: State::default(),
42 connector: WsConnector,
43 connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
44 stream: None,
45 sink: None,
46 }
47 }
48}
49
50impl<Stream, S, C> WsFlashBlockStream<Stream, S, C> {
51 pub fn with_connector(ws_url: Url, connector: C) -> Self {
53 Self {
54 ws_url,
55 state: State::default(),
56 connector,
57 connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
58 stream: None,
59 sink: None,
60 }
61 }
62}
63
64impl<Str, S, C> Stream for WsFlashBlockStream<Str, S, C>
65where
66 Str: Stream<Item = Result<Message, Error>> + Unpin,
67 S: Sink<Message> + Send + Sync + Unpin,
68 C: WsConnect<Stream = Str, Sink = S> + Clone + Send + Sync + 'static + Unpin,
69{
70 type Item = eyre::Result<FlashBlock>;
71
72 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
73 let this = self.get_mut();
74
75 'start: loop {
76 if this.state == State::Initial {
77 this.connect();
78 }
79
80 if this.state == State::Connect {
81 match ready!(this.connect.poll_unpin(cx)) {
82 Ok((sink, stream)) => this.stream(sink, stream),
83 Err(err) => {
84 this.state = State::Initial;
85
86 return Poll::Ready(Some(Err(err)));
87 }
88 }
89 }
90
91 while let State::Stream(msg) = &mut this.state {
92 if msg.is_some() {
93 let mut sink = Pin::new(this.sink.as_mut().unwrap());
94 let _ = ready!(sink.as_mut().poll_ready(cx));
95 if let Some(pong) = msg.take() {
96 let _ = sink.as_mut().start_send(pong);
97 }
98 let _ = ready!(sink.as_mut().poll_flush(cx));
99 }
100
101 let Some(msg) = ready!(this
102 .stream
103 .as_mut()
104 .expect("Stream state should be unreachable without stream")
105 .poll_next_unpin(cx))
106 else {
107 this.state = State::Initial;
108
109 continue 'start;
110 };
111
112 match msg {
113 Ok(Message::Binary(bytes)) => {
114 return Poll::Ready(Some(FlashBlock::decode(bytes)))
115 }
116 Ok(Message::Text(bytes)) => {
117 return Poll::Ready(Some(FlashBlock::decode(bytes.into())))
118 }
119 Ok(Message::Ping(bytes)) => this.ping(bytes),
120 Ok(Message::Close(frame)) => this.close(frame),
121 Ok(msg) => debug!("Received unexpected message: {:?}", msg),
122 Err(err) => return Poll::Ready(Some(Err(err.into()))),
123 }
124 }
125 }
126 }
127}
128
129impl<Stream, S, C> WsFlashBlockStream<Stream, S, C>
130where
131 C: WsConnect<Stream = Stream, Sink = S> + Clone + Send + Sync + 'static,
132{
133 fn connect(&mut self) {
134 let ws_url = self.ws_url.clone();
135 let mut connector = self.connector.clone();
136
137 Pin::new(&mut self.connect).set(Box::pin(async move { connector.connect(ws_url).await }));
138
139 self.state = State::Connect;
140 }
141
142 fn stream(&mut self, sink: S, stream: Stream) {
143 self.sink.replace(sink);
144 self.stream.replace(stream);
145
146 self.state = State::Stream(None);
147 }
148
149 fn ping(&mut self, pong: Bytes) {
150 if let State::Stream(current) = &mut self.state {
151 current.replace(Message::Pong(pong));
152 }
153 }
154
155 fn close(&mut self, frame: Option<CloseFrame>) {
156 if let State::Stream(current) = &mut self.state {
157 current.replace(Message::Close(frame));
158 }
159 }
160}
161
162impl<Stream: Debug, S: Debug, C: Debug> Debug for WsFlashBlockStream<Stream, S, C> {
163 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("FlashBlockStream")
165 .field("ws_url", &self.ws_url)
166 .field("state", &self.state)
167 .field("connector", &self.connector)
168 .field("connect", &"Pin<Box<dyn Future<..>>>")
169 .field("stream", &self.stream)
170 .finish()
171 }
172}
173
174#[derive(Default, Debug, Eq, PartialEq)]
175enum State {
176 #[default]
177 Initial,
178 Connect,
179 Stream(Option<Message>),
180}
181
182type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
183type WsStream = SplitStream<Ws>;
184type WsSink = SplitSink<Ws, Message>;
185type ConnectFuture<Sink, Stream> =
186 Pin<Box<dyn Future<Output = eyre::Result<(Sink, Stream)>> + Send + Sync + 'static>>;
187
188pub trait WsConnect {
198 type Stream;
200
201 type Sink;
203
204 fn connect(
208 &mut self,
209 ws_url: Url,
210 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync;
211}
212
213#[derive(Debug, Clone)]
217pub struct WsConnector;
218
219impl WsConnect for WsConnector {
220 type Stream = WsStream;
221 type Sink = WsSink;
222
223 async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> {
224 let (stream, _response) = connect_async(ws_url.as_str()).await?;
225
226 Ok(stream.split())
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::ExecutionPayloadBaseV1;
234 use alloy_primitives::bytes::Bytes;
235 use brotli::enc::BrotliEncoderParams;
236 use std::{future, iter};
237 use tokio_tungstenite::tungstenite::{
238 protocol::frame::{coding::CloseCode, Frame},
239 Error,
240 };
241
242 #[derive(Clone)]
246 struct FakeConnector(FakeStream);
247
248 #[derive(Clone)]
253 struct FakeConnectorWithSink(FakeStream);
254
255 #[derive(Default)]
257 struct FakeStream(Vec<Result<Message, Error>>);
258
259 impl FakeStream {
260 fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
261 messages.reverse();
262
263 Self(messages)
264 }
265 }
266
267 impl Clone for FakeStream {
268 fn clone(&self) -> Self {
269 Self(
270 self.0
271 .iter()
272 .map(|v| match v {
273 Ok(msg) => Ok(msg.clone()),
274 Err(err) => Err(match err {
275 Error::AttackAttempt => Error::AttackAttempt,
276 err => unimplemented!("Cannot clone this error: {err}"),
277 }),
278 })
279 .collect(),
280 )
281 }
282 }
283
284 impl Stream for FakeStream {
285 type Item = Result<Message, Error>;
286
287 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
288 let this = self.get_mut();
289
290 Poll::Ready(this.0.pop())
291 }
292 }
293
294 #[derive(Clone)]
295 struct NoopSink;
296
297 impl<T> Sink<T> for NoopSink {
298 type Error = ();
299
300 fn poll_ready(
301 self: Pin<&mut Self>,
302 _cx: &mut Context<'_>,
303 ) -> Poll<Result<(), Self::Error>> {
304 unimplemented!()
305 }
306
307 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
308 unimplemented!()
309 }
310
311 fn poll_flush(
312 self: Pin<&mut Self>,
313 _cx: &mut Context<'_>,
314 ) -> Poll<Result<(), Self::Error>> {
315 unimplemented!()
316 }
317
318 fn poll_close(
319 self: Pin<&mut Self>,
320 _cx: &mut Context<'_>,
321 ) -> Poll<Result<(), Self::Error>> {
322 unimplemented!()
323 }
324 }
325
326 #[derive(Clone, Default)]
329 struct FakeSink(Option<Message>, Vec<Message>);
330
331 impl Sink<Message> for FakeSink {
332 type Error = ();
333
334 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
335 self.poll_flush(cx)
336 }
337
338 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
339 self.get_mut().0.replace(item);
340 Ok(())
341 }
342
343 fn poll_flush(
344 self: Pin<&mut Self>,
345 _cx: &mut Context<'_>,
346 ) -> Poll<Result<(), Self::Error>> {
347 let this = self.get_mut();
348 if let Some(item) = this.0.take() {
349 this.1.push(item);
350 }
351 Poll::Ready(Ok(()))
352 }
353
354 fn poll_close(
355 self: Pin<&mut Self>,
356 _cx: &mut Context<'_>,
357 ) -> Poll<Result<(), Self::Error>> {
358 Poll::Ready(Ok(()))
359 }
360 }
361
362 impl WsConnect for FakeConnector {
363 type Stream = FakeStream;
364 type Sink = NoopSink;
365
366 fn connect(
367 &mut self,
368 _ws_url: Url,
369 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
370 future::ready(Ok((NoopSink, self.0.clone())))
371 }
372 }
373
374 impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
375 fn from(value: T) -> Self {
376 Self(FakeStream::new(value.into_iter().collect()))
377 }
378 }
379
380 impl WsConnect for FakeConnectorWithSink {
381 type Stream = FakeStream;
382 type Sink = FakeSink;
383
384 fn connect(
385 &mut self,
386 _ws_url: Url,
387 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
388 future::ready(Ok((FakeSink::default(), self.0.clone())))
389 }
390 }
391
392 impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
393 fn from(value: T) -> Self {
394 Self(FakeStream::new(value.into_iter().collect()))
395 }
396 }
397
398 #[derive(Clone)]
400 struct FailingConnector(String);
401
402 impl WsConnect for FailingConnector {
403 type Stream = FakeStream;
404 type Sink = NoopSink;
405
406 fn connect(
407 &mut self,
408 _ws_url: Url,
409 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
410 future::ready(Err(eyre::eyre!("{}", &self.0)))
411 }
412 }
413
414 fn to_json_message<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
415 wrapper_f: F,
416 ) -> impl Fn(&FlashBlock) -> Result<Message, Error> + use<F, B> {
417 move |block| to_json_message_using(block, &wrapper_f)
418 }
419
420 fn to_json_binary_message(block: &FlashBlock) -> Result<Message, Error> {
421 to_json_message_using(block, Message::Binary)
422 }
423
424 fn to_json_message_using<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
425 block: &FlashBlock,
426 wrapper_f: F,
427 ) -> Result<Message, Error> {
428 Ok(wrapper_f(B::try_from(Bytes::from(serde_json::to_vec(block).unwrap())).unwrap()))
429 }
430
431 fn to_brotli_message(block: &FlashBlock) -> Result<Message, Error> {
432 let json = serde_json::to_vec(block).unwrap();
433 let mut compressed = Vec::new();
434 brotli::BrotliCompress(
435 &mut json.as_slice(),
436 &mut compressed,
437 &BrotliEncoderParams::default(),
438 )?;
439
440 Ok(Message::Binary(Bytes::from(compressed)))
441 }
442
443 fn flashblock() -> FlashBlock {
444 FlashBlock {
445 payload_id: Default::default(),
446 index: 0,
447 base: Some(ExecutionPayloadBaseV1 {
448 parent_beacon_block_root: Default::default(),
449 parent_hash: Default::default(),
450 fee_recipient: Default::default(),
451 prev_randao: Default::default(),
452 block_number: 0,
453 gas_limit: 0,
454 timestamp: 0,
455 extra_data: Default::default(),
456 base_fee_per_gas: Default::default(),
457 }),
458 diff: Default::default(),
459 metadata: Default::default(),
460 }
461 }
462
463 #[test_case::test_case(to_json_message(Message::Binary); "json binary")]
464 #[test_case::test_case(to_json_message(Message::Text); "json UTF-8")]
465 #[test_case::test_case(to_brotli_message; "brotli")]
466 #[tokio::test]
467 async fn test_stream_decodes_messages_successfully(
468 to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
469 ) {
470 let flashblocks = [flashblock()];
471 let connector = FakeConnector::from(flashblocks.iter().map(to_message));
472 let ws_url = "http://localhost".parse().unwrap();
473 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
474
475 let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
476 let expected_messages = flashblocks.to_vec();
477
478 assert_eq!(actual_messages, expected_messages);
479 }
480
481 #[test_case::test_case(Message::Pong(Bytes::from(b"test".as_slice())); "pong")]
482 #[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
483 #[tokio::test]
484 async fn test_stream_ignores_unexpected_message(message: Message) {
485 let flashblock = flashblock();
486 let connector = FakeConnector::from([Ok(message), to_json_binary_message(&flashblock)]);
487 let ws_url = "http://localhost".parse().unwrap();
488 let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
489
490 let expected_message = flashblock;
491 let actual_message =
492 stream.next().await.expect("Binary message should not be ignored").unwrap();
493
494 assert_eq!(actual_message, expected_message)
495 }
496
497 #[tokio::test]
498 async fn test_stream_passes_errors_through() {
499 let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
500 let ws_url = "http://localhost".parse().unwrap();
501 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
502
503 let actual_messages: Vec<_> =
504 stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
505 let expected_messages = vec!["Attack attempt detected".to_owned()];
506
507 assert_eq!(actual_messages, expected_messages);
508 }
509
510 #[tokio::test]
511 async fn test_connect_error_causes_retries() {
512 let tries = 3;
513 let error_msg = "test".to_owned();
514 let connector = FailingConnector(error_msg.clone());
515 let ws_url = "http://localhost".parse().unwrap();
516 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
517
518 let actual_errors: Vec<_> =
519 stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
520 let expected_errors: Vec<_> = iter::repeat_n(error_msg, tries).collect();
521
522 assert_eq!(actual_errors, expected_errors);
523 }
524
525 #[test_case::test_case(
526 Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() })),
527 Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() }));
528 "close"
529 )]
530 #[test_case::test_case(
531 Message::Ping(Bytes::from_static(&[1u8, 2, 3])),
532 Message::Pong(Bytes::from_static(&[1u8, 2, 3]));
533 "ping"
534 )]
535 #[tokio::test]
536 async fn test_stream_responds_to_messages(msg: Message, expected_response: Message) {
537 let flashblock = flashblock();
538 let messages = [Ok(msg), to_json_binary_message(&flashblock)];
539 let connector = FakeConnectorWithSink::from(messages);
540 let ws_url = "http://localhost".parse().unwrap();
541 let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
542
543 let _ = stream.next().await;
544
545 let expected_response = vec![expected_response];
546 let FakeSink(actual_buffer, actual_response) = stream.sink.unwrap();
547
548 assert!(actual_buffer.is_none(), "buffer not flushed: {actual_buffer:#?}");
549 assert_eq!(actual_response, expected_response);
550 }
551}