Skip to main content

reth_eth_wire/
multiplex.rs

1//! Rlpx protocol multiplexer and satellite stream
2//!
3//! A Satellite is a Stream that primarily drives a single `RLPx` subprotocol but can also handle
4//! additional subprotocols.
5//!
6//! Most of other subprotocols are "dependent satellite" protocols of "eth" and not a fully standalone protocol, for example "snap", See also [snap protocol](https://github.com/ethereum/devp2p/blob/298d7a77c3bf833641579ecbbb5b13f0311eeeea/caps/snap.md?plain=1#L71)
7//! Hence it is expected that the primary protocol is "eth" and the additional protocols are
8//! "dependent satellite" protocols.
9
10use std::{
11    collections::VecDeque,
12    fmt,
13    future::Future,
14    io,
15    pin::{pin, Pin},
16    sync::Arc,
17    task::{ready, Context, Poll},
18};
19
20use crate::{
21    capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22    errors::{EthStreamError, P2PStreamError},
23    handshake::EthRlpxHandshake,
24    p2pstream::DisconnectP2P,
25    CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26    HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35/// A Stream and Sink type that wraps a raw rlpx stream [`P2PStream`] and handles message ID
36/// multiplexing.
37#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39    inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43    /// Wraps the raw p2p stream
44    pub fn new(conn: P2PStream<St>) -> Self {
45        Self {
46            inner: MultiplexInner {
47                conn,
48                protocols: Default::default(),
49                out_buffer: Default::default(),
50            },
51        }
52    }
53
54    /// Installs a new protocol on top of the raw p2p stream.
55    ///
56    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
57    /// the given capability.
58    pub fn install_protocol<F, Proto>(
59        &mut self,
60        cap: &Capability,
61        f: F,
62    ) -> Result<(), UnsupportedCapabilityError>
63    where
64        F: FnOnce(ProtocolConnection) -> Proto,
65        Proto: Stream<Item = BytesMut> + Send + 'static,
66    {
67        self.inner.install_protocol(cap, f)
68    }
69
70    /// Returns the [`SharedCapabilities`] of the underlying raw p2p stream
71    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72        self.inner.shared_capabilities()
73    }
74
75    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
76    pub fn into_satellite_stream<F, Primary>(
77        self,
78        cap: &Capability,
79        primary: F,
80    ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81    where
82        F: FnOnce(ProtocolProxy) -> Primary,
83    {
84        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85        else {
86            return Err(P2PStreamError::CapabilityNotShared)
87        };
88
89        let (to_primary, from_wire) = mpsc::unbounded_channel();
90        let (to_wire, from_primary) = mpsc::unbounded_channel();
91        let proxy = ProtocolProxy {
92            shared_cap: shared_cap.clone(),
93            from_wire: UnboundedReceiverStream::new(from_wire),
94            to_wire,
95        };
96
97        let st = primary(proxy);
98        Ok(RlpxSatelliteStream {
99            inner: self.inner,
100            primary: PrimaryProtocol {
101                to_primary,
102                from_primary: UnboundedReceiverStream::new(from_primary),
103                st,
104                shared_cap,
105            },
106        })
107    }
108
109    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
110    ///
111    /// Returns an error if the primary protocol is not supported by the remote or the handshake
112    /// failed.
113    pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114        self,
115        cap: &Capability,
116        handshake: F,
117    ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118    where
119        F: FnOnce(ProtocolProxy) -> Fut,
120        Fut: Future<Output = Result<Primary, Err>>,
121        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122        P2PStreamError: Into<Err>,
123    {
124        self.into_satellite_stream_with_tuple_handshake(cap, async move |proxy| {
125            let st = handshake(proxy).await?;
126            Ok((st, ()))
127        })
128        .await
129        .map(|(st, _)| st)
130    }
131
132    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
133    ///
134    /// Returns an error if the primary protocol is not supported by the remote or the handshake
135    /// failed.
136    ///
137    /// This accepts a closure that does a handshake with the remote peer and returns a tuple of the
138    /// primary stream and extra data.
139    ///
140    /// See also [`UnauthedEthStream::handshake`](crate::UnauthedEthStream)
141    pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142        mut self,
143        cap: &Capability,
144        handshake: F,
145    ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146    where
147        F: FnOnce(ProtocolProxy) -> Fut,
148        Fut: Future<Output = Result<(Primary, Extra), Err>>,
149        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150        P2PStreamError: Into<Err>,
151    {
152        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153        else {
154            return Err(P2PStreamError::CapabilityNotShared.into())
155        };
156
157        let (to_primary, from_wire) = mpsc::unbounded_channel();
158        let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159        let proxy = ProtocolProxy {
160            shared_cap: shared_cap.clone(),
161            from_wire: UnboundedReceiverStream::new(from_wire),
162            to_wire,
163        };
164
165        let f = handshake(proxy);
166        let mut f = pin!(f);
167
168        // this polls the connection and the primary stream concurrently until the handshake is
169        // complete
170        loop {
171            tokio::select! {
172                biased;
173                Some(Ok(msg)) = self.inner.conn.next() => {
174                    // Ensure the message belongs to the primary protocol
175                    let Some(offset) = msg.first().copied()
176                    else {
177                        return Err(P2PStreamError::EmptyProtocolMessage.into())
178                    };
179                    if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180                            if cap == shared_cap {
181                                // delegate to primary
182                                let _ = to_primary.send(msg);
183                            } else {
184                                // delegate to satellite
185                                self.inner.delegate_message(&cap, msg);
186                            }
187                        } else {
188                           return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189                        }
190                }
191                Some(msg) = from_primary.recv() => {
192                    self.inner.conn.send(msg).await.map_err(Into::into)?;
193                }
194                // Poll all subprotocols for new messages
195                msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196                     self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197                }
198                res = &mut f => {
199                    let (st, extra) = res?;
200                    return Ok((RlpxSatelliteStream {
201                            inner: self.inner,
202                            primary: PrimaryProtocol {
203                                to_primary,
204                                from_primary: UnboundedReceiverStream::new(from_primary),
205                                st,
206                                shared_cap,
207                            }
208                    }, extra))
209                }
210            }
211        }
212    }
213
214    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given
215    /// primary protocol and the handshake implementation.
216    pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217        self,
218        status: UnifiedStatus,
219        fork_filter: ForkFilter,
220        handshake: Arc<dyn EthRlpxHandshake>,
221        eth_max_message_size: usize,
222    ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
223    where
224        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
225    {
226        let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
227        self.into_satellite_stream_with_tuple_handshake(
228            &Capability::eth(eth_cap),
229            async move |proxy| {
230                let handshake = handshake.clone();
231                let mut unauth = UnauthProxy { inner: proxy };
232                let their_status = handshake
233                    .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
234                    .await?;
235                let eth_stream = EthStream::with_max_message_size(
236                    eth_cap,
237                    unauth.into_inner(),
238                    eth_max_message_size,
239                );
240                Ok((eth_stream, their_status))
241            },
242        )
243        .await
244    }
245}
246
247#[derive(Debug)]
248struct MultiplexInner<St> {
249    /// The raw p2p stream
250    conn: P2PStream<St>,
251    /// All the subprotocols that are multiplexed on top of the raw p2p stream
252    protocols: Vec<ProtocolStream>,
253    /// Buffer for outgoing messages on the wire.
254    out_buffer: VecDeque<Bytes>,
255}
256
257impl<St> MultiplexInner<St> {
258    const fn shared_capabilities(&self) -> &SharedCapabilities {
259        self.conn.shared_capabilities()
260    }
261
262    /// Delegates a message to the matching protocol.
263    fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
264        for proto in &self.protocols {
265            if proto.shared_cap == *cap {
266                proto.send_raw(msg);
267                return true
268            }
269        }
270        false
271    }
272
273    fn install_protocol<F, Proto>(
274        &mut self,
275        cap: &Capability,
276        f: F,
277    ) -> Result<(), UnsupportedCapabilityError>
278    where
279        F: FnOnce(ProtocolConnection) -> Proto,
280        Proto: Stream<Item = BytesMut> + Send + 'static,
281    {
282        let shared_cap =
283            self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
284        let (to_satellite, rx) = mpsc::unbounded_channel();
285        let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
286        let st = f(proto_conn);
287        let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
288        self.protocols.push(st);
289        Ok(())
290    }
291}
292
293/// Represents a protocol in the multiplexer that is used as the primary protocol.
294#[derive(Debug)]
295struct PrimaryProtocol<Primary> {
296    /// Channel to send messages to the primary protocol.
297    to_primary: UnboundedSender<BytesMut>,
298    /// Receiver for messages from the primary protocol.
299    from_primary: UnboundedReceiverStream<Bytes>,
300    /// Shared capability of the primary protocol.
301    shared_cap: SharedCapability,
302    /// The primary stream.
303    st: Primary,
304}
305
306/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
307///
308/// Only emits and sends _non-empty_ messages
309#[derive(Debug)]
310pub struct ProtocolProxy {
311    shared_cap: SharedCapability,
312    /// Receives _non-empty_ messages from the wire
313    from_wire: UnboundedReceiverStream<BytesMut>,
314    /// Sends _non-empty_ messages from the wire
315    to_wire: UnboundedSender<Bytes>,
316}
317
318impl ProtocolProxy {
319    /// Sends a _non-empty_ message on the wire.
320    fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
321        if msg.is_empty() {
322            // message must not be empty
323            return Err(io::ErrorKind::InvalidInput.into())
324        }
325        self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
326    }
327
328    /// Masks the message ID of a message to be sent on the wire.
329    #[inline]
330    fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
331        if msg.is_empty() {
332            // message must not be empty
333            return Err(io::ErrorKind::InvalidInput.into())
334        }
335
336        let offset = self.shared_cap.relative_message_id_offset();
337        if offset == 0 {
338            return Ok(msg);
339        }
340
341        let mut masked: BytesMut = msg.into();
342        masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
343        Ok(masked.freeze())
344    }
345
346    /// Unmasks the message ID of a message received from the wire.
347    #[inline]
348    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
349        if msg.is_empty() {
350            // message must not be empty
351            return Err(io::ErrorKind::InvalidInput.into())
352        }
353        msg[0] = msg[0]
354            .checked_sub(self.shared_cap.relative_message_id_offset())
355            .ok_or(io::ErrorKind::InvalidInput)?;
356        Ok(msg)
357    }
358}
359
360impl Stream for ProtocolProxy {
361    type Item = Result<BytesMut, io::Error>;
362
363    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
364        let msg = ready!(self.from_wire.poll_next_unpin(cx));
365        Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
366    }
367}
368
369impl Sink<Bytes> for ProtocolProxy {
370    type Error = io::Error;
371
372    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
373        Poll::Ready(Ok(()))
374    }
375
376    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
377        self.get_mut().try_send(item)
378    }
379
380    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
381        Poll::Ready(Ok(()))
382    }
383
384    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
385        Poll::Ready(Ok(()))
386    }
387}
388
389impl CanDisconnect<Bytes> for ProtocolProxy {
390    fn disconnect(
391        &mut self,
392        _reason: DisconnectReason,
393    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
394        Box::pin(async move { Ok(()) })
395    }
396}
397
398/// Adapter so the injected `EthRlpxHandshake` can run over a multiplexed `ProtocolProxy`
399/// using the same error type expectations (`P2PStreamError`).
400#[derive(Debug)]
401struct UnauthProxy {
402    inner: ProtocolProxy,
403}
404
405impl UnauthProxy {
406    fn into_inner(self) -> ProtocolProxy {
407        self.inner
408    }
409}
410
411impl Stream for UnauthProxy {
412    type Item = Result<BytesMut, P2PStreamError>;
413
414    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
415        self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
416    }
417}
418
419impl Sink<Bytes> for UnauthProxy {
420    type Error = P2PStreamError;
421
422    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
423        self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
424    }
425
426    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
427        self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
428    }
429
430    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
431        self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
432    }
433
434    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
435        self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
436    }
437}
438
439impl CanDisconnect<Bytes> for UnauthProxy {
440    fn disconnect(
441        &mut self,
442        reason: DisconnectReason,
443    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
444        let fut = self.inner.disconnect(reason);
445        Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
446    }
447}
448
449/// A connection channel to receive _`non_empty`_ messages for the negotiated protocol.
450///
451/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
452#[derive(Debug)]
453pub struct ProtocolConnection {
454    from_wire: UnboundedReceiverStream<BytesMut>,
455}
456
457impl Stream for ProtocolConnection {
458    type Item = BytesMut;
459
460    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
461        self.from_wire.poll_next_unpin(cx)
462    }
463}
464
465/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
466/// [`EthStream`] and can also handle additional subprotocols.
467#[derive(Debug)]
468pub struct RlpxSatelliteStream<St, Primary> {
469    inner: MultiplexInner<St>,
470    primary: PrimaryProtocol<Primary>,
471}
472
473impl<St, Primary> RlpxSatelliteStream<St, Primary> {
474    /// Installs a new protocol on top of the raw p2p stream.
475    ///
476    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
477    /// the given capability.
478    pub fn install_protocol<F, Proto>(
479        &mut self,
480        cap: &Capability,
481        f: F,
482    ) -> Result<(), UnsupportedCapabilityError>
483    where
484        F: FnOnce(ProtocolConnection) -> Proto,
485        Proto: Stream<Item = BytesMut> + Send + 'static,
486    {
487        self.inner.install_protocol(cap, f)
488    }
489
490    /// Returns the primary protocol.
491    #[inline]
492    pub const fn primary(&self) -> &Primary {
493        &self.primary.st
494    }
495
496    /// Returns mutable access to the primary protocol.
497    #[inline]
498    pub const fn primary_mut(&mut self) -> &mut Primary {
499        &mut self.primary.st
500    }
501
502    /// Returns the underlying [`P2PStream`].
503    #[inline]
504    pub const fn inner(&self) -> &P2PStream<St> {
505        &self.inner.conn
506    }
507
508    /// Returns mutable access to the underlying [`P2PStream`].
509    #[inline]
510    pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
511        &mut self.inner.conn
512    }
513
514    /// Consumes this type and returns the wrapped [`P2PStream`].
515    #[inline]
516    pub fn into_inner(self) -> P2PStream<St> {
517        self.inner.conn
518    }
519}
520
521impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
522where
523    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
524    Primary: TryStream<Error = PrimaryErr> + Unpin,
525    P2PStreamError: Into<PrimaryErr>,
526{
527    type Item = Result<Primary::Ok, Primary::Error>;
528
529    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
530        let this = self.get_mut();
531
532        loop {
533            // first drain the primary stream
534            if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
535                return Poll::Ready(Some(msg))
536            }
537
538            let mut conn_ready = true;
539            loop {
540                match this.inner.conn.poll_ready_unpin(cx) {
541                    Poll::Ready(Ok(())) => {
542                        if let Some(msg) = this.inner.out_buffer.pop_front() {
543                            if let Err(err) = this.inner.conn.start_send_unpin(msg) {
544                                return Poll::Ready(Some(Err(err.into())))
545                            }
546                        } else {
547                            break
548                        }
549                    }
550                    Poll::Ready(Err(err)) => {
551                        if let Err(disconnect_err) =
552                            this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
553                        {
554                            return Poll::Ready(Some(Err(disconnect_err.into())))
555                        }
556                        return Poll::Ready(Some(Err(err.into())))
557                    }
558                    Poll::Pending => {
559                        conn_ready = false;
560                        break
561                    }
562                }
563            }
564
565            // advance primary out
566            loop {
567                match this.primary.from_primary.poll_next_unpin(cx) {
568                    Poll::Ready(Some(msg)) => {
569                        this.inner.out_buffer.push_back(msg);
570                    }
571                    Poll::Ready(None) => {
572                        // primary closed
573                        return Poll::Ready(None)
574                    }
575                    Poll::Pending => break,
576                }
577            }
578
579            // advance all satellites
580            for idx in (0..this.inner.protocols.len()).rev() {
581                let mut proto = this.inner.protocols.swap_remove(idx);
582                loop {
583                    match proto.poll_next_unpin(cx) {
584                        Poll::Ready(Some(Err(err))) => {
585                            return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
586                        }
587                        Poll::Ready(Some(Ok(msg))) => {
588                            this.inner.out_buffer.push_back(msg);
589                        }
590                        Poll::Ready(None) => return Poll::Ready(None),
591                        Poll::Pending => {
592                            this.inner.protocols.push(proto);
593                            break
594                        }
595                    }
596                }
597            }
598
599            let mut delegated = false;
600            loop {
601                // pull messages from connection
602                match this.inner.conn.poll_next_unpin(cx) {
603                    Poll::Ready(Some(Ok(msg))) => {
604                        delegated = true;
605                        let Some(offset) = msg.first().copied() else {
606                            return Poll::Ready(Some(Err(
607                                P2PStreamError::EmptyProtocolMessage.into()
608                            )))
609                        };
610                        // delegate the multiplexed message to the correct protocol
611                        if let Some(cap) =
612                            this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
613                        {
614                            if cap == &this.primary.shared_cap {
615                                // delegate to primary
616                                let _ = this.primary.to_primary.send(msg);
617                            } else {
618                                // delegate to installed satellite if any
619                                for proto in &this.inner.protocols {
620                                    if proto.shared_cap == *cap {
621                                        proto.send_raw(msg);
622                                        break
623                                    }
624                                }
625                            }
626                        } else {
627                            return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
628                                offset,
629                            )
630                            .into())))
631                        }
632                    }
633                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
634                    Poll::Ready(None) => {
635                        // connection closed
636                        return Poll::Ready(None)
637                    }
638                    Poll::Pending => break,
639                }
640            }
641
642            if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
643                return Poll::Pending
644            }
645        }
646    }
647}
648
649impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
650where
651    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
652    Primary: Sink<T> + Unpin,
653    P2PStreamError: Into<<Primary as Sink<T>>::Error>,
654{
655    type Error = <Primary as Sink<T>>::Error;
656
657    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
658        let this = self.get_mut();
659        if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
660            return Poll::Ready(Err(err.into()))
661        }
662        if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
663            return Poll::Ready(Err(err))
664        }
665        Poll::Ready(Ok(()))
666    }
667
668    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
669        self.get_mut().primary.st.start_send_unpin(item)
670    }
671
672    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
673        self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
674    }
675
676    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
677        self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
678    }
679}
680
681/// Wraps a `RLPx` subprotocol and handles message ID multiplexing.
682struct ProtocolStream {
683    shared_cap: SharedCapability,
684    /// the channel shared with the satellite stream
685    to_satellite: UnboundedSender<BytesMut>,
686    satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
687}
688
689impl ProtocolStream {
690    /// Masks the message ID of a message to be sent on the wire.
691    #[inline]
692    fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
693        if msg.is_empty() {
694            // message must not be empty
695            return Err(io::ErrorKind::InvalidInput.into())
696        }
697        msg[0] = msg[0]
698            .checked_add(self.shared_cap.relative_message_id_offset())
699            .ok_or(io::ErrorKind::InvalidInput)?;
700        Ok(msg.freeze())
701    }
702
703    /// Unmasks the message ID of a message received from the wire.
704    #[inline]
705    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
706        if msg.is_empty() {
707            // message must not be empty
708            return Err(io::ErrorKind::InvalidInput.into())
709        }
710        msg[0] = msg[0]
711            .checked_sub(self.shared_cap.relative_message_id_offset())
712            .ok_or(io::ErrorKind::InvalidInput)?;
713        Ok(msg)
714    }
715
716    /// Sends the message to the satellite stream.
717    fn send_raw(&self, msg: BytesMut) {
718        let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
719    }
720}
721
722impl Stream for ProtocolStream {
723    type Item = Result<Bytes, io::Error>;
724
725    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
726        let this = self.get_mut();
727        let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
728        Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
729    }
730}
731
732impl fmt::Debug for ProtocolStream {
733    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
734        f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
735    }
736}
737
738/// Helper to poll multiple protocol streams in a `tokio::select`! branch
739struct ProtocolsPoller<'a> {
740    protocols: &'a mut Vec<ProtocolStream>,
741}
742
743impl<'a> ProtocolsPoller<'a> {
744    const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
745        Self { protocols }
746    }
747}
748
749impl<'a> Future for ProtocolsPoller<'a> {
750    type Output = Result<Bytes, P2PStreamError>;
751
752    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
753        // Process protocols in reverse order, like the existing pattern
754        for idx in (0..self.protocols.len()).rev() {
755            let mut proto = self.protocols.swap_remove(idx);
756            match proto.poll_next_unpin(cx) {
757                Poll::Ready(Some(Err(err))) => {
758                    self.protocols.push(proto);
759                    return Poll::Ready(Err(P2PStreamError::from(err)))
760                }
761                Poll::Ready(Some(Ok(msg))) => {
762                    // Got a message, put protocol back and return the message
763                    self.protocols.push(proto);
764                    return Poll::Ready(Ok(msg));
765                }
766                _ => {
767                    // push it back because we still want to complete the handshake first
768                    self.protocols.push(proto);
769                }
770            }
771        }
772
773        // All protocols processed, nothing ready
774        Poll::Pending
775    }
776}
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use crate::{
782        handshake::EthHandshake,
783        message::MAX_MESSAGE_SIZE,
784        test_utils::{
785            connect_passthrough, eth_handshake, eth_hello,
786            proto::{test_hello, TestProtoMessage},
787        },
788        UnauthedEthStream, UnauthedP2PStream,
789    };
790    use reth_eth_wire_types::EthNetworkPrimitives;
791    use tokio::{net::TcpListener, sync::oneshot};
792    use tokio_util::codec::Decoder;
793
794    #[tokio::test]
795    async fn eth_satellite() {
796        reth_tracing::init_test_tracing();
797        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
798        let local_addr = listener.local_addr().unwrap();
799        let (status, fork_filter) = eth_handshake();
800        let other_status = status;
801        let other_fork_filter = fork_filter.clone();
802        let _handle = tokio::spawn(async move {
803            let (incoming, _) = listener.accept().await.unwrap();
804            let stream = crate::PassthroughCodec::default().framed(incoming);
805            let (server_hello, _) = eth_hello();
806            let (p2p_stream, _) =
807                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
808
809            let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
810                .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
811                .await
812                .unwrap();
813
814            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
815        });
816
817        let conn = connect_passthrough(local_addr, eth_hello().0).await;
818        let eth = conn.shared_capabilities().eth().unwrap().clone();
819
820        let multiplexer = RlpxProtocolMultiplexer::new(conn);
821        let _satellite = multiplexer
822            .into_satellite_stream_with_handshake(eth.capability().as_ref(), async move |proxy| {
823                UnauthedEthStream::new(proxy)
824                    .handshake::<EthNetworkPrimitives>(status, fork_filter)
825                    .await
826            })
827            .await
828            .unwrap();
829    }
830
831    /// A test that install a satellite stream eth+test protocol and sends messages between them.
832    #[tokio::test(flavor = "multi_thread")]
833    async fn eth_test_protocol_satellite() {
834        reth_tracing::init_test_tracing();
835        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
836        let local_addr = listener.local_addr().unwrap();
837        let (status, fork_filter) = eth_handshake();
838        let other_status = status;
839        let other_fork_filter = fork_filter.clone();
840        let _handle = tokio::spawn(async move {
841            let (incoming, _) = listener.accept().await.unwrap();
842            let stream = crate::PassthroughCodec::default().framed(incoming);
843            let (server_hello, _) = test_hello();
844            let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
845
846            let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
847                .into_eth_satellite_stream::<EthNetworkPrimitives>(
848                    other_status,
849                    other_fork_filter,
850                    Arc::new(EthHandshake::default()),
851                    MAX_MESSAGE_SIZE,
852                )
853                .await
854                .unwrap();
855
856            st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
857                async_stream::stream! {
858                    yield TestProtoMessage::ping().encoded();
859                    let msg = conn.next().await.unwrap();
860                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
861                    assert_eq!(msg, TestProtoMessage::pong());
862
863                    yield TestProtoMessage::message("hello").encoded();
864                    let msg = conn.next().await.unwrap();
865                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
866                    assert_eq!(msg, TestProtoMessage::message("good bye!"));
867
868                    yield TestProtoMessage::message("good bye!").encoded();
869
870                    futures::future::pending::<()>().await;
871                    unreachable!()
872                }
873            })
874            .unwrap();
875
876            loop {
877                let _ = st.next().await;
878            }
879        });
880
881        let conn = connect_passthrough(local_addr, test_hello().0).await;
882        let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
883            .into_eth_satellite_stream::<EthNetworkPrimitives>(
884                status,
885                fork_filter,
886                Arc::new(EthHandshake::default()),
887                MAX_MESSAGE_SIZE,
888            )
889            .await
890            .unwrap();
891
892        let (tx, mut rx) = oneshot::channel();
893
894        st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
895            async_stream::stream! {
896                let msg = conn.next().await.unwrap();
897                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
898                assert_eq!(msg, TestProtoMessage::ping());
899
900                yield TestProtoMessage::pong().encoded();
901
902                let msg = conn.next().await.unwrap();
903                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
904                assert_eq!(msg, TestProtoMessage::message("hello"));
905
906                yield TestProtoMessage::message("good bye!").encoded();
907
908                let msg = conn.next().await.unwrap();
909                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
910                assert_eq!(msg, TestProtoMessage::message("good bye!"));
911
912                tx.send(()).unwrap();
913
914                futures::future::pending::<()>().await;
915                unreachable!()
916            }
917        })
918        .unwrap();
919
920        loop {
921            tokio::select! {
922                _ = &mut rx => {
923                    break
924                }
925               _ = st.next() => {
926                }
927            }
928        }
929    }
930}