1use std::{
11 collections::VecDeque,
12 fmt,
13 future::Future,
14 io,
15 pin::{pin, Pin},
16 sync::Arc,
17 task::{ready, Context, Poll},
18};
19
20use crate::{
21 capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22 errors::{EthStreamError, P2PStreamError},
23 handshake::EthRlpxHandshake,
24 p2pstream::DisconnectP2P,
25 CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26 HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39 inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43 pub fn new(conn: P2PStream<St>) -> Self {
45 Self {
46 inner: MultiplexInner {
47 conn,
48 protocols: Default::default(),
49 out_buffer: Default::default(),
50 },
51 }
52 }
53
54 pub fn install_protocol<F, Proto>(
59 &mut self,
60 cap: &Capability,
61 f: F,
62 ) -> Result<(), UnsupportedCapabilityError>
63 where
64 F: FnOnce(ProtocolConnection) -> Proto,
65 Proto: Stream<Item = BytesMut> + Send + 'static,
66 {
67 self.inner.install_protocol(cap, f)
68 }
69
70 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72 self.inner.shared_capabilities()
73 }
74
75 pub fn into_satellite_stream<F, Primary>(
77 self,
78 cap: &Capability,
79 primary: F,
80 ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81 where
82 F: FnOnce(ProtocolProxy) -> Primary,
83 {
84 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85 else {
86 return Err(P2PStreamError::CapabilityNotShared)
87 };
88
89 let (to_primary, from_wire) = mpsc::unbounded_channel();
90 let (to_wire, from_primary) = mpsc::unbounded_channel();
91 let proxy = ProtocolProxy {
92 shared_cap: shared_cap.clone(),
93 from_wire: UnboundedReceiverStream::new(from_wire),
94 to_wire,
95 };
96
97 let st = primary(proxy);
98 Ok(RlpxSatelliteStream {
99 inner: self.inner,
100 primary: PrimaryProtocol {
101 to_primary,
102 from_primary: UnboundedReceiverStream::new(from_primary),
103 st,
104 shared_cap,
105 },
106 })
107 }
108
109 pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114 self,
115 cap: &Capability,
116 handshake: F,
117 ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118 where
119 F: FnOnce(ProtocolProxy) -> Fut,
120 Fut: Future<Output = Result<Primary, Err>>,
121 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122 P2PStreamError: Into<Err>,
123 {
124 self.into_satellite_stream_with_tuple_handshake(cap, async move |proxy| {
125 let st = handshake(proxy).await?;
126 Ok((st, ()))
127 })
128 .await
129 .map(|(st, _)| st)
130 }
131
132 pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142 mut self,
143 cap: &Capability,
144 handshake: F,
145 ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146 where
147 F: FnOnce(ProtocolProxy) -> Fut,
148 Fut: Future<Output = Result<(Primary, Extra), Err>>,
149 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150 P2PStreamError: Into<Err>,
151 {
152 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153 else {
154 return Err(P2PStreamError::CapabilityNotShared.into())
155 };
156
157 let (to_primary, from_wire) = mpsc::unbounded_channel();
158 let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159 let proxy = ProtocolProxy {
160 shared_cap: shared_cap.clone(),
161 from_wire: UnboundedReceiverStream::new(from_wire),
162 to_wire,
163 };
164
165 let f = handshake(proxy);
166 let mut f = pin!(f);
167
168 loop {
171 tokio::select! {
172 biased;
173 Some(Ok(msg)) = self.inner.conn.next() => {
174 let Some(offset) = msg.first().copied()
176 else {
177 return Err(P2PStreamError::EmptyProtocolMessage.into())
178 };
179 if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180 if cap == shared_cap {
181 let _ = to_primary.send(msg);
183 } else {
184 self.inner.delegate_message(&cap, msg);
186 }
187 } else {
188 return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189 }
190 }
191 Some(msg) = from_primary.recv() => {
192 self.inner.conn.send(msg).await.map_err(Into::into)?;
193 }
194 msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196 self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197 }
198 res = &mut f => {
199 let (st, extra) = res?;
200 return Ok((RlpxSatelliteStream {
201 inner: self.inner,
202 primary: PrimaryProtocol {
203 to_primary,
204 from_primary: UnboundedReceiverStream::new(from_primary),
205 st,
206 shared_cap,
207 }
208 }, extra))
209 }
210 }
211 }
212 }
213
214 pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217 self,
218 status: UnifiedStatus,
219 fork_filter: ForkFilter,
220 handshake: Arc<dyn EthRlpxHandshake>,
221 eth_max_message_size: usize,
222 ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
223 where
224 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
225 {
226 let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
227 self.into_satellite_stream_with_tuple_handshake(
228 &Capability::eth(eth_cap),
229 async move |proxy| {
230 let handshake = handshake.clone();
231 let mut unauth = UnauthProxy { inner: proxy };
232 let their_status = handshake
233 .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
234 .await?;
235 let eth_stream = EthStream::with_max_message_size(
236 eth_cap,
237 unauth.into_inner(),
238 eth_max_message_size,
239 );
240 Ok((eth_stream, their_status))
241 },
242 )
243 .await
244 }
245}
246
247#[derive(Debug)]
248struct MultiplexInner<St> {
249 conn: P2PStream<St>,
251 protocols: Vec<ProtocolStream>,
253 out_buffer: VecDeque<Bytes>,
255}
256
257impl<St> MultiplexInner<St> {
258 const fn shared_capabilities(&self) -> &SharedCapabilities {
259 self.conn.shared_capabilities()
260 }
261
262 fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
264 for proto in &self.protocols {
265 if proto.shared_cap == *cap {
266 proto.send_raw(msg);
267 return true
268 }
269 }
270 false
271 }
272
273 fn install_protocol<F, Proto>(
274 &mut self,
275 cap: &Capability,
276 f: F,
277 ) -> Result<(), UnsupportedCapabilityError>
278 where
279 F: FnOnce(ProtocolConnection) -> Proto,
280 Proto: Stream<Item = BytesMut> + Send + 'static,
281 {
282 let shared_cap =
283 self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
284 let (to_satellite, rx) = mpsc::unbounded_channel();
285 let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
286 let st = f(proto_conn);
287 let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
288 self.protocols.push(st);
289 Ok(())
290 }
291}
292
293#[derive(Debug)]
295struct PrimaryProtocol<Primary> {
296 to_primary: UnboundedSender<BytesMut>,
298 from_primary: UnboundedReceiverStream<Bytes>,
300 shared_cap: SharedCapability,
302 st: Primary,
304}
305
306#[derive(Debug)]
310pub struct ProtocolProxy {
311 shared_cap: SharedCapability,
312 from_wire: UnboundedReceiverStream<BytesMut>,
314 to_wire: UnboundedSender<Bytes>,
316}
317
318impl ProtocolProxy {
319 fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
321 if msg.is_empty() {
322 return Err(io::ErrorKind::InvalidInput.into())
324 }
325 self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
326 }
327
328 #[inline]
330 fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
331 if msg.is_empty() {
332 return Err(io::ErrorKind::InvalidInput.into())
334 }
335
336 let offset = self.shared_cap.relative_message_id_offset();
337 if offset == 0 {
338 return Ok(msg);
339 }
340
341 let mut masked: BytesMut = msg.into();
342 masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
343 Ok(masked.freeze())
344 }
345
346 #[inline]
348 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
349 if msg.is_empty() {
350 return Err(io::ErrorKind::InvalidInput.into())
352 }
353 msg[0] = msg[0]
354 .checked_sub(self.shared_cap.relative_message_id_offset())
355 .ok_or(io::ErrorKind::InvalidInput)?;
356 Ok(msg)
357 }
358}
359
360impl Stream for ProtocolProxy {
361 type Item = Result<BytesMut, io::Error>;
362
363 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
364 let msg = ready!(self.from_wire.poll_next_unpin(cx));
365 Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
366 }
367}
368
369impl Sink<Bytes> for ProtocolProxy {
370 type Error = io::Error;
371
372 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
373 Poll::Ready(Ok(()))
374 }
375
376 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
377 self.get_mut().try_send(item)
378 }
379
380 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
381 Poll::Ready(Ok(()))
382 }
383
384 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
385 Poll::Ready(Ok(()))
386 }
387}
388
389impl CanDisconnect<Bytes> for ProtocolProxy {
390 fn disconnect(
391 &mut self,
392 _reason: DisconnectReason,
393 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
394 Box::pin(async move { Ok(()) })
395 }
396}
397
398#[derive(Debug)]
401struct UnauthProxy {
402 inner: ProtocolProxy,
403}
404
405impl UnauthProxy {
406 fn into_inner(self) -> ProtocolProxy {
407 self.inner
408 }
409}
410
411impl Stream for UnauthProxy {
412 type Item = Result<BytesMut, P2PStreamError>;
413
414 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
415 self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
416 }
417}
418
419impl Sink<Bytes> for UnauthProxy {
420 type Error = P2PStreamError;
421
422 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
423 self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
424 }
425
426 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
427 self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
428 }
429
430 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
431 self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
432 }
433
434 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
435 self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
436 }
437}
438
439impl CanDisconnect<Bytes> for UnauthProxy {
440 fn disconnect(
441 &mut self,
442 reason: DisconnectReason,
443 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
444 let fut = self.inner.disconnect(reason);
445 Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
446 }
447}
448
449#[derive(Debug)]
453pub struct ProtocolConnection {
454 from_wire: UnboundedReceiverStream<BytesMut>,
455}
456
457impl Stream for ProtocolConnection {
458 type Item = BytesMut;
459
460 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
461 self.from_wire.poll_next_unpin(cx)
462 }
463}
464
465#[derive(Debug)]
468pub struct RlpxSatelliteStream<St, Primary> {
469 inner: MultiplexInner<St>,
470 primary: PrimaryProtocol<Primary>,
471}
472
473impl<St, Primary> RlpxSatelliteStream<St, Primary> {
474 pub fn install_protocol<F, Proto>(
479 &mut self,
480 cap: &Capability,
481 f: F,
482 ) -> Result<(), UnsupportedCapabilityError>
483 where
484 F: FnOnce(ProtocolConnection) -> Proto,
485 Proto: Stream<Item = BytesMut> + Send + 'static,
486 {
487 self.inner.install_protocol(cap, f)
488 }
489
490 #[inline]
492 pub const fn primary(&self) -> &Primary {
493 &self.primary.st
494 }
495
496 #[inline]
498 pub const fn primary_mut(&mut self) -> &mut Primary {
499 &mut self.primary.st
500 }
501
502 #[inline]
504 pub const fn inner(&self) -> &P2PStream<St> {
505 &self.inner.conn
506 }
507
508 #[inline]
510 pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
511 &mut self.inner.conn
512 }
513
514 #[inline]
516 pub fn into_inner(self) -> P2PStream<St> {
517 self.inner.conn
518 }
519}
520
521impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
522where
523 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
524 Primary: TryStream<Error = PrimaryErr> + Unpin,
525 P2PStreamError: Into<PrimaryErr>,
526{
527 type Item = Result<Primary::Ok, Primary::Error>;
528
529 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
530 let this = self.get_mut();
531
532 loop {
533 if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
535 return Poll::Ready(Some(msg))
536 }
537
538 let mut conn_ready = true;
539 loop {
540 match this.inner.conn.poll_ready_unpin(cx) {
541 Poll::Ready(Ok(())) => {
542 if let Some(msg) = this.inner.out_buffer.pop_front() {
543 if let Err(err) = this.inner.conn.start_send_unpin(msg) {
544 return Poll::Ready(Some(Err(err.into())))
545 }
546 } else {
547 break
548 }
549 }
550 Poll::Ready(Err(err)) => {
551 if let Err(disconnect_err) =
552 this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
553 {
554 return Poll::Ready(Some(Err(disconnect_err.into())))
555 }
556 return Poll::Ready(Some(Err(err.into())))
557 }
558 Poll::Pending => {
559 conn_ready = false;
560 break
561 }
562 }
563 }
564
565 loop {
567 match this.primary.from_primary.poll_next_unpin(cx) {
568 Poll::Ready(Some(msg)) => {
569 this.inner.out_buffer.push_back(msg);
570 }
571 Poll::Ready(None) => {
572 return Poll::Ready(None)
574 }
575 Poll::Pending => break,
576 }
577 }
578
579 for idx in (0..this.inner.protocols.len()).rev() {
581 let mut proto = this.inner.protocols.swap_remove(idx);
582 loop {
583 match proto.poll_next_unpin(cx) {
584 Poll::Ready(Some(Err(err))) => {
585 return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
586 }
587 Poll::Ready(Some(Ok(msg))) => {
588 this.inner.out_buffer.push_back(msg);
589 }
590 Poll::Ready(None) => return Poll::Ready(None),
591 Poll::Pending => {
592 this.inner.protocols.push(proto);
593 break
594 }
595 }
596 }
597 }
598
599 let mut delegated = false;
600 loop {
601 match this.inner.conn.poll_next_unpin(cx) {
603 Poll::Ready(Some(Ok(msg))) => {
604 delegated = true;
605 let Some(offset) = msg.first().copied() else {
606 return Poll::Ready(Some(Err(
607 P2PStreamError::EmptyProtocolMessage.into()
608 )))
609 };
610 if let Some(cap) =
612 this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
613 {
614 if cap == &this.primary.shared_cap {
615 let _ = this.primary.to_primary.send(msg);
617 } else {
618 for proto in &this.inner.protocols {
620 if proto.shared_cap == *cap {
621 proto.send_raw(msg);
622 break
623 }
624 }
625 }
626 } else {
627 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
628 offset,
629 )
630 .into())))
631 }
632 }
633 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
634 Poll::Ready(None) => {
635 return Poll::Ready(None)
637 }
638 Poll::Pending => break,
639 }
640 }
641
642 if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
643 return Poll::Pending
644 }
645 }
646 }
647}
648
649impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
650where
651 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
652 Primary: Sink<T> + Unpin,
653 P2PStreamError: Into<<Primary as Sink<T>>::Error>,
654{
655 type Error = <Primary as Sink<T>>::Error;
656
657 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
658 let this = self.get_mut();
659 if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
660 return Poll::Ready(Err(err.into()))
661 }
662 if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
663 return Poll::Ready(Err(err))
664 }
665 Poll::Ready(Ok(()))
666 }
667
668 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
669 self.get_mut().primary.st.start_send_unpin(item)
670 }
671
672 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
673 self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
674 }
675
676 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
677 self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
678 }
679}
680
681struct ProtocolStream {
683 shared_cap: SharedCapability,
684 to_satellite: UnboundedSender<BytesMut>,
686 satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
687}
688
689impl ProtocolStream {
690 #[inline]
692 fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
693 if msg.is_empty() {
694 return Err(io::ErrorKind::InvalidInput.into())
696 }
697 msg[0] = msg[0]
698 .checked_add(self.shared_cap.relative_message_id_offset())
699 .ok_or(io::ErrorKind::InvalidInput)?;
700 Ok(msg.freeze())
701 }
702
703 #[inline]
705 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
706 if msg.is_empty() {
707 return Err(io::ErrorKind::InvalidInput.into())
709 }
710 msg[0] = msg[0]
711 .checked_sub(self.shared_cap.relative_message_id_offset())
712 .ok_or(io::ErrorKind::InvalidInput)?;
713 Ok(msg)
714 }
715
716 fn send_raw(&self, msg: BytesMut) {
718 let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
719 }
720}
721
722impl Stream for ProtocolStream {
723 type Item = Result<Bytes, io::Error>;
724
725 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
726 let this = self.get_mut();
727 let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
728 Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
729 }
730}
731
732impl fmt::Debug for ProtocolStream {
733 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
734 f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
735 }
736}
737
738struct ProtocolsPoller<'a> {
740 protocols: &'a mut Vec<ProtocolStream>,
741}
742
743impl<'a> ProtocolsPoller<'a> {
744 const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
745 Self { protocols }
746 }
747}
748
749impl<'a> Future for ProtocolsPoller<'a> {
750 type Output = Result<Bytes, P2PStreamError>;
751
752 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
753 for idx in (0..self.protocols.len()).rev() {
755 let mut proto = self.protocols.swap_remove(idx);
756 match proto.poll_next_unpin(cx) {
757 Poll::Ready(Some(Err(err))) => {
758 self.protocols.push(proto);
759 return Poll::Ready(Err(P2PStreamError::from(err)))
760 }
761 Poll::Ready(Some(Ok(msg))) => {
762 self.protocols.push(proto);
764 return Poll::Ready(Ok(msg));
765 }
766 _ => {
767 self.protocols.push(proto);
769 }
770 }
771 }
772
773 Poll::Pending
775 }
776}
777
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use crate::{
782 handshake::EthHandshake,
783 message::MAX_MESSAGE_SIZE,
784 test_utils::{
785 connect_passthrough, eth_handshake, eth_hello,
786 proto::{test_hello, TestProtoMessage},
787 },
788 UnauthedEthStream, UnauthedP2PStream,
789 };
790 use reth_eth_wire_types::EthNetworkPrimitives;
791 use tokio::{net::TcpListener, sync::oneshot};
792 use tokio_util::codec::Decoder;
793
794 #[tokio::test]
795 async fn eth_satellite() {
796 reth_tracing::init_test_tracing();
797 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
798 let local_addr = listener.local_addr().unwrap();
799 let (status, fork_filter) = eth_handshake();
800 let other_status = status;
801 let other_fork_filter = fork_filter.clone();
802 let _handle = tokio::spawn(async move {
803 let (incoming, _) = listener.accept().await.unwrap();
804 let stream = crate::PassthroughCodec::default().framed(incoming);
805 let (server_hello, _) = eth_hello();
806 let (p2p_stream, _) =
807 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
808
809 let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
810 .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
811 .await
812 .unwrap();
813
814 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
815 });
816
817 let conn = connect_passthrough(local_addr, eth_hello().0).await;
818 let eth = conn.shared_capabilities().eth().unwrap().clone();
819
820 let multiplexer = RlpxProtocolMultiplexer::new(conn);
821 let _satellite = multiplexer
822 .into_satellite_stream_with_handshake(eth.capability().as_ref(), async move |proxy| {
823 UnauthedEthStream::new(proxy)
824 .handshake::<EthNetworkPrimitives>(status, fork_filter)
825 .await
826 })
827 .await
828 .unwrap();
829 }
830
831 #[tokio::test(flavor = "multi_thread")]
833 async fn eth_test_protocol_satellite() {
834 reth_tracing::init_test_tracing();
835 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
836 let local_addr = listener.local_addr().unwrap();
837 let (status, fork_filter) = eth_handshake();
838 let other_status = status;
839 let other_fork_filter = fork_filter.clone();
840 let _handle = tokio::spawn(async move {
841 let (incoming, _) = listener.accept().await.unwrap();
842 let stream = crate::PassthroughCodec::default().framed(incoming);
843 let (server_hello, _) = test_hello();
844 let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
845
846 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
847 .into_eth_satellite_stream::<EthNetworkPrimitives>(
848 other_status,
849 other_fork_filter,
850 Arc::new(EthHandshake::default()),
851 MAX_MESSAGE_SIZE,
852 )
853 .await
854 .unwrap();
855
856 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
857 async_stream::stream! {
858 yield TestProtoMessage::ping().encoded();
859 let msg = conn.next().await.unwrap();
860 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
861 assert_eq!(msg, TestProtoMessage::pong());
862
863 yield TestProtoMessage::message("hello").encoded();
864 let msg = conn.next().await.unwrap();
865 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
866 assert_eq!(msg, TestProtoMessage::message("good bye!"));
867
868 yield TestProtoMessage::message("good bye!").encoded();
869
870 futures::future::pending::<()>().await;
871 unreachable!()
872 }
873 })
874 .unwrap();
875
876 loop {
877 let _ = st.next().await;
878 }
879 });
880
881 let conn = connect_passthrough(local_addr, test_hello().0).await;
882 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
883 .into_eth_satellite_stream::<EthNetworkPrimitives>(
884 status,
885 fork_filter,
886 Arc::new(EthHandshake::default()),
887 MAX_MESSAGE_SIZE,
888 )
889 .await
890 .unwrap();
891
892 let (tx, mut rx) = oneshot::channel();
893
894 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
895 async_stream::stream! {
896 let msg = conn.next().await.unwrap();
897 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
898 assert_eq!(msg, TestProtoMessage::ping());
899
900 yield TestProtoMessage::pong().encoded();
901
902 let msg = conn.next().await.unwrap();
903 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
904 assert_eq!(msg, TestProtoMessage::message("hello"));
905
906 yield TestProtoMessage::message("good bye!").encoded();
907
908 let msg = conn.next().await.unwrap();
909 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
910 assert_eq!(msg, TestProtoMessage::message("good bye!"));
911
912 tx.send(()).unwrap();
913
914 futures::future::pending::<()>().await;
915 unreachable!()
916 }
917 })
918 .unwrap();
919
920 loop {
921 tokio::select! {
922 _ = &mut rx => {
923 break
924 }
925 _ = st.next() => {
926 }
927 }
928 }
929 }
930}