1use crate::{ws::FlashBlockDecoder, FlashBlock};
2use futures_util::{
3 stream::{SplitSink, SplitStream},
4 FutureExt, Sink, Stream, StreamExt,
5};
6use std::{
7 fmt::{Debug, Formatter},
8 future::Future,
9 pin::Pin,
10 task::{ready, Context, Poll},
11};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{
14 connect_async,
15 tungstenite::{protocol::CloseFrame, Bytes, Error, Message},
16 MaybeTlsStream, WebSocketStream,
17};
18use tracing::debug;
19use url::Url;
20
21pub struct WsFlashBlockStream<Stream, Sink, Connector> {
28 ws_url: Url,
29 state: State,
30 connector: Connector,
31 decoder: Box<dyn FlashBlockDecoder>,
32 connect: ConnectFuture<Sink, Stream>,
33 stream: Option<Stream>,
34 sink: Option<Sink>,
35}
36
37impl WsFlashBlockStream<WsStream, WsSink, WsConnector> {
38 pub fn new(ws_url: Url) -> Self {
40 Self {
41 ws_url,
42 state: State::default(),
43 connector: WsConnector,
44 decoder: Box::new(()),
45 connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
46 stream: None,
47 sink: None,
48 }
49 }
50
51 pub fn with_decoder(self, decoder: Box<dyn FlashBlockDecoder>) -> Self {
53 Self { decoder, ..self }
54 }
55}
56
57impl<Stream, S, C> WsFlashBlockStream<Stream, S, C> {
58 pub fn with_connector(ws_url: Url, connector: C) -> Self {
60 Self {
61 ws_url,
62 state: State::default(),
63 decoder: Box::new(()),
64 connector,
65 connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
66 stream: None,
67 sink: None,
68 }
69 }
70}
71
72impl<Str, S, C> Stream for WsFlashBlockStream<Str, S, C>
73where
74 Str: Stream<Item = Result<Message, Error>> + Unpin,
75 S: Sink<Message> + Send + Unpin,
76 C: WsConnect<Stream = Str, Sink = S> + Clone + Send + 'static + Unpin,
77{
78 type Item = eyre::Result<FlashBlock>;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let this = self.get_mut();
82
83 'start: loop {
84 if this.state == State::Initial {
85 this.connect();
86 }
87
88 if this.state == State::Connect {
89 match ready!(this.connect.poll_unpin(cx)) {
90 Ok((sink, stream)) => this.stream(sink, stream),
91 Err(err) => {
92 this.state = State::Initial;
93
94 return Poll::Ready(Some(Err(err)));
95 }
96 }
97 }
98
99 while let State::Stream(msg) = &mut this.state {
100 if msg.is_some() {
101 let mut sink = Pin::new(this.sink.as_mut().unwrap());
102 let _ = ready!(sink.as_mut().poll_ready(cx));
103 if let Some(pong) = msg.take() {
104 let _ = sink.as_mut().start_send(pong);
105 }
106 let _ = ready!(sink.as_mut().poll_flush(cx));
107 }
108
109 let Some(msg) = ready!(this
110 .stream
111 .as_mut()
112 .expect("Stream state should be unreachable without stream")
113 .poll_next_unpin(cx))
114 else {
115 this.state = State::Initial;
116
117 continue 'start;
118 };
119
120 match msg {
121 Ok(Message::Binary(bytes)) => {
122 return Poll::Ready(Some(this.decoder.decode(bytes)))
123 }
124 Ok(Message::Text(bytes)) => {
125 return Poll::Ready(Some(this.decoder.decode(bytes.into())))
126 }
127 Ok(Message::Ping(bytes)) => this.ping(bytes),
128 Ok(Message::Close(frame)) => this.close(frame),
129 Ok(msg) => {
130 debug!(target: "flashblocks", "Received unexpected message: {:?}", msg)
131 }
132 Err(err) => return Poll::Ready(Some(Err(err.into()))),
133 }
134 }
135 }
136 }
137}
138
139impl<Stream, S, C> WsFlashBlockStream<Stream, S, C>
140where
141 C: WsConnect<Stream = Stream, Sink = S> + Clone + Send + 'static,
142{
143 fn connect(&mut self) {
144 let ws_url = self.ws_url.clone();
145 let mut connector = self.connector.clone();
146
147 Pin::new(&mut self.connect).set(Box::pin(async move { connector.connect(ws_url).await }));
148
149 self.state = State::Connect;
150 }
151
152 fn stream(&mut self, sink: S, stream: Stream) {
153 self.sink.replace(sink);
154 self.stream.replace(stream);
155
156 self.state = State::Stream(None);
157 }
158
159 fn ping(&mut self, pong: Bytes) {
160 if let State::Stream(current) = &mut self.state {
161 current.replace(Message::Pong(pong));
162 }
163 }
164
165 fn close(&mut self, frame: Option<CloseFrame>) {
166 if let State::Stream(current) = &mut self.state {
167 current.replace(Message::Close(frame));
168 }
169 }
170}
171
172impl<Stream: Debug, S: Debug, C: Debug> Debug for WsFlashBlockStream<Stream, S, C> {
173 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
174 f.debug_struct("FlashBlockStream")
175 .field("ws_url", &self.ws_url)
176 .field("state", &self.state)
177 .field("connector", &self.connector)
178 .field("connect", &"Pin<Box<dyn Future<..>>>")
179 .field("stream", &self.stream)
180 .finish()
181 }
182}
183
184#[derive(Default, Debug, Eq, PartialEq)]
185enum State {
186 #[default]
187 Initial,
188 Connect,
189 Stream(Option<Message>),
190}
191
192type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
193type WsStream = SplitStream<Ws>;
194type WsSink = SplitSink<Ws, Message>;
195type ConnectFuture<Sink, Stream> =
196 Pin<Box<dyn Future<Output = eyre::Result<(Sink, Stream)>> + Send + 'static>>;
197
198pub trait WsConnect {
208 type Stream;
210
211 type Sink;
213
214 fn connect(
218 &mut self,
219 ws_url: Url,
220 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send;
221}
222
223#[derive(Debug, Clone)]
227pub struct WsConnector;
228
229impl WsConnect for WsConnector {
230 type Stream = WsStream;
231 type Sink = WsSink;
232
233 async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> {
234 let (stream, _response) = connect_async(ws_url.as_str()).await?;
235
236 Ok(stream.split())
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use alloy_primitives::bytes::Bytes;
244 use brotli::enc::BrotliEncoderParams;
245 use std::{future, iter};
246 use tokio_tungstenite::tungstenite::{
247 protocol::frame::{coding::CloseCode, Frame},
248 Error,
249 };
250
251 #[derive(Clone)]
255 struct FakeConnector(FakeStream);
256
257 #[derive(Clone)]
262 struct FakeConnectorWithSink(FakeStream);
263
264 #[derive(Default)]
266 struct FakeStream(Vec<Result<Message, Error>>);
267
268 impl FakeStream {
269 fn new(mut messages: Vec<Result<Message, Error>>) -> Self {
270 messages.reverse();
271
272 Self(messages)
273 }
274 }
275
276 impl Clone for FakeStream {
277 fn clone(&self) -> Self {
278 Self(
279 self.0
280 .iter()
281 .map(|v| match v {
282 Ok(msg) => Ok(msg.clone()),
283 Err(err) => Err(match err {
284 Error::AttackAttempt => Error::AttackAttempt,
285 err => unimplemented!("Cannot clone this error: {err}"),
286 }),
287 })
288 .collect(),
289 )
290 }
291 }
292
293 impl Stream for FakeStream {
294 type Item = Result<Message, Error>;
295
296 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297 let this = self.get_mut();
298
299 Poll::Ready(this.0.pop())
300 }
301 }
302
303 #[derive(Clone)]
304 struct NoopSink;
305
306 impl<T> Sink<T> for NoopSink {
307 type Error = ();
308
309 fn poll_ready(
310 self: Pin<&mut Self>,
311 _cx: &mut Context<'_>,
312 ) -> Poll<Result<(), Self::Error>> {
313 unimplemented!()
314 }
315
316 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
317 unimplemented!()
318 }
319
320 fn poll_flush(
321 self: Pin<&mut Self>,
322 _cx: &mut Context<'_>,
323 ) -> Poll<Result<(), Self::Error>> {
324 unimplemented!()
325 }
326
327 fn poll_close(
328 self: Pin<&mut Self>,
329 _cx: &mut Context<'_>,
330 ) -> Poll<Result<(), Self::Error>> {
331 unimplemented!()
332 }
333 }
334
335 #[derive(Clone, Default)]
338 struct FakeSink(Option<Message>, Vec<Message>);
339
340 impl Sink<Message> for FakeSink {
341 type Error = ();
342
343 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344 self.poll_flush(cx)
345 }
346
347 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
348 self.get_mut().0.replace(item);
349 Ok(())
350 }
351
352 fn poll_flush(
353 self: Pin<&mut Self>,
354 _cx: &mut Context<'_>,
355 ) -> Poll<Result<(), Self::Error>> {
356 let this = self.get_mut();
357 if let Some(item) = this.0.take() {
358 this.1.push(item);
359 }
360 Poll::Ready(Ok(()))
361 }
362
363 fn poll_close(
364 self: Pin<&mut Self>,
365 _cx: &mut Context<'_>,
366 ) -> Poll<Result<(), Self::Error>> {
367 Poll::Ready(Ok(()))
368 }
369 }
370
371 impl WsConnect for FakeConnector {
372 type Stream = FakeStream;
373 type Sink = NoopSink;
374
375 fn connect(
376 &mut self,
377 _ws_url: Url,
378 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
379 future::ready(Ok((NoopSink, self.0.clone())))
380 }
381 }
382
383 impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnector {
384 fn from(value: T) -> Self {
385 Self(FakeStream::new(value.into_iter().collect()))
386 }
387 }
388
389 impl WsConnect for FakeConnectorWithSink {
390 type Stream = FakeStream;
391 type Sink = FakeSink;
392
393 fn connect(
394 &mut self,
395 _ws_url: Url,
396 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
397 future::ready(Ok((FakeSink::default(), self.0.clone())))
398 }
399 }
400
401 impl<T: IntoIterator<Item = Result<Message, Error>>> From<T> for FakeConnectorWithSink {
402 fn from(value: T) -> Self {
403 Self(FakeStream::new(value.into_iter().collect()))
404 }
405 }
406
407 #[derive(Clone)]
409 struct FailingConnector(String);
410
411 impl WsConnect for FailingConnector {
412 type Stream = FakeStream;
413 type Sink = NoopSink;
414
415 fn connect(
416 &mut self,
417 _ws_url: Url,
418 ) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send {
419 future::ready(Err(eyre::eyre!("{}", &self.0)))
420 }
421 }
422
423 fn to_json_message<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
424 wrapper_f: F,
425 ) -> impl Fn(&FlashBlock) -> Result<Message, Error> + use<F, B> {
426 move |block| to_json_message_using(block, &wrapper_f)
427 }
428
429 fn to_json_binary_message(block: &FlashBlock) -> Result<Message, Error> {
430 to_json_message_using(block, Message::Binary)
431 }
432
433 fn to_json_message_using<B: TryFrom<Bytes, Error: Debug>, F: Fn(B) -> Message>(
434 block: &FlashBlock,
435 wrapper_f: F,
436 ) -> Result<Message, Error> {
437 Ok(wrapper_f(B::try_from(Bytes::from(serde_json::to_vec(block).unwrap())).unwrap()))
438 }
439
440 fn to_brotli_message(block: &FlashBlock) -> Result<Message, Error> {
441 let json = serde_json::to_vec(block).unwrap();
442 let mut compressed = Vec::new();
443 brotli::BrotliCompress(
444 &mut json.as_slice(),
445 &mut compressed,
446 &BrotliEncoderParams::default(),
447 )?;
448
449 Ok(Message::Binary(Bytes::from(compressed)))
450 }
451
452 fn flashblock() -> FlashBlock {
453 Default::default()
454 }
455
456 #[test_case::test_case(to_json_message(Message::Binary); "json binary")]
457 #[test_case::test_case(to_json_message(Message::Text); "json UTF-8")]
458 #[test_case::test_case(to_brotli_message; "brotli")]
459 #[tokio::test]
460 async fn test_stream_decodes_messages_successfully(
461 to_message: impl Fn(&FlashBlock) -> Result<Message, Error>,
462 ) {
463 let flashblocks = [flashblock()];
464 let connector = FakeConnector::from(flashblocks.iter().map(to_message));
465 let ws_url = "http://localhost".parse().unwrap();
466 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
467
468 let actual_messages: Vec<_> = stream.take(1).map(Result::unwrap).collect().await;
469 let expected_messages = flashblocks.to_vec();
470
471 assert_eq!(actual_messages, expected_messages);
472 }
473
474 #[test_case::test_case(Message::Pong(Bytes::from(b"test".as_slice())); "pong")]
475 #[test_case::test_case(Message::Frame(Frame::pong(b"test".as_slice())); "frame")]
476 #[tokio::test]
477 async fn test_stream_ignores_unexpected_message(message: Message) {
478 let flashblock = flashblock();
479 let connector = FakeConnector::from([Ok(message), to_json_binary_message(&flashblock)]);
480 let ws_url = "http://localhost".parse().unwrap();
481 let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
482
483 let expected_message = flashblock;
484 let actual_message =
485 stream.next().await.expect("Binary message should not be ignored").unwrap();
486
487 assert_eq!(actual_message, expected_message)
488 }
489
490 #[tokio::test]
491 async fn test_stream_passes_errors_through() {
492 let connector = FakeConnector::from([Err(Error::AttackAttempt)]);
493 let ws_url = "http://localhost".parse().unwrap();
494 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
495
496 let actual_messages: Vec<_> =
497 stream.take(1).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
498 let expected_messages = vec!["Attack attempt detected".to_owned()];
499
500 assert_eq!(actual_messages, expected_messages);
501 }
502
503 #[tokio::test]
504 async fn test_connect_error_causes_retries() {
505 let tries = 3;
506 let error_msg = "test".to_owned();
507 let connector = FailingConnector(error_msg.clone());
508 let ws_url = "http://localhost".parse().unwrap();
509 let stream = WsFlashBlockStream::with_connector(ws_url, connector);
510
511 let actual_errors: Vec<_> =
512 stream.take(tries).map(Result::unwrap_err).map(|e| format!("{e}")).collect().await;
513 let expected_errors: Vec<_> = iter::repeat_n(error_msg, tries).collect();
514
515 assert_eq!(actual_errors, expected_errors);
516 }
517
518 #[test_case::test_case(
519 Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() })),
520 Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: "test".into() }));
521 "close"
522 )]
523 #[test_case::test_case(
524 Message::Ping(Bytes::from_static(&[1u8, 2, 3])),
525 Message::Pong(Bytes::from_static(&[1u8, 2, 3]));
526 "ping"
527 )]
528 #[tokio::test]
529 async fn test_stream_responds_to_messages(msg: Message, expected_response: Message) {
530 let flashblock = flashblock();
531 let messages = [Ok(msg), to_json_binary_message(&flashblock)];
532 let connector = FakeConnectorWithSink::from(messages);
533 let ws_url = "http://localhost".parse().unwrap();
534 let mut stream = WsFlashBlockStream::with_connector(ws_url, connector);
535
536 let _ = stream.next().await;
537
538 let expected_response = vec![expected_response];
539 let FakeSink(actual_buffer, actual_response) = stream.sink.unwrap();
540
541 assert!(actual_buffer.is_none(), "buffer not flushed: {actual_buffer:#?}");
542 assert_eq!(actual_response, expected_response);
543 }
544}