1use std::{
11 collections::VecDeque,
12 fmt,
13 future::Future,
14 io,
15 pin::{pin, Pin},
16 sync::Arc,
17 task::{ready, Context, Poll},
18};
19
20use crate::{
21 capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22 errors::{EthStreamError, P2PStreamError},
23 handshake::EthRlpxHandshake,
24 p2pstream::DisconnectP2P,
25 CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26 HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39 inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43 pub fn new(conn: P2PStream<St>) -> Self {
45 Self {
46 inner: MultiplexInner {
47 conn,
48 protocols: Default::default(),
49 out_buffer: Default::default(),
50 },
51 }
52 }
53
54 pub fn install_protocol<F, Proto>(
59 &mut self,
60 cap: &Capability,
61 f: F,
62 ) -> Result<(), UnsupportedCapabilityError>
63 where
64 F: FnOnce(ProtocolConnection) -> Proto,
65 Proto: Stream<Item = BytesMut> + Send + 'static,
66 {
67 self.inner.install_protocol(cap, f)
68 }
69
70 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72 self.inner.shared_capabilities()
73 }
74
75 pub fn into_satellite_stream<F, Primary>(
77 self,
78 cap: &Capability,
79 primary: F,
80 ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81 where
82 F: FnOnce(ProtocolProxy) -> Primary,
83 {
84 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85 else {
86 return Err(P2PStreamError::CapabilityNotShared)
87 };
88
89 let (to_primary, from_wire) = mpsc::unbounded_channel();
90 let (to_wire, from_primary) = mpsc::unbounded_channel();
91 let proxy = ProtocolProxy {
92 shared_cap: shared_cap.clone(),
93 from_wire: UnboundedReceiverStream::new(from_wire),
94 to_wire,
95 };
96
97 let st = primary(proxy);
98 Ok(RlpxSatelliteStream {
99 inner: self.inner,
100 primary: PrimaryProtocol {
101 to_primary,
102 from_primary: UnboundedReceiverStream::new(from_primary),
103 st,
104 shared_cap,
105 },
106 })
107 }
108
109 pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114 self,
115 cap: &Capability,
116 handshake: F,
117 ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118 where
119 F: FnOnce(ProtocolProxy) -> Fut,
120 Fut: Future<Output = Result<Primary, Err>>,
121 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122 P2PStreamError: Into<Err>,
123 {
124 self.into_satellite_stream_with_tuple_handshake(cap, move |proxy| async move {
125 let st = handshake(proxy).await?;
126 Ok((st, ()))
127 })
128 .await
129 .map(|(st, _)| st)
130 }
131
132 pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142 mut self,
143 cap: &Capability,
144 handshake: F,
145 ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146 where
147 F: FnOnce(ProtocolProxy) -> Fut,
148 Fut: Future<Output = Result<(Primary, Extra), Err>>,
149 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150 P2PStreamError: Into<Err>,
151 {
152 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153 else {
154 return Err(P2PStreamError::CapabilityNotShared.into())
155 };
156
157 let (to_primary, from_wire) = mpsc::unbounded_channel();
158 let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159 let proxy = ProtocolProxy {
160 shared_cap: shared_cap.clone(),
161 from_wire: UnboundedReceiverStream::new(from_wire),
162 to_wire,
163 };
164
165 let f = handshake(proxy);
166 let mut f = pin!(f);
167
168 loop {
171 tokio::select! {
172 biased;
173 Some(Ok(msg)) = self.inner.conn.next() => {
174 let Some(offset) = msg.first().copied()
176 else {
177 return Err(P2PStreamError::EmptyProtocolMessage.into())
178 };
179 if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180 if cap == shared_cap {
181 let _ = to_primary.send(msg);
183 } else {
184 self.inner.delegate_message(&cap, msg);
186 }
187 } else {
188 return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189 }
190 }
191 Some(msg) = from_primary.recv() => {
192 self.inner.conn.send(msg).await.map_err(Into::into)?;
193 }
194 msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196 self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197 }
198 res = &mut f => {
199 let (st, extra) = res?;
200 return Ok((RlpxSatelliteStream {
201 inner: self.inner,
202 primary: PrimaryProtocol {
203 to_primary,
204 from_primary: UnboundedReceiverStream::new(from_primary),
205 st,
206 shared_cap,
207 }
208 }, extra))
209 }
210 }
211 }
212 }
213
214 pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217 self,
218 status: UnifiedStatus,
219 fork_filter: ForkFilter,
220 handshake: Arc<dyn EthRlpxHandshake>,
221 ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
222 where
223 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
224 {
225 let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
226 self.into_satellite_stream_with_tuple_handshake(&Capability::eth(eth_cap), move |proxy| {
227 let handshake = handshake.clone();
228 async move {
229 let mut unauth = UnauthProxy { inner: proxy };
230 let their_status = handshake
231 .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
232 .await?;
233 let eth_stream = EthStream::new(eth_cap, unauth.into_inner());
234 Ok((eth_stream, their_status))
235 }
236 })
237 .await
238 }
239}
240
241#[derive(Debug)]
242struct MultiplexInner<St> {
243 conn: P2PStream<St>,
245 protocols: Vec<ProtocolStream>,
247 out_buffer: VecDeque<Bytes>,
249}
250
251impl<St> MultiplexInner<St> {
252 const fn shared_capabilities(&self) -> &SharedCapabilities {
253 self.conn.shared_capabilities()
254 }
255
256 fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
258 for proto in &self.protocols {
259 if proto.shared_cap == *cap {
260 proto.send_raw(msg);
261 return true
262 }
263 }
264 false
265 }
266
267 fn install_protocol<F, Proto>(
268 &mut self,
269 cap: &Capability,
270 f: F,
271 ) -> Result<(), UnsupportedCapabilityError>
272 where
273 F: FnOnce(ProtocolConnection) -> Proto,
274 Proto: Stream<Item = BytesMut> + Send + 'static,
275 {
276 let shared_cap =
277 self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
278 let (to_satellite, rx) = mpsc::unbounded_channel();
279 let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
280 let st = f(proto_conn);
281 let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
282 self.protocols.push(st);
283 Ok(())
284 }
285}
286
287#[derive(Debug)]
289struct PrimaryProtocol<Primary> {
290 to_primary: UnboundedSender<BytesMut>,
292 from_primary: UnboundedReceiverStream<Bytes>,
294 shared_cap: SharedCapability,
296 st: Primary,
298}
299
300#[derive(Debug)]
304pub struct ProtocolProxy {
305 shared_cap: SharedCapability,
306 from_wire: UnboundedReceiverStream<BytesMut>,
308 to_wire: UnboundedSender<Bytes>,
310}
311
312impl ProtocolProxy {
313 fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
315 if msg.is_empty() {
316 return Err(io::ErrorKind::InvalidInput.into())
318 }
319 self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
320 }
321
322 #[inline]
324 fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
325 if msg.is_empty() {
326 return Err(io::ErrorKind::InvalidInput.into())
328 }
329
330 let offset = self.shared_cap.relative_message_id_offset();
331 if offset == 0 {
332 return Ok(msg);
333 }
334
335 let mut masked: BytesMut = msg.into();
336 masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
337 Ok(masked.freeze())
338 }
339
340 #[inline]
342 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
343 if msg.is_empty() {
344 return Err(io::ErrorKind::InvalidInput.into())
346 }
347 msg[0] = msg[0]
348 .checked_sub(self.shared_cap.relative_message_id_offset())
349 .ok_or(io::ErrorKind::InvalidInput)?;
350 Ok(msg)
351 }
352}
353
354impl Stream for ProtocolProxy {
355 type Item = Result<BytesMut, io::Error>;
356
357 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
358 let msg = ready!(self.from_wire.poll_next_unpin(cx));
359 Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
360 }
361}
362
363impl Sink<Bytes> for ProtocolProxy {
364 type Error = io::Error;
365
366 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
367 Poll::Ready(Ok(()))
368 }
369
370 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
371 self.get_mut().try_send(item)
372 }
373
374 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
375 Poll::Ready(Ok(()))
376 }
377
378 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379 Poll::Ready(Ok(()))
380 }
381}
382
383impl CanDisconnect<Bytes> for ProtocolProxy {
384 fn disconnect(
385 &mut self,
386 _reason: DisconnectReason,
387 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
388 Box::pin(async move { Ok(()) })
389 }
390}
391
392#[derive(Debug)]
395struct UnauthProxy {
396 inner: ProtocolProxy,
397}
398
399impl UnauthProxy {
400 fn into_inner(self) -> ProtocolProxy {
401 self.inner
402 }
403}
404
405impl Stream for UnauthProxy {
406 type Item = Result<BytesMut, P2PStreamError>;
407
408 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
409 self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
410 }
411}
412
413impl Sink<Bytes> for UnauthProxy {
414 type Error = P2PStreamError;
415
416 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
417 self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
418 }
419
420 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
421 self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
422 }
423
424 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
425 self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
426 }
427
428 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
429 self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
430 }
431}
432
433impl CanDisconnect<Bytes> for UnauthProxy {
434 fn disconnect(
435 &mut self,
436 reason: DisconnectReason,
437 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
438 let fut = self.inner.disconnect(reason);
439 Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
440 }
441}
442
443#[derive(Debug)]
447pub struct ProtocolConnection {
448 from_wire: UnboundedReceiverStream<BytesMut>,
449}
450
451impl Stream for ProtocolConnection {
452 type Item = BytesMut;
453
454 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
455 self.from_wire.poll_next_unpin(cx)
456 }
457}
458
459#[derive(Debug)]
462pub struct RlpxSatelliteStream<St, Primary> {
463 inner: MultiplexInner<St>,
464 primary: PrimaryProtocol<Primary>,
465}
466
467impl<St, Primary> RlpxSatelliteStream<St, Primary> {
468 pub fn install_protocol<F, Proto>(
473 &mut self,
474 cap: &Capability,
475 f: F,
476 ) -> Result<(), UnsupportedCapabilityError>
477 where
478 F: FnOnce(ProtocolConnection) -> Proto,
479 Proto: Stream<Item = BytesMut> + Send + 'static,
480 {
481 self.inner.install_protocol(cap, f)
482 }
483
484 #[inline]
486 pub const fn primary(&self) -> &Primary {
487 &self.primary.st
488 }
489
490 #[inline]
492 pub const fn primary_mut(&mut self) -> &mut Primary {
493 &mut self.primary.st
494 }
495
496 #[inline]
498 pub const fn inner(&self) -> &P2PStream<St> {
499 &self.inner.conn
500 }
501
502 #[inline]
504 pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
505 &mut self.inner.conn
506 }
507
508 #[inline]
510 pub fn into_inner(self) -> P2PStream<St> {
511 self.inner.conn
512 }
513}
514
515impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
516where
517 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
518 Primary: TryStream<Error = PrimaryErr> + Unpin,
519 P2PStreamError: Into<PrimaryErr>,
520{
521 type Item = Result<Primary::Ok, Primary::Error>;
522
523 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
524 let this = self.get_mut();
525
526 loop {
527 if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
529 return Poll::Ready(Some(msg))
530 }
531
532 let mut conn_ready = true;
533 loop {
534 match this.inner.conn.poll_ready_unpin(cx) {
535 Poll::Ready(Ok(())) => {
536 if let Some(msg) = this.inner.out_buffer.pop_front() {
537 if let Err(err) = this.inner.conn.start_send_unpin(msg) {
538 return Poll::Ready(Some(Err(err.into())))
539 }
540 } else {
541 break
542 }
543 }
544 Poll::Ready(Err(err)) => {
545 if let Err(disconnect_err) =
546 this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
547 {
548 return Poll::Ready(Some(Err(disconnect_err.into())))
549 }
550 return Poll::Ready(Some(Err(err.into())))
551 }
552 Poll::Pending => {
553 conn_ready = false;
554 break
555 }
556 }
557 }
558
559 loop {
561 match this.primary.from_primary.poll_next_unpin(cx) {
562 Poll::Ready(Some(msg)) => {
563 this.inner.out_buffer.push_back(msg);
564 }
565 Poll::Ready(None) => {
566 return Poll::Ready(None)
568 }
569 Poll::Pending => break,
570 }
571 }
572
573 for idx in (0..this.inner.protocols.len()).rev() {
575 let mut proto = this.inner.protocols.swap_remove(idx);
576 loop {
577 match proto.poll_next_unpin(cx) {
578 Poll::Ready(Some(Err(err))) => {
579 return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
580 }
581 Poll::Ready(Some(Ok(msg))) => {
582 this.inner.out_buffer.push_back(msg);
583 }
584 Poll::Ready(None) => return Poll::Ready(None),
585 Poll::Pending => {
586 this.inner.protocols.push(proto);
587 break
588 }
589 }
590 }
591 }
592
593 let mut delegated = false;
594 loop {
595 match this.inner.conn.poll_next_unpin(cx) {
597 Poll::Ready(Some(Ok(msg))) => {
598 delegated = true;
599 let Some(offset) = msg.first().copied() else {
600 return Poll::Ready(Some(Err(
601 P2PStreamError::EmptyProtocolMessage.into()
602 )))
603 };
604 if let Some(cap) =
606 this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
607 {
608 if cap == &this.primary.shared_cap {
609 let _ = this.primary.to_primary.send(msg);
611 } else {
612 for proto in &this.inner.protocols {
614 if proto.shared_cap == *cap {
615 proto.send_raw(msg);
616 break
617 }
618 }
619 }
620 } else {
621 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
622 offset,
623 )
624 .into())))
625 }
626 }
627 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
628 Poll::Ready(None) => {
629 return Poll::Ready(None)
631 }
632 Poll::Pending => break,
633 }
634 }
635
636 if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
637 return Poll::Pending
638 }
639 }
640 }
641}
642
643impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
644where
645 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
646 Primary: Sink<T> + Unpin,
647 P2PStreamError: Into<<Primary as Sink<T>>::Error>,
648{
649 type Error = <Primary as Sink<T>>::Error;
650
651 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
652 let this = self.get_mut();
653 if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
654 return Poll::Ready(Err(err.into()))
655 }
656 if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
657 return Poll::Ready(Err(err))
658 }
659 Poll::Ready(Ok(()))
660 }
661
662 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
663 self.get_mut().primary.st.start_send_unpin(item)
664 }
665
666 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
667 self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
668 }
669
670 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
671 self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
672 }
673}
674
675struct ProtocolStream {
677 shared_cap: SharedCapability,
678 to_satellite: UnboundedSender<BytesMut>,
680 satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
681}
682
683impl ProtocolStream {
684 #[inline]
686 fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
687 if msg.is_empty() {
688 return Err(io::ErrorKind::InvalidInput.into())
690 }
691 msg[0] = msg[0]
692 .checked_add(self.shared_cap.relative_message_id_offset())
693 .ok_or(io::ErrorKind::InvalidInput)?;
694 Ok(msg.freeze())
695 }
696
697 #[inline]
699 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
700 if msg.is_empty() {
701 return Err(io::ErrorKind::InvalidInput.into())
703 }
704 msg[0] = msg[0]
705 .checked_sub(self.shared_cap.relative_message_id_offset())
706 .ok_or(io::ErrorKind::InvalidInput)?;
707 Ok(msg)
708 }
709
710 fn send_raw(&self, msg: BytesMut) {
712 let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
713 }
714}
715
716impl Stream for ProtocolStream {
717 type Item = Result<Bytes, io::Error>;
718
719 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
720 let this = self.get_mut();
721 let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
722 Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
723 }
724}
725
726impl fmt::Debug for ProtocolStream {
727 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
728 f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
729 }
730}
731
732struct ProtocolsPoller<'a> {
734 protocols: &'a mut Vec<ProtocolStream>,
735}
736
737impl<'a> ProtocolsPoller<'a> {
738 const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
739 Self { protocols }
740 }
741}
742
743impl<'a> Future for ProtocolsPoller<'a> {
744 type Output = Result<Bytes, P2PStreamError>;
745
746 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
747 for idx in (0..self.protocols.len()).rev() {
749 let mut proto = self.protocols.swap_remove(idx);
750 match proto.poll_next_unpin(cx) {
751 Poll::Ready(Some(Err(err))) => {
752 self.protocols.push(proto);
753 return Poll::Ready(Err(P2PStreamError::from(err)))
754 }
755 Poll::Ready(Some(Ok(msg))) => {
756 self.protocols.push(proto);
758 return Poll::Ready(Ok(msg));
759 }
760 _ => {
761 self.protocols.push(proto);
763 }
764 }
765 }
766
767 Poll::Pending
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775 use crate::{
776 handshake::EthHandshake,
777 test_utils::{
778 connect_passthrough, eth_handshake, eth_hello,
779 proto::{test_hello, TestProtoMessage},
780 },
781 UnauthedEthStream, UnauthedP2PStream,
782 };
783 use reth_eth_wire_types::EthNetworkPrimitives;
784 use tokio::{net::TcpListener, sync::oneshot};
785 use tokio_util::codec::Decoder;
786
787 #[tokio::test]
788 async fn eth_satellite() {
789 reth_tracing::init_test_tracing();
790 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
791 let local_addr = listener.local_addr().unwrap();
792 let (status, fork_filter) = eth_handshake();
793 let other_status = status;
794 let other_fork_filter = fork_filter.clone();
795 let _handle = tokio::spawn(async move {
796 let (incoming, _) = listener.accept().await.unwrap();
797 let stream = crate::PassthroughCodec::default().framed(incoming);
798 let (server_hello, _) = eth_hello();
799 let (p2p_stream, _) =
800 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
801
802 let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
803 .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
804 .await
805 .unwrap();
806
807 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
808 });
809
810 let conn = connect_passthrough(local_addr, eth_hello().0).await;
811 let eth = conn.shared_capabilities().eth().unwrap().clone();
812
813 let multiplexer = RlpxProtocolMultiplexer::new(conn);
814 let _satellite = multiplexer
815 .into_satellite_stream_with_handshake(
816 eth.capability().as_ref(),
817 move |proxy| async move {
818 UnauthedEthStream::new(proxy)
819 .handshake::<EthNetworkPrimitives>(status, fork_filter)
820 .await
821 },
822 )
823 .await
824 .unwrap();
825 }
826
827 #[tokio::test(flavor = "multi_thread")]
829 async fn eth_test_protocol_satellite() {
830 reth_tracing::init_test_tracing();
831 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
832 let local_addr = listener.local_addr().unwrap();
833 let (status, fork_filter) = eth_handshake();
834 let other_status = status;
835 let other_fork_filter = fork_filter.clone();
836 let _handle = tokio::spawn(async move {
837 let (incoming, _) = listener.accept().await.unwrap();
838 let stream = crate::PassthroughCodec::default().framed(incoming);
839 let (server_hello, _) = test_hello();
840 let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
841
842 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
843 .into_eth_satellite_stream::<EthNetworkPrimitives>(
844 other_status,
845 other_fork_filter,
846 Arc::new(EthHandshake::default()),
847 )
848 .await
849 .unwrap();
850
851 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
852 async_stream::stream! {
853 yield TestProtoMessage::ping().encoded();
854 let msg = conn.next().await.unwrap();
855 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
856 assert_eq!(msg, TestProtoMessage::pong());
857
858 yield TestProtoMessage::message("hello").encoded();
859 let msg = conn.next().await.unwrap();
860 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
861 assert_eq!(msg, TestProtoMessage::message("good bye!"));
862
863 yield TestProtoMessage::message("good bye!").encoded();
864
865 futures::future::pending::<()>().await;
866 unreachable!()
867 }
868 })
869 .unwrap();
870
871 loop {
872 let _ = st.next().await;
873 }
874 });
875
876 let conn = connect_passthrough(local_addr, test_hello().0).await;
877 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
878 .into_eth_satellite_stream::<EthNetworkPrimitives>(
879 status,
880 fork_filter,
881 Arc::new(EthHandshake::default()),
882 )
883 .await
884 .unwrap();
885
886 let (tx, mut rx) = oneshot::channel();
887
888 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
889 async_stream::stream! {
890 let msg = conn.next().await.unwrap();
891 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
892 assert_eq!(msg, TestProtoMessage::ping());
893
894 yield TestProtoMessage::pong().encoded();
895
896 let msg = conn.next().await.unwrap();
897 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
898 assert_eq!(msg, TestProtoMessage::message("hello"));
899
900 yield TestProtoMessage::message("good bye!").encoded();
901
902 let msg = conn.next().await.unwrap();
903 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
904 assert_eq!(msg, TestProtoMessage::message("good bye!"));
905
906 tx.send(()).unwrap();
907
908 futures::future::pending::<()>().await;
909 unreachable!()
910 }
911 })
912 .unwrap();
913
914 loop {
915 tokio::select! {
916 _ = &mut rx => {
917 break
918 }
919 _ = st.next() => {
920 }
921 }
922 }
923 }
924}