1use std::{
11 collections::VecDeque,
12 fmt,
13 future::Future,
14 io,
15 pin::{pin, Pin},
16 sync::Arc,
17 task::{ready, Context, Poll},
18};
19
20use crate::{
21 capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22 errors::{EthStreamError, P2PStreamError},
23 handshake::EthRlpxHandshake,
24 p2pstream::DisconnectP2P,
25 CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26 HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39 inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43 pub fn new(conn: P2PStream<St>) -> Self {
45 Self {
46 inner: MultiplexInner {
47 conn,
48 protocols: Default::default(),
49 out_buffer: Default::default(),
50 },
51 }
52 }
53
54 pub fn install_protocol<F, Proto>(
59 &mut self,
60 cap: &Capability,
61 f: F,
62 ) -> Result<(), UnsupportedCapabilityError>
63 where
64 F: FnOnce(ProtocolConnection) -> Proto,
65 Proto: Stream<Item = BytesMut> + Send + 'static,
66 {
67 self.inner.install_protocol(cap, f)
68 }
69
70 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72 self.inner.shared_capabilities()
73 }
74
75 pub fn into_satellite_stream<F, Primary>(
77 self,
78 cap: &Capability,
79 primary: F,
80 ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81 where
82 F: FnOnce(ProtocolProxy) -> Primary,
83 {
84 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85 else {
86 return Err(P2PStreamError::CapabilityNotShared)
87 };
88
89 let (to_primary, from_wire) = mpsc::unbounded_channel();
90 let (to_wire, from_primary) = mpsc::unbounded_channel();
91 let proxy = ProtocolProxy {
92 shared_cap: shared_cap.clone(),
93 from_wire: UnboundedReceiverStream::new(from_wire),
94 to_wire,
95 };
96
97 let st = primary(proxy);
98 Ok(RlpxSatelliteStream {
99 inner: self.inner,
100 primary: PrimaryProtocol {
101 to_primary,
102 from_primary: UnboundedReceiverStream::new(from_primary),
103 st,
104 shared_cap,
105 },
106 })
107 }
108
109 pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114 self,
115 cap: &Capability,
116 handshake: F,
117 ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118 where
119 F: FnOnce(ProtocolProxy) -> Fut,
120 Fut: Future<Output = Result<Primary, Err>>,
121 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122 P2PStreamError: Into<Err>,
123 {
124 self.into_satellite_stream_with_tuple_handshake(cap, async move |proxy| {
125 let st = handshake(proxy).await?;
126 Ok((st, ()))
127 })
128 .await
129 .map(|(st, _)| st)
130 }
131
132 pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142 mut self,
143 cap: &Capability,
144 handshake: F,
145 ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146 where
147 F: FnOnce(ProtocolProxy) -> Fut,
148 Fut: Future<Output = Result<(Primary, Extra), Err>>,
149 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150 P2PStreamError: Into<Err>,
151 {
152 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153 else {
154 return Err(P2PStreamError::CapabilityNotShared.into())
155 };
156
157 let (to_primary, from_wire) = mpsc::unbounded_channel();
158 let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159 let proxy = ProtocolProxy {
160 shared_cap: shared_cap.clone(),
161 from_wire: UnboundedReceiverStream::new(from_wire),
162 to_wire,
163 };
164
165 let f = handshake(proxy);
166 let mut f = pin!(f);
167
168 loop {
171 tokio::select! {
172 biased;
173 Some(Ok(msg)) = self.inner.conn.next() => {
174 let Some(offset) = msg.first().copied()
176 else {
177 return Err(P2PStreamError::EmptyProtocolMessage.into())
178 };
179 if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180 if cap == shared_cap {
181 let _ = to_primary.send(msg);
183 } else {
184 self.inner.delegate_message(&cap, msg);
186 }
187 } else {
188 return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189 }
190 }
191 Some(msg) = from_primary.recv() => {
192 self.inner.conn.send(msg).await.map_err(Into::into)?;
193 }
194 msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196 self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197 }
198 res = &mut f => {
199 let (st, extra) = res?;
200 return Ok((RlpxSatelliteStream {
201 inner: self.inner,
202 primary: PrimaryProtocol {
203 to_primary,
204 from_primary: UnboundedReceiverStream::new(from_primary),
205 st,
206 shared_cap,
207 }
208 }, extra))
209 }
210 }
211 }
212 }
213
214 pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217 self,
218 status: UnifiedStatus,
219 fork_filter: ForkFilter,
220 handshake: Arc<dyn EthRlpxHandshake>,
221 ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
222 where
223 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
224 {
225 let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
226 self.into_satellite_stream_with_tuple_handshake(
227 &Capability::eth(eth_cap),
228 async move |proxy| {
229 let handshake = handshake.clone();
230 let mut unauth = UnauthProxy { inner: proxy };
231 let their_status = handshake
232 .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
233 .await?;
234 let eth_stream = EthStream::new(eth_cap, unauth.into_inner());
235 Ok((eth_stream, their_status))
236 },
237 )
238 .await
239 }
240}
241
242#[derive(Debug)]
243struct MultiplexInner<St> {
244 conn: P2PStream<St>,
246 protocols: Vec<ProtocolStream>,
248 out_buffer: VecDeque<Bytes>,
250}
251
252impl<St> MultiplexInner<St> {
253 const fn shared_capabilities(&self) -> &SharedCapabilities {
254 self.conn.shared_capabilities()
255 }
256
257 fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
259 for proto in &self.protocols {
260 if proto.shared_cap == *cap {
261 proto.send_raw(msg);
262 return true
263 }
264 }
265 false
266 }
267
268 fn install_protocol<F, Proto>(
269 &mut self,
270 cap: &Capability,
271 f: F,
272 ) -> Result<(), UnsupportedCapabilityError>
273 where
274 F: FnOnce(ProtocolConnection) -> Proto,
275 Proto: Stream<Item = BytesMut> + Send + 'static,
276 {
277 let shared_cap =
278 self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
279 let (to_satellite, rx) = mpsc::unbounded_channel();
280 let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
281 let st = f(proto_conn);
282 let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
283 self.protocols.push(st);
284 Ok(())
285 }
286}
287
288#[derive(Debug)]
290struct PrimaryProtocol<Primary> {
291 to_primary: UnboundedSender<BytesMut>,
293 from_primary: UnboundedReceiverStream<Bytes>,
295 shared_cap: SharedCapability,
297 st: Primary,
299}
300
301#[derive(Debug)]
305pub struct ProtocolProxy {
306 shared_cap: SharedCapability,
307 from_wire: UnboundedReceiverStream<BytesMut>,
309 to_wire: UnboundedSender<Bytes>,
311}
312
313impl ProtocolProxy {
314 fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
316 if msg.is_empty() {
317 return Err(io::ErrorKind::InvalidInput.into())
319 }
320 self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
321 }
322
323 #[inline]
325 fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
326 if msg.is_empty() {
327 return Err(io::ErrorKind::InvalidInput.into())
329 }
330
331 let offset = self.shared_cap.relative_message_id_offset();
332 if offset == 0 {
333 return Ok(msg);
334 }
335
336 let mut masked: BytesMut = msg.into();
337 masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
338 Ok(masked.freeze())
339 }
340
341 #[inline]
343 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
344 if msg.is_empty() {
345 return Err(io::ErrorKind::InvalidInput.into())
347 }
348 msg[0] = msg[0]
349 .checked_sub(self.shared_cap.relative_message_id_offset())
350 .ok_or(io::ErrorKind::InvalidInput)?;
351 Ok(msg)
352 }
353}
354
355impl Stream for ProtocolProxy {
356 type Item = Result<BytesMut, io::Error>;
357
358 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
359 let msg = ready!(self.from_wire.poll_next_unpin(cx));
360 Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
361 }
362}
363
364impl Sink<Bytes> for ProtocolProxy {
365 type Error = io::Error;
366
367 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
368 Poll::Ready(Ok(()))
369 }
370
371 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
372 self.get_mut().try_send(item)
373 }
374
375 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
376 Poll::Ready(Ok(()))
377 }
378
379 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
380 Poll::Ready(Ok(()))
381 }
382}
383
384impl CanDisconnect<Bytes> for ProtocolProxy {
385 fn disconnect(
386 &mut self,
387 _reason: DisconnectReason,
388 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
389 Box::pin(async move { Ok(()) })
390 }
391}
392
393#[derive(Debug)]
396struct UnauthProxy {
397 inner: ProtocolProxy,
398}
399
400impl UnauthProxy {
401 fn into_inner(self) -> ProtocolProxy {
402 self.inner
403 }
404}
405
406impl Stream for UnauthProxy {
407 type Item = Result<BytesMut, P2PStreamError>;
408
409 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410 self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
411 }
412}
413
414impl Sink<Bytes> for UnauthProxy {
415 type Error = P2PStreamError;
416
417 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
418 self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
419 }
420
421 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
422 self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
423 }
424
425 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
426 self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
427 }
428
429 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
430 self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
431 }
432}
433
434impl CanDisconnect<Bytes> for UnauthProxy {
435 fn disconnect(
436 &mut self,
437 reason: DisconnectReason,
438 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
439 let fut = self.inner.disconnect(reason);
440 Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
441 }
442}
443
444#[derive(Debug)]
448pub struct ProtocolConnection {
449 from_wire: UnboundedReceiverStream<BytesMut>,
450}
451
452impl Stream for ProtocolConnection {
453 type Item = BytesMut;
454
455 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
456 self.from_wire.poll_next_unpin(cx)
457 }
458}
459
460#[derive(Debug)]
463pub struct RlpxSatelliteStream<St, Primary> {
464 inner: MultiplexInner<St>,
465 primary: PrimaryProtocol<Primary>,
466}
467
468impl<St, Primary> RlpxSatelliteStream<St, Primary> {
469 pub fn install_protocol<F, Proto>(
474 &mut self,
475 cap: &Capability,
476 f: F,
477 ) -> Result<(), UnsupportedCapabilityError>
478 where
479 F: FnOnce(ProtocolConnection) -> Proto,
480 Proto: Stream<Item = BytesMut> + Send + 'static,
481 {
482 self.inner.install_protocol(cap, f)
483 }
484
485 #[inline]
487 pub const fn primary(&self) -> &Primary {
488 &self.primary.st
489 }
490
491 #[inline]
493 pub const fn primary_mut(&mut self) -> &mut Primary {
494 &mut self.primary.st
495 }
496
497 #[inline]
499 pub const fn inner(&self) -> &P2PStream<St> {
500 &self.inner.conn
501 }
502
503 #[inline]
505 pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
506 &mut self.inner.conn
507 }
508
509 #[inline]
511 pub fn into_inner(self) -> P2PStream<St> {
512 self.inner.conn
513 }
514}
515
516impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
517where
518 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
519 Primary: TryStream<Error = PrimaryErr> + Unpin,
520 P2PStreamError: Into<PrimaryErr>,
521{
522 type Item = Result<Primary::Ok, Primary::Error>;
523
524 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
525 let this = self.get_mut();
526
527 loop {
528 if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
530 return Poll::Ready(Some(msg))
531 }
532
533 let mut conn_ready = true;
534 loop {
535 match this.inner.conn.poll_ready_unpin(cx) {
536 Poll::Ready(Ok(())) => {
537 if let Some(msg) = this.inner.out_buffer.pop_front() {
538 if let Err(err) = this.inner.conn.start_send_unpin(msg) {
539 return Poll::Ready(Some(Err(err.into())))
540 }
541 } else {
542 break
543 }
544 }
545 Poll::Ready(Err(err)) => {
546 if let Err(disconnect_err) =
547 this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
548 {
549 return Poll::Ready(Some(Err(disconnect_err.into())))
550 }
551 return Poll::Ready(Some(Err(err.into())))
552 }
553 Poll::Pending => {
554 conn_ready = false;
555 break
556 }
557 }
558 }
559
560 loop {
562 match this.primary.from_primary.poll_next_unpin(cx) {
563 Poll::Ready(Some(msg)) => {
564 this.inner.out_buffer.push_back(msg);
565 }
566 Poll::Ready(None) => {
567 return Poll::Ready(None)
569 }
570 Poll::Pending => break,
571 }
572 }
573
574 for idx in (0..this.inner.protocols.len()).rev() {
576 let mut proto = this.inner.protocols.swap_remove(idx);
577 loop {
578 match proto.poll_next_unpin(cx) {
579 Poll::Ready(Some(Err(err))) => {
580 return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
581 }
582 Poll::Ready(Some(Ok(msg))) => {
583 this.inner.out_buffer.push_back(msg);
584 }
585 Poll::Ready(None) => return Poll::Ready(None),
586 Poll::Pending => {
587 this.inner.protocols.push(proto);
588 break
589 }
590 }
591 }
592 }
593
594 let mut delegated = false;
595 loop {
596 match this.inner.conn.poll_next_unpin(cx) {
598 Poll::Ready(Some(Ok(msg))) => {
599 delegated = true;
600 let Some(offset) = msg.first().copied() else {
601 return Poll::Ready(Some(Err(
602 P2PStreamError::EmptyProtocolMessage.into()
603 )))
604 };
605 if let Some(cap) =
607 this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
608 {
609 if cap == &this.primary.shared_cap {
610 let _ = this.primary.to_primary.send(msg);
612 } else {
613 for proto in &this.inner.protocols {
615 if proto.shared_cap == *cap {
616 proto.send_raw(msg);
617 break
618 }
619 }
620 }
621 } else {
622 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
623 offset,
624 )
625 .into())))
626 }
627 }
628 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
629 Poll::Ready(None) => {
630 return Poll::Ready(None)
632 }
633 Poll::Pending => break,
634 }
635 }
636
637 if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
638 return Poll::Pending
639 }
640 }
641 }
642}
643
644impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
645where
646 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
647 Primary: Sink<T> + Unpin,
648 P2PStreamError: Into<<Primary as Sink<T>>::Error>,
649{
650 type Error = <Primary as Sink<T>>::Error;
651
652 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
653 let this = self.get_mut();
654 if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
655 return Poll::Ready(Err(err.into()))
656 }
657 if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
658 return Poll::Ready(Err(err))
659 }
660 Poll::Ready(Ok(()))
661 }
662
663 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
664 self.get_mut().primary.st.start_send_unpin(item)
665 }
666
667 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
668 self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
669 }
670
671 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
672 self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
673 }
674}
675
676struct ProtocolStream {
678 shared_cap: SharedCapability,
679 to_satellite: UnboundedSender<BytesMut>,
681 satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
682}
683
684impl ProtocolStream {
685 #[inline]
687 fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
688 if msg.is_empty() {
689 return Err(io::ErrorKind::InvalidInput.into())
691 }
692 msg[0] = msg[0]
693 .checked_add(self.shared_cap.relative_message_id_offset())
694 .ok_or(io::ErrorKind::InvalidInput)?;
695 Ok(msg.freeze())
696 }
697
698 #[inline]
700 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
701 if msg.is_empty() {
702 return Err(io::ErrorKind::InvalidInput.into())
704 }
705 msg[0] = msg[0]
706 .checked_sub(self.shared_cap.relative_message_id_offset())
707 .ok_or(io::ErrorKind::InvalidInput)?;
708 Ok(msg)
709 }
710
711 fn send_raw(&self, msg: BytesMut) {
713 let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
714 }
715}
716
717impl Stream for ProtocolStream {
718 type Item = Result<Bytes, io::Error>;
719
720 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
721 let this = self.get_mut();
722 let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
723 Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
724 }
725}
726
727impl fmt::Debug for ProtocolStream {
728 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
729 f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
730 }
731}
732
733struct ProtocolsPoller<'a> {
735 protocols: &'a mut Vec<ProtocolStream>,
736}
737
738impl<'a> ProtocolsPoller<'a> {
739 const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
740 Self { protocols }
741 }
742}
743
744impl<'a> Future for ProtocolsPoller<'a> {
745 type Output = Result<Bytes, P2PStreamError>;
746
747 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
748 for idx in (0..self.protocols.len()).rev() {
750 let mut proto = self.protocols.swap_remove(idx);
751 match proto.poll_next_unpin(cx) {
752 Poll::Ready(Some(Err(err))) => {
753 self.protocols.push(proto);
754 return Poll::Ready(Err(P2PStreamError::from(err)))
755 }
756 Poll::Ready(Some(Ok(msg))) => {
757 self.protocols.push(proto);
759 return Poll::Ready(Ok(msg));
760 }
761 _ => {
762 self.protocols.push(proto);
764 }
765 }
766 }
767
768 Poll::Pending
770 }
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use crate::{
777 handshake::EthHandshake,
778 test_utils::{
779 connect_passthrough, eth_handshake, eth_hello,
780 proto::{test_hello, TestProtoMessage},
781 },
782 UnauthedEthStream, UnauthedP2PStream,
783 };
784 use reth_eth_wire_types::EthNetworkPrimitives;
785 use tokio::{net::TcpListener, sync::oneshot};
786 use tokio_util::codec::Decoder;
787
788 #[tokio::test]
789 async fn eth_satellite() {
790 reth_tracing::init_test_tracing();
791 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
792 let local_addr = listener.local_addr().unwrap();
793 let (status, fork_filter) = eth_handshake();
794 let other_status = status;
795 let other_fork_filter = fork_filter.clone();
796 let _handle = tokio::spawn(async move {
797 let (incoming, _) = listener.accept().await.unwrap();
798 let stream = crate::PassthroughCodec::default().framed(incoming);
799 let (server_hello, _) = eth_hello();
800 let (p2p_stream, _) =
801 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
802
803 let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
804 .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
805 .await
806 .unwrap();
807
808 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
809 });
810
811 let conn = connect_passthrough(local_addr, eth_hello().0).await;
812 let eth = conn.shared_capabilities().eth().unwrap().clone();
813
814 let multiplexer = RlpxProtocolMultiplexer::new(conn);
815 let _satellite = multiplexer
816 .into_satellite_stream_with_handshake(eth.capability().as_ref(), async move |proxy| {
817 UnauthedEthStream::new(proxy)
818 .handshake::<EthNetworkPrimitives>(status, fork_filter)
819 .await
820 })
821 .await
822 .unwrap();
823 }
824
825 #[tokio::test(flavor = "multi_thread")]
827 async fn eth_test_protocol_satellite() {
828 reth_tracing::init_test_tracing();
829 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
830 let local_addr = listener.local_addr().unwrap();
831 let (status, fork_filter) = eth_handshake();
832 let other_status = status;
833 let other_fork_filter = fork_filter.clone();
834 let _handle = tokio::spawn(async move {
835 let (incoming, _) = listener.accept().await.unwrap();
836 let stream = crate::PassthroughCodec::default().framed(incoming);
837 let (server_hello, _) = test_hello();
838 let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
839
840 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
841 .into_eth_satellite_stream::<EthNetworkPrimitives>(
842 other_status,
843 other_fork_filter,
844 Arc::new(EthHandshake::default()),
845 )
846 .await
847 .unwrap();
848
849 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
850 async_stream::stream! {
851 yield TestProtoMessage::ping().encoded();
852 let msg = conn.next().await.unwrap();
853 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
854 assert_eq!(msg, TestProtoMessage::pong());
855
856 yield TestProtoMessage::message("hello").encoded();
857 let msg = conn.next().await.unwrap();
858 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
859 assert_eq!(msg, TestProtoMessage::message("good bye!"));
860
861 yield TestProtoMessage::message("good bye!").encoded();
862
863 futures::future::pending::<()>().await;
864 unreachable!()
865 }
866 })
867 .unwrap();
868
869 loop {
870 let _ = st.next().await;
871 }
872 });
873
874 let conn = connect_passthrough(local_addr, test_hello().0).await;
875 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
876 .into_eth_satellite_stream::<EthNetworkPrimitives>(
877 status,
878 fork_filter,
879 Arc::new(EthHandshake::default()),
880 )
881 .await
882 .unwrap();
883
884 let (tx, mut rx) = oneshot::channel();
885
886 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
887 async_stream::stream! {
888 let msg = conn.next().await.unwrap();
889 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
890 assert_eq!(msg, TestProtoMessage::ping());
891
892 yield TestProtoMessage::pong().encoded();
893
894 let msg = conn.next().await.unwrap();
895 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
896 assert_eq!(msg, TestProtoMessage::message("hello"));
897
898 yield TestProtoMessage::message("good bye!").encoded();
899
900 let msg = conn.next().await.unwrap();
901 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
902 assert_eq!(msg, TestProtoMessage::message("good bye!"));
903
904 tx.send(()).unwrap();
905
906 futures::future::pending::<()>().await;
907 unreachable!()
908 }
909 })
910 .unwrap();
911
912 loop {
913 tokio::select! {
914 _ = &mut rx => {
915 break
916 }
917 _ = st.next() => {
918 }
919 }
920 }
921 }
922}