1use std::{
11 collections::VecDeque,
12 fmt,
13 future::Future,
14 io,
15 pin::{pin, Pin},
16 sync::Arc,
17 task::{ready, Context, Poll},
18};
19
20use crate::{
21 capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22 errors::{EthStreamError, P2PStreamError},
23 handshake::EthRlpxHandshake,
24 p2pstream::DisconnectP2P,
25 CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26 HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39 inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43 pub fn new(conn: P2PStream<St>) -> Self {
45 Self {
46 inner: MultiplexInner {
47 conn,
48 protocols: Default::default(),
49 out_buffer: Default::default(),
50 },
51 }
52 }
53
54 pub fn install_protocol<F, Proto>(
59 &mut self,
60 cap: &Capability,
61 f: F,
62 ) -> Result<(), UnsupportedCapabilityError>
63 where
64 F: FnOnce(ProtocolConnection) -> Proto,
65 Proto: Stream<Item = BytesMut> + Send + 'static,
66 {
67 self.inner.install_protocol(cap, f)
68 }
69
70 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72 self.inner.shared_capabilities()
73 }
74
75 pub fn into_satellite_stream<F, Primary>(
77 self,
78 cap: &Capability,
79 primary: F,
80 ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81 where
82 F: FnOnce(ProtocolProxy) -> Primary,
83 {
84 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85 else {
86 return Err(P2PStreamError::CapabilityNotShared)
87 };
88
89 let (to_primary, from_wire) = mpsc::unbounded_channel();
90 let (to_wire, from_primary) = mpsc::unbounded_channel();
91 let proxy = ProtocolProxy {
92 shared_cap: shared_cap.clone(),
93 from_wire: UnboundedReceiverStream::new(from_wire),
94 to_wire,
95 };
96
97 let st = primary(proxy);
98 Ok(RlpxSatelliteStream {
99 inner: self.inner,
100 primary: PrimaryProtocol {
101 to_primary,
102 from_primary: UnboundedReceiverStream::new(from_primary),
103 st,
104 shared_cap,
105 },
106 })
107 }
108
109 pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114 self,
115 cap: &Capability,
116 handshake: F,
117 ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118 where
119 F: FnOnce(ProtocolProxy) -> Fut,
120 Fut: Future<Output = Result<Primary, Err>>,
121 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122 P2PStreamError: Into<Err>,
123 {
124 self.into_satellite_stream_with_tuple_handshake(cap, move |proxy| async move {
125 let st = handshake(proxy).await?;
126 Ok((st, ()))
127 })
128 .await
129 .map(|(st, _)| st)
130 }
131
132 pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142 mut self,
143 cap: &Capability,
144 handshake: F,
145 ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146 where
147 F: FnOnce(ProtocolProxy) -> Fut,
148 Fut: Future<Output = Result<(Primary, Extra), Err>>,
149 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150 P2PStreamError: Into<Err>,
151 {
152 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153 else {
154 return Err(P2PStreamError::CapabilityNotShared.into())
155 };
156
157 let (to_primary, from_wire) = mpsc::unbounded_channel();
158 let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159 let proxy = ProtocolProxy {
160 shared_cap: shared_cap.clone(),
161 from_wire: UnboundedReceiverStream::new(from_wire),
162 to_wire,
163 };
164
165 let f = handshake(proxy);
166 let mut f = pin!(f);
167
168 loop {
171 tokio::select! {
172 biased;
173 Some(Ok(msg)) = self.inner.conn.next() => {
174 let Some(offset) = msg.first().copied()
176 else {
177 return Err(P2PStreamError::EmptyProtocolMessage.into())
178 };
179 if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180 if cap == shared_cap {
181 let _ = to_primary.send(msg);
183 } else {
184 self.inner.delegate_message(&cap, msg);
186 }
187 } else {
188 return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189 }
190 }
191 Some(msg) = from_primary.recv() => {
192 self.inner.conn.send(msg).await.map_err(Into::into)?;
193 }
194 msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196 self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197 }
198 res = &mut f => {
199 let (st, extra) = res?;
200 return Ok((RlpxSatelliteStream {
201 inner: self.inner,
202 primary: PrimaryProtocol {
203 to_primary,
204 from_primary: UnboundedReceiverStream::new(from_primary),
205 st,
206 shared_cap,
207 }
208 }, extra))
209 }
210 }
211 }
212 }
213
214 pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217 self,
218 status: UnifiedStatus,
219 fork_filter: ForkFilter,
220 handshake: Arc<dyn EthRlpxHandshake>,
221 ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
222 where
223 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
224 {
225 let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
226 self.into_satellite_stream_with_tuple_handshake(&Capability::eth(eth_cap), move |proxy| {
227 let handshake = handshake.clone();
228 async move {
229 let mut unauth = UnauthProxy { inner: proxy };
230 let their_status = handshake
231 .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
232 .await?;
233 let eth_stream = EthStream::new(eth_cap, unauth.into_inner());
234 Ok((eth_stream, their_status))
235 }
236 })
237 .await
238 }
239}
240
241#[derive(Debug)]
242struct MultiplexInner<St> {
243 conn: P2PStream<St>,
245 protocols: Vec<ProtocolStream>,
247 out_buffer: VecDeque<Bytes>,
249}
250
251impl<St> MultiplexInner<St> {
252 const fn shared_capabilities(&self) -> &SharedCapabilities {
253 self.conn.shared_capabilities()
254 }
255
256 fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
258 for proto in &self.protocols {
259 if proto.shared_cap == *cap {
260 proto.send_raw(msg);
261 return true
262 }
263 }
264 false
265 }
266
267 fn install_protocol<F, Proto>(
268 &mut self,
269 cap: &Capability,
270 f: F,
271 ) -> Result<(), UnsupportedCapabilityError>
272 where
273 F: FnOnce(ProtocolConnection) -> Proto,
274 Proto: Stream<Item = BytesMut> + Send + 'static,
275 {
276 let shared_cap =
277 self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
278 let (to_satellite, rx) = mpsc::unbounded_channel();
279 let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
280 let st = f(proto_conn);
281 let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
282 self.protocols.push(st);
283 Ok(())
284 }
285}
286
287#[derive(Debug)]
289struct PrimaryProtocol<Primary> {
290 to_primary: UnboundedSender<BytesMut>,
292 from_primary: UnboundedReceiverStream<Bytes>,
294 shared_cap: SharedCapability,
296 st: Primary,
298}
299
300#[derive(Debug)]
304pub struct ProtocolProxy {
305 shared_cap: SharedCapability,
306 from_wire: UnboundedReceiverStream<BytesMut>,
308 to_wire: UnboundedSender<Bytes>,
310}
311
312impl ProtocolProxy {
313 fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
315 if msg.is_empty() {
316 return Err(io::ErrorKind::InvalidInput.into())
318 }
319 self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
320 }
321
322 #[inline]
324 fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
325 if msg.is_empty() {
326 return Err(io::ErrorKind::InvalidInput.into())
328 }
329
330 let offset = self.shared_cap.relative_message_id_offset();
331 if offset == 0 {
332 return Ok(msg);
333 }
334
335 let mut masked = Vec::from(msg);
336 masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
337 Ok(masked.into())
338 }
339
340 #[inline]
342 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
343 if msg.is_empty() {
344 return Err(io::ErrorKind::InvalidInput.into())
346 }
347 msg[0] = msg[0]
348 .checked_sub(self.shared_cap.relative_message_id_offset())
349 .ok_or(io::ErrorKind::InvalidInput)?;
350 Ok(msg)
351 }
352}
353
354impl Stream for ProtocolProxy {
355 type Item = Result<BytesMut, io::Error>;
356
357 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
358 let msg = ready!(self.from_wire.poll_next_unpin(cx));
359 Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
360 }
361}
362
363impl Sink<Bytes> for ProtocolProxy {
364 type Error = io::Error;
365
366 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
367 Poll::Ready(Ok(()))
368 }
369
370 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
371 self.get_mut().try_send(item)
372 }
373
374 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
375 Poll::Ready(Ok(()))
376 }
377
378 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379 Poll::Ready(Ok(()))
380 }
381}
382
383impl CanDisconnect<Bytes> for ProtocolProxy {
384 fn disconnect(
385 &mut self,
386 _reason: DisconnectReason,
387 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
388 Box::pin(async move { Ok(()) })
390 }
391}
392
393#[derive(Debug)]
396struct UnauthProxy {
397 inner: ProtocolProxy,
398}
399
400impl UnauthProxy {
401 fn into_inner(self) -> ProtocolProxy {
402 self.inner
403 }
404}
405
406impl Stream for UnauthProxy {
407 type Item = Result<BytesMut, P2PStreamError>;
408
409 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410 self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
411 }
412}
413
414impl Sink<Bytes> for UnauthProxy {
415 type Error = P2PStreamError;
416
417 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
418 self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
419 }
420
421 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
422 self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
423 }
424
425 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
426 self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
427 }
428
429 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
430 self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
431 }
432}
433
434impl CanDisconnect<Bytes> for UnauthProxy {
435 fn disconnect(
436 &mut self,
437 reason: DisconnectReason,
438 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
439 let fut = self.inner.disconnect(reason);
440 Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
441 }
442}
443
444#[derive(Debug)]
448pub struct ProtocolConnection {
449 from_wire: UnboundedReceiverStream<BytesMut>,
450}
451
452impl Stream for ProtocolConnection {
453 type Item = BytesMut;
454
455 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
456 self.from_wire.poll_next_unpin(cx)
457 }
458}
459
460#[derive(Debug)]
463pub struct RlpxSatelliteStream<St, Primary> {
464 inner: MultiplexInner<St>,
465 primary: PrimaryProtocol<Primary>,
466}
467
468impl<St, Primary> RlpxSatelliteStream<St, Primary> {
469 pub fn install_protocol<F, Proto>(
474 &mut self,
475 cap: &Capability,
476 f: F,
477 ) -> Result<(), UnsupportedCapabilityError>
478 where
479 F: FnOnce(ProtocolConnection) -> Proto,
480 Proto: Stream<Item = BytesMut> + Send + 'static,
481 {
482 self.inner.install_protocol(cap, f)
483 }
484
485 #[inline]
487 pub const fn primary(&self) -> &Primary {
488 &self.primary.st
489 }
490
491 #[inline]
493 pub const fn primary_mut(&mut self) -> &mut Primary {
494 &mut self.primary.st
495 }
496
497 #[inline]
499 pub const fn inner(&self) -> &P2PStream<St> {
500 &self.inner.conn
501 }
502
503 #[inline]
505 pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
506 &mut self.inner.conn
507 }
508
509 #[inline]
511 pub fn into_inner(self) -> P2PStream<St> {
512 self.inner.conn
513 }
514}
515
516impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
517where
518 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
519 Primary: TryStream<Error = PrimaryErr> + Unpin,
520 P2PStreamError: Into<PrimaryErr>,
521{
522 type Item = Result<Primary::Ok, Primary::Error>;
523
524 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
525 let this = self.get_mut();
526
527 loop {
528 if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
530 return Poll::Ready(Some(msg))
531 }
532
533 let mut conn_ready = true;
534 loop {
535 match this.inner.conn.poll_ready_unpin(cx) {
536 Poll::Ready(Ok(())) => {
537 if let Some(msg) = this.inner.out_buffer.pop_front() {
538 if let Err(err) = this.inner.conn.start_send_unpin(msg) {
539 return Poll::Ready(Some(Err(err.into())))
540 }
541 } else {
542 break
543 }
544 }
545 Poll::Ready(Err(err)) => {
546 if let Err(disconnect_err) =
547 this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
548 {
549 return Poll::Ready(Some(Err(disconnect_err.into())))
550 }
551 return Poll::Ready(Some(Err(err.into())))
552 }
553 Poll::Pending => {
554 conn_ready = false;
555 break
556 }
557 }
558 }
559
560 loop {
562 match this.primary.from_primary.poll_next_unpin(cx) {
563 Poll::Ready(Some(msg)) => {
564 this.inner.out_buffer.push_back(msg);
565 }
566 Poll::Ready(None) => {
567 return Poll::Ready(None)
569 }
570 Poll::Pending => break,
571 }
572 }
573
574 for idx in (0..this.inner.protocols.len()).rev() {
576 let mut proto = this.inner.protocols.swap_remove(idx);
577 loop {
578 match proto.poll_next_unpin(cx) {
579 Poll::Ready(Some(Err(err))) => {
580 return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
581 }
582 Poll::Ready(Some(Ok(msg))) => {
583 this.inner.out_buffer.push_back(msg);
584 }
585 Poll::Ready(None) => return Poll::Ready(None),
586 Poll::Pending => {
587 this.inner.protocols.push(proto);
588 break
589 }
590 }
591 }
592 }
593
594 let mut delegated = false;
595 loop {
596 match this.inner.conn.poll_next_unpin(cx) {
598 Poll::Ready(Some(Ok(msg))) => {
599 delegated = true;
600 let Some(offset) = msg.first().copied() else {
601 return Poll::Ready(Some(Err(
602 P2PStreamError::EmptyProtocolMessage.into()
603 )))
604 };
605 if let Some(cap) =
607 this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
608 {
609 if cap == &this.primary.shared_cap {
610 let _ = this.primary.to_primary.send(msg);
612 } else {
613 for proto in &this.inner.protocols {
615 if proto.shared_cap == *cap {
616 proto.send_raw(msg);
617 break
618 }
619 }
620 }
621 } else {
622 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
623 offset,
624 )
625 .into())))
626 }
627 }
628 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
629 Poll::Ready(None) => {
630 return Poll::Ready(None)
632 }
633 Poll::Pending => break,
634 }
635 }
636
637 if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
638 return Poll::Pending
639 }
640 }
641 }
642}
643
644impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
645where
646 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
647 Primary: Sink<T> + Unpin,
648 P2PStreamError: Into<<Primary as Sink<T>>::Error>,
649{
650 type Error = <Primary as Sink<T>>::Error;
651
652 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
653 let this = self.get_mut();
654 if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
655 return Poll::Ready(Err(err.into()))
656 }
657 if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
658 return Poll::Ready(Err(err))
659 }
660 Poll::Ready(Ok(()))
661 }
662
663 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
664 self.get_mut().primary.st.start_send_unpin(item)
665 }
666
667 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
668 self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
669 }
670
671 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
672 self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
673 }
674}
675
676struct ProtocolStream {
678 shared_cap: SharedCapability,
679 to_satellite: UnboundedSender<BytesMut>,
681 satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
682}
683
684impl ProtocolStream {
685 #[inline]
687 fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
688 if msg.is_empty() {
689 return Err(io::ErrorKind::InvalidInput.into())
691 }
692 msg[0] = msg[0]
693 .checked_add(self.shared_cap.relative_message_id_offset())
694 .ok_or(io::ErrorKind::InvalidInput)?;
695 Ok(msg.freeze())
696 }
697
698 #[inline]
700 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
701 if msg.is_empty() {
702 return Err(io::ErrorKind::InvalidInput.into())
704 }
705 msg[0] = msg[0]
706 .checked_sub(self.shared_cap.relative_message_id_offset())
707 .ok_or(io::ErrorKind::InvalidInput)?;
708 Ok(msg)
709 }
710
711 fn send_raw(&self, msg: BytesMut) {
713 let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
714 }
715}
716
717impl Stream for ProtocolStream {
718 type Item = Result<Bytes, io::Error>;
719
720 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
721 let this = self.get_mut();
722 let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
723 Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
724 }
725}
726
727impl fmt::Debug for ProtocolStream {
728 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
729 f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
730 }
731}
732
733struct ProtocolsPoller<'a> {
735 protocols: &'a mut Vec<ProtocolStream>,
736}
737
738impl<'a> ProtocolsPoller<'a> {
739 const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
740 Self { protocols }
741 }
742}
743
744impl<'a> Future for ProtocolsPoller<'a> {
745 type Output = Result<Bytes, P2PStreamError>;
746
747 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
748 for idx in (0..self.protocols.len()).rev() {
750 let mut proto = self.protocols.swap_remove(idx);
751 match proto.poll_next_unpin(cx) {
752 Poll::Ready(Some(Err(err))) => {
753 self.protocols.push(proto);
754 return Poll::Ready(Err(P2PStreamError::from(err)))
755 }
756 Poll::Ready(Some(Ok(msg))) => {
757 self.protocols.push(proto);
759 return Poll::Ready(Ok(msg));
760 }
761 _ => {
762 self.protocols.push(proto);
764 }
765 }
766 }
767
768 Poll::Pending
770 }
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use crate::{
777 handshake::EthHandshake,
778 test_utils::{
779 connect_passthrough, eth_handshake, eth_hello,
780 proto::{test_hello, TestProtoMessage},
781 },
782 UnauthedEthStream, UnauthedP2PStream,
783 };
784 use reth_eth_wire_types::EthNetworkPrimitives;
785 use tokio::{net::TcpListener, sync::oneshot};
786 use tokio_util::codec::Decoder;
787
788 #[tokio::test]
789 async fn eth_satellite() {
790 reth_tracing::init_test_tracing();
791 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
792 let local_addr = listener.local_addr().unwrap();
793 let (status, fork_filter) = eth_handshake();
794 let other_status = status;
795 let other_fork_filter = fork_filter.clone();
796 let _handle = tokio::spawn(async move {
797 let (incoming, _) = listener.accept().await.unwrap();
798 let stream = crate::PassthroughCodec::default().framed(incoming);
799 let (server_hello, _) = eth_hello();
800 let (p2p_stream, _) =
801 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
802
803 let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
804 .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
805 .await
806 .unwrap();
807
808 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
809 });
810
811 let conn = connect_passthrough(local_addr, eth_hello().0).await;
812 let eth = conn.shared_capabilities().eth().unwrap().clone();
813
814 let multiplexer = RlpxProtocolMultiplexer::new(conn);
815 let _satellite = multiplexer
816 .into_satellite_stream_with_handshake(
817 eth.capability().as_ref(),
818 move |proxy| async move {
819 UnauthedEthStream::new(proxy)
820 .handshake::<EthNetworkPrimitives>(status, fork_filter)
821 .await
822 },
823 )
824 .await
825 .unwrap();
826 }
827
828 #[tokio::test(flavor = "multi_thread")]
830 async fn eth_test_protocol_satellite() {
831 reth_tracing::init_test_tracing();
832 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
833 let local_addr = listener.local_addr().unwrap();
834 let (status, fork_filter) = eth_handshake();
835 let other_status = status;
836 let other_fork_filter = fork_filter.clone();
837 let _handle = tokio::spawn(async move {
838 let (incoming, _) = listener.accept().await.unwrap();
839 let stream = crate::PassthroughCodec::default().framed(incoming);
840 let (server_hello, _) = test_hello();
841 let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
842
843 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
844 .into_eth_satellite_stream::<EthNetworkPrimitives>(
845 other_status,
846 other_fork_filter,
847 Arc::new(EthHandshake::default()),
848 )
849 .await
850 .unwrap();
851
852 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
853 async_stream::stream! {
854 yield TestProtoMessage::ping().encoded();
855 let msg = conn.next().await.unwrap();
856 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
857 assert_eq!(msg, TestProtoMessage::pong());
858
859 yield TestProtoMessage::message("hello").encoded();
860 let msg = conn.next().await.unwrap();
861 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
862 assert_eq!(msg, TestProtoMessage::message("good bye!"));
863
864 yield TestProtoMessage::message("good bye!").encoded();
865
866 futures::future::pending::<()>().await;
867 unreachable!()
868 }
869 })
870 .unwrap();
871
872 loop {
873 let _ = st.next().await;
874 }
875 });
876
877 let conn = connect_passthrough(local_addr, test_hello().0).await;
878 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
879 .into_eth_satellite_stream::<EthNetworkPrimitives>(
880 status,
881 fork_filter,
882 Arc::new(EthHandshake::default()),
883 )
884 .await
885 .unwrap();
886
887 let (tx, mut rx) = oneshot::channel();
888
889 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
890 async_stream::stream! {
891 let msg = conn.next().await.unwrap();
892 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
893 assert_eq!(msg, TestProtoMessage::ping());
894
895 yield TestProtoMessage::pong().encoded();
896
897 let msg = conn.next().await.unwrap();
898 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
899 assert_eq!(msg, TestProtoMessage::message("hello"));
900
901 yield TestProtoMessage::message("good bye!").encoded();
902
903 let msg = conn.next().await.unwrap();
904 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
905 assert_eq!(msg, TestProtoMessage::message("good bye!"));
906
907 tx.send(()).unwrap();
908
909 futures::future::pending::<()>().await;
910 unreachable!()
911 }
912 })
913 .unwrap();
914
915 loop {
916 tokio::select! {
917 _ = &mut rx => {
918 break
919 }
920 _ = st.next() => {
921 }
922 }
923 }
924 }
925}