Skip to main content

reth_eth_wire/
multiplex.rs

1//! Rlpx protocol multiplexer and satellite stream
2//!
3//! A Satellite is a Stream that primarily drives a single `RLPx` subprotocol but can also handle
4//! additional subprotocols.
5//!
6//! Most of other subprotocols are "dependent satellite" protocols of "eth" and not a fully standalone protocol, for example "snap", See also [snap protocol](https://github.com/ethereum/devp2p/blob/298d7a77c3bf833641579ecbbb5b13f0311eeeea/caps/snap.md?plain=1#L71)
7//! Hence it is expected that the primary protocol is "eth" and the additional protocols are
8//! "dependent satellite" protocols.
9
10use std::{
11    collections::VecDeque,
12    fmt,
13    future::Future,
14    io,
15    pin::{pin, Pin},
16    sync::Arc,
17    task::{ready, Context, Poll},
18};
19
20use crate::{
21    capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22    errors::{EthStreamError, P2PStreamError},
23    handshake::EthRlpxHandshake,
24    p2pstream::DisconnectP2P,
25    CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26    HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35/// A Stream and Sink type that wraps a raw rlpx stream [`P2PStream`] and handles message ID
36/// multiplexing.
37#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39    inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43    /// Wraps the raw p2p stream
44    pub fn new(conn: P2PStream<St>) -> Self {
45        Self {
46            inner: MultiplexInner {
47                conn,
48                protocols: Default::default(),
49                out_buffer: Default::default(),
50            },
51        }
52    }
53
54    /// Installs a new protocol on top of the raw p2p stream.
55    ///
56    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
57    /// the given capability.
58    pub fn install_protocol<F, Proto>(
59        &mut self,
60        cap: &Capability,
61        f: F,
62    ) -> Result<(), UnsupportedCapabilityError>
63    where
64        F: FnOnce(ProtocolConnection) -> Proto,
65        Proto: Stream<Item = BytesMut> + Send + 'static,
66    {
67        self.inner.install_protocol(cap, f)
68    }
69
70    /// Returns the [`SharedCapabilities`] of the underlying raw p2p stream
71    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72        self.inner.shared_capabilities()
73    }
74
75    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
76    pub fn into_satellite_stream<F, Primary>(
77        self,
78        cap: &Capability,
79        primary: F,
80    ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81    where
82        F: FnOnce(ProtocolProxy) -> Primary,
83    {
84        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85        else {
86            return Err(P2PStreamError::CapabilityNotShared)
87        };
88
89        let (to_primary, from_wire) = mpsc::unbounded_channel();
90        let (to_wire, from_primary) = mpsc::unbounded_channel();
91        let proxy = ProtocolProxy {
92            shared_cap: shared_cap.clone(),
93            from_wire: UnboundedReceiverStream::new(from_wire),
94            to_wire,
95        };
96
97        let st = primary(proxy);
98        Ok(RlpxSatelliteStream {
99            inner: self.inner,
100            primary: PrimaryProtocol {
101                to_primary,
102                from_primary: UnboundedReceiverStream::new(from_primary),
103                st,
104                shared_cap,
105            },
106        })
107    }
108
109    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
110    ///
111    /// Returns an error if the primary protocol is not supported by the remote or the handshake
112    /// failed.
113    pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114        self,
115        cap: &Capability,
116        handshake: F,
117    ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118    where
119        F: FnOnce(ProtocolProxy) -> Fut,
120        Fut: Future<Output = Result<Primary, Err>>,
121        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122        P2PStreamError: Into<Err>,
123    {
124        self.into_satellite_stream_with_tuple_handshake(cap, async move |proxy| {
125            let st = handshake(proxy).await?;
126            Ok((st, ()))
127        })
128        .await
129        .map(|(st, _)| st)
130    }
131
132    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
133    ///
134    /// Returns an error if the primary protocol is not supported by the remote or the handshake
135    /// failed.
136    ///
137    /// This accepts a closure that does a handshake with the remote peer and returns a tuple of the
138    /// primary stream and extra data.
139    ///
140    /// See also [`UnauthedEthStream::handshake`](crate::UnauthedEthStream)
141    pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142        mut self,
143        cap: &Capability,
144        handshake: F,
145    ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146    where
147        F: FnOnce(ProtocolProxy) -> Fut,
148        Fut: Future<Output = Result<(Primary, Extra), Err>>,
149        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150        P2PStreamError: Into<Err>,
151    {
152        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153        else {
154            return Err(P2PStreamError::CapabilityNotShared.into())
155        };
156
157        let (to_primary, from_wire) = mpsc::unbounded_channel();
158        let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159        let proxy = ProtocolProxy {
160            shared_cap: shared_cap.clone(),
161            from_wire: UnboundedReceiverStream::new(from_wire),
162            to_wire,
163        };
164
165        let f = handshake(proxy);
166        let mut f = pin!(f);
167
168        // this polls the connection and the primary stream concurrently until the handshake is
169        // complete
170        loop {
171            tokio::select! {
172                biased;
173                Some(Ok(msg)) = self.inner.conn.next() => {
174                    // Ensure the message belongs to the primary protocol
175                    let Some(offset) = msg.first().copied()
176                    else {
177                        return Err(P2PStreamError::EmptyProtocolMessage.into())
178                    };
179                    if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180                            if cap == shared_cap {
181                                // delegate to primary
182                                let _ = to_primary.send(msg);
183                            } else {
184                                // delegate to satellite
185                                self.inner.delegate_message(&cap, msg);
186                            }
187                        } else {
188                           return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189                        }
190                }
191                Some(msg) = from_primary.recv() => {
192                    self.inner.conn.send(msg).await.map_err(Into::into)?;
193                }
194                // Poll all subprotocols for new messages
195                msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196                     self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197                }
198                res = &mut f => {
199                    let (st, extra) = res?;
200                    return Ok((RlpxSatelliteStream {
201                            inner: self.inner,
202                            primary: PrimaryProtocol {
203                                to_primary,
204                                from_primary: UnboundedReceiverStream::new(from_primary),
205                                st,
206                                shared_cap,
207                            }
208                    }, extra))
209                }
210            }
211        }
212    }
213
214    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given
215    /// primary protocol and the handshake implementation.
216    pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217        self,
218        status: UnifiedStatus,
219        fork_filter: ForkFilter,
220        handshake: Arc<dyn EthRlpxHandshake>,
221    ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
222    where
223        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
224    {
225        let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
226        self.into_satellite_stream_with_tuple_handshake(
227            &Capability::eth(eth_cap),
228            async move |proxy| {
229                let handshake = handshake.clone();
230                let mut unauth = UnauthProxy { inner: proxy };
231                let their_status = handshake
232                    .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
233                    .await?;
234                let eth_stream = EthStream::new(eth_cap, unauth.into_inner());
235                Ok((eth_stream, their_status))
236            },
237        )
238        .await
239    }
240}
241
242#[derive(Debug)]
243struct MultiplexInner<St> {
244    /// The raw p2p stream
245    conn: P2PStream<St>,
246    /// All the subprotocols that are multiplexed on top of the raw p2p stream
247    protocols: Vec<ProtocolStream>,
248    /// Buffer for outgoing messages on the wire.
249    out_buffer: VecDeque<Bytes>,
250}
251
252impl<St> MultiplexInner<St> {
253    const fn shared_capabilities(&self) -> &SharedCapabilities {
254        self.conn.shared_capabilities()
255    }
256
257    /// Delegates a message to the matching protocol.
258    fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
259        for proto in &self.protocols {
260            if proto.shared_cap == *cap {
261                proto.send_raw(msg);
262                return true
263            }
264        }
265        false
266    }
267
268    fn install_protocol<F, Proto>(
269        &mut self,
270        cap: &Capability,
271        f: F,
272    ) -> Result<(), UnsupportedCapabilityError>
273    where
274        F: FnOnce(ProtocolConnection) -> Proto,
275        Proto: Stream<Item = BytesMut> + Send + 'static,
276    {
277        let shared_cap =
278            self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
279        let (to_satellite, rx) = mpsc::unbounded_channel();
280        let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
281        let st = f(proto_conn);
282        let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
283        self.protocols.push(st);
284        Ok(())
285    }
286}
287
288/// Represents a protocol in the multiplexer that is used as the primary protocol.
289#[derive(Debug)]
290struct PrimaryProtocol<Primary> {
291    /// Channel to send messages to the primary protocol.
292    to_primary: UnboundedSender<BytesMut>,
293    /// Receiver for messages from the primary protocol.
294    from_primary: UnboundedReceiverStream<Bytes>,
295    /// Shared capability of the primary protocol.
296    shared_cap: SharedCapability,
297    /// The primary stream.
298    st: Primary,
299}
300
301/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
302///
303/// Only emits and sends _non-empty_ messages
304#[derive(Debug)]
305pub struct ProtocolProxy {
306    shared_cap: SharedCapability,
307    /// Receives _non-empty_ messages from the wire
308    from_wire: UnboundedReceiverStream<BytesMut>,
309    /// Sends _non-empty_ messages from the wire
310    to_wire: UnboundedSender<Bytes>,
311}
312
313impl ProtocolProxy {
314    /// Sends a _non-empty_ message on the wire.
315    fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
316        if msg.is_empty() {
317            // message must not be empty
318            return Err(io::ErrorKind::InvalidInput.into())
319        }
320        self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
321    }
322
323    /// Masks the message ID of a message to be sent on the wire.
324    #[inline]
325    fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
326        if msg.is_empty() {
327            // message must not be empty
328            return Err(io::ErrorKind::InvalidInput.into())
329        }
330
331        let offset = self.shared_cap.relative_message_id_offset();
332        if offset == 0 {
333            return Ok(msg);
334        }
335
336        let mut masked: BytesMut = msg.into();
337        masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
338        Ok(masked.freeze())
339    }
340
341    /// Unmasks the message ID of a message received from the wire.
342    #[inline]
343    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
344        if msg.is_empty() {
345            // message must not be empty
346            return Err(io::ErrorKind::InvalidInput.into())
347        }
348        msg[0] = msg[0]
349            .checked_sub(self.shared_cap.relative_message_id_offset())
350            .ok_or(io::ErrorKind::InvalidInput)?;
351        Ok(msg)
352    }
353}
354
355impl Stream for ProtocolProxy {
356    type Item = Result<BytesMut, io::Error>;
357
358    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
359        let msg = ready!(self.from_wire.poll_next_unpin(cx));
360        Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
361    }
362}
363
364impl Sink<Bytes> for ProtocolProxy {
365    type Error = io::Error;
366
367    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
368        Poll::Ready(Ok(()))
369    }
370
371    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
372        self.get_mut().try_send(item)
373    }
374
375    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
376        Poll::Ready(Ok(()))
377    }
378
379    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
380        Poll::Ready(Ok(()))
381    }
382}
383
384impl CanDisconnect<Bytes> for ProtocolProxy {
385    fn disconnect(
386        &mut self,
387        _reason: DisconnectReason,
388    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
389        Box::pin(async move { Ok(()) })
390    }
391}
392
393/// Adapter so the injected `EthRlpxHandshake` can run over a multiplexed `ProtocolProxy`
394/// using the same error type expectations (`P2PStreamError`).
395#[derive(Debug)]
396struct UnauthProxy {
397    inner: ProtocolProxy,
398}
399
400impl UnauthProxy {
401    fn into_inner(self) -> ProtocolProxy {
402        self.inner
403    }
404}
405
406impl Stream for UnauthProxy {
407    type Item = Result<BytesMut, P2PStreamError>;
408
409    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410        self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
411    }
412}
413
414impl Sink<Bytes> for UnauthProxy {
415    type Error = P2PStreamError;
416
417    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
418        self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
419    }
420
421    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
422        self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
423    }
424
425    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
426        self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
427    }
428
429    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
430        self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
431    }
432}
433
434impl CanDisconnect<Bytes> for UnauthProxy {
435    fn disconnect(
436        &mut self,
437        reason: DisconnectReason,
438    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
439        let fut = self.inner.disconnect(reason);
440        Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
441    }
442}
443
444/// A connection channel to receive _`non_empty`_ messages for the negotiated protocol.
445///
446/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
447#[derive(Debug)]
448pub struct ProtocolConnection {
449    from_wire: UnboundedReceiverStream<BytesMut>,
450}
451
452impl Stream for ProtocolConnection {
453    type Item = BytesMut;
454
455    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
456        self.from_wire.poll_next_unpin(cx)
457    }
458}
459
460/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
461/// [`EthStream`] and can also handle additional subprotocols.
462#[derive(Debug)]
463pub struct RlpxSatelliteStream<St, Primary> {
464    inner: MultiplexInner<St>,
465    primary: PrimaryProtocol<Primary>,
466}
467
468impl<St, Primary> RlpxSatelliteStream<St, Primary> {
469    /// Installs a new protocol on top of the raw p2p stream.
470    ///
471    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
472    /// the given capability.
473    pub fn install_protocol<F, Proto>(
474        &mut self,
475        cap: &Capability,
476        f: F,
477    ) -> Result<(), UnsupportedCapabilityError>
478    where
479        F: FnOnce(ProtocolConnection) -> Proto,
480        Proto: Stream<Item = BytesMut> + Send + 'static,
481    {
482        self.inner.install_protocol(cap, f)
483    }
484
485    /// Returns the primary protocol.
486    #[inline]
487    pub const fn primary(&self) -> &Primary {
488        &self.primary.st
489    }
490
491    /// Returns mutable access to the primary protocol.
492    #[inline]
493    pub const fn primary_mut(&mut self) -> &mut Primary {
494        &mut self.primary.st
495    }
496
497    /// Returns the underlying [`P2PStream`].
498    #[inline]
499    pub const fn inner(&self) -> &P2PStream<St> {
500        &self.inner.conn
501    }
502
503    /// Returns mutable access to the underlying [`P2PStream`].
504    #[inline]
505    pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
506        &mut self.inner.conn
507    }
508
509    /// Consumes this type and returns the wrapped [`P2PStream`].
510    #[inline]
511    pub fn into_inner(self) -> P2PStream<St> {
512        self.inner.conn
513    }
514}
515
516impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
517where
518    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
519    Primary: TryStream<Error = PrimaryErr> + Unpin,
520    P2PStreamError: Into<PrimaryErr>,
521{
522    type Item = Result<Primary::Ok, Primary::Error>;
523
524    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
525        let this = self.get_mut();
526
527        loop {
528            // first drain the primary stream
529            if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
530                return Poll::Ready(Some(msg))
531            }
532
533            let mut conn_ready = true;
534            loop {
535                match this.inner.conn.poll_ready_unpin(cx) {
536                    Poll::Ready(Ok(())) => {
537                        if let Some(msg) = this.inner.out_buffer.pop_front() {
538                            if let Err(err) = this.inner.conn.start_send_unpin(msg) {
539                                return Poll::Ready(Some(Err(err.into())))
540                            }
541                        } else {
542                            break
543                        }
544                    }
545                    Poll::Ready(Err(err)) => {
546                        if let Err(disconnect_err) =
547                            this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
548                        {
549                            return Poll::Ready(Some(Err(disconnect_err.into())))
550                        }
551                        return Poll::Ready(Some(Err(err.into())))
552                    }
553                    Poll::Pending => {
554                        conn_ready = false;
555                        break
556                    }
557                }
558            }
559
560            // advance primary out
561            loop {
562                match this.primary.from_primary.poll_next_unpin(cx) {
563                    Poll::Ready(Some(msg)) => {
564                        this.inner.out_buffer.push_back(msg);
565                    }
566                    Poll::Ready(None) => {
567                        // primary closed
568                        return Poll::Ready(None)
569                    }
570                    Poll::Pending => break,
571                }
572            }
573
574            // advance all satellites
575            for idx in (0..this.inner.protocols.len()).rev() {
576                let mut proto = this.inner.protocols.swap_remove(idx);
577                loop {
578                    match proto.poll_next_unpin(cx) {
579                        Poll::Ready(Some(Err(err))) => {
580                            return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
581                        }
582                        Poll::Ready(Some(Ok(msg))) => {
583                            this.inner.out_buffer.push_back(msg);
584                        }
585                        Poll::Ready(None) => return Poll::Ready(None),
586                        Poll::Pending => {
587                            this.inner.protocols.push(proto);
588                            break
589                        }
590                    }
591                }
592            }
593
594            let mut delegated = false;
595            loop {
596                // pull messages from connection
597                match this.inner.conn.poll_next_unpin(cx) {
598                    Poll::Ready(Some(Ok(msg))) => {
599                        delegated = true;
600                        let Some(offset) = msg.first().copied() else {
601                            return Poll::Ready(Some(Err(
602                                P2PStreamError::EmptyProtocolMessage.into()
603                            )))
604                        };
605                        // delegate the multiplexed message to the correct protocol
606                        if let Some(cap) =
607                            this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
608                        {
609                            if cap == &this.primary.shared_cap {
610                                // delegate to primary
611                                let _ = this.primary.to_primary.send(msg);
612                            } else {
613                                // delegate to installed satellite if any
614                                for proto in &this.inner.protocols {
615                                    if proto.shared_cap == *cap {
616                                        proto.send_raw(msg);
617                                        break
618                                    }
619                                }
620                            }
621                        } else {
622                            return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
623                                offset,
624                            )
625                            .into())))
626                        }
627                    }
628                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
629                    Poll::Ready(None) => {
630                        // connection closed
631                        return Poll::Ready(None)
632                    }
633                    Poll::Pending => break,
634                }
635            }
636
637            if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
638                return Poll::Pending
639            }
640        }
641    }
642}
643
644impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
645where
646    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
647    Primary: Sink<T> + Unpin,
648    P2PStreamError: Into<<Primary as Sink<T>>::Error>,
649{
650    type Error = <Primary as Sink<T>>::Error;
651
652    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
653        let this = self.get_mut();
654        if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
655            return Poll::Ready(Err(err.into()))
656        }
657        if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
658            return Poll::Ready(Err(err))
659        }
660        Poll::Ready(Ok(()))
661    }
662
663    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
664        self.get_mut().primary.st.start_send_unpin(item)
665    }
666
667    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
668        self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
669    }
670
671    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
672        self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
673    }
674}
675
676/// Wraps a `RLPx` subprotocol and handles message ID multiplexing.
677struct ProtocolStream {
678    shared_cap: SharedCapability,
679    /// the channel shared with the satellite stream
680    to_satellite: UnboundedSender<BytesMut>,
681    satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
682}
683
684impl ProtocolStream {
685    /// Masks the message ID of a message to be sent on the wire.
686    #[inline]
687    fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
688        if msg.is_empty() {
689            // message must not be empty
690            return Err(io::ErrorKind::InvalidInput.into())
691        }
692        msg[0] = msg[0]
693            .checked_add(self.shared_cap.relative_message_id_offset())
694            .ok_or(io::ErrorKind::InvalidInput)?;
695        Ok(msg.freeze())
696    }
697
698    /// Unmasks the message ID of a message received from the wire.
699    #[inline]
700    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
701        if msg.is_empty() {
702            // message must not be empty
703            return Err(io::ErrorKind::InvalidInput.into())
704        }
705        msg[0] = msg[0]
706            .checked_sub(self.shared_cap.relative_message_id_offset())
707            .ok_or(io::ErrorKind::InvalidInput)?;
708        Ok(msg)
709    }
710
711    /// Sends the message to the satellite stream.
712    fn send_raw(&self, msg: BytesMut) {
713        let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
714    }
715}
716
717impl Stream for ProtocolStream {
718    type Item = Result<Bytes, io::Error>;
719
720    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
721        let this = self.get_mut();
722        let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
723        Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
724    }
725}
726
727impl fmt::Debug for ProtocolStream {
728    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
729        f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
730    }
731}
732
733/// Helper to poll multiple protocol streams in a `tokio::select`! branch
734struct ProtocolsPoller<'a> {
735    protocols: &'a mut Vec<ProtocolStream>,
736}
737
738impl<'a> ProtocolsPoller<'a> {
739    const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
740        Self { protocols }
741    }
742}
743
744impl<'a> Future for ProtocolsPoller<'a> {
745    type Output = Result<Bytes, P2PStreamError>;
746
747    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
748        // Process protocols in reverse order, like the existing pattern
749        for idx in (0..self.protocols.len()).rev() {
750            let mut proto = self.protocols.swap_remove(idx);
751            match proto.poll_next_unpin(cx) {
752                Poll::Ready(Some(Err(err))) => {
753                    self.protocols.push(proto);
754                    return Poll::Ready(Err(P2PStreamError::from(err)))
755                }
756                Poll::Ready(Some(Ok(msg))) => {
757                    // Got a message, put protocol back and return the message
758                    self.protocols.push(proto);
759                    return Poll::Ready(Ok(msg));
760                }
761                _ => {
762                    // push it back because we still want to complete the handshake first
763                    self.protocols.push(proto);
764                }
765            }
766        }
767
768        // All protocols processed, nothing ready
769        Poll::Pending
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use crate::{
777        handshake::EthHandshake,
778        test_utils::{
779            connect_passthrough, eth_handshake, eth_hello,
780            proto::{test_hello, TestProtoMessage},
781        },
782        UnauthedEthStream, UnauthedP2PStream,
783    };
784    use reth_eth_wire_types::EthNetworkPrimitives;
785    use tokio::{net::TcpListener, sync::oneshot};
786    use tokio_util::codec::Decoder;
787
788    #[tokio::test]
789    async fn eth_satellite() {
790        reth_tracing::init_test_tracing();
791        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
792        let local_addr = listener.local_addr().unwrap();
793        let (status, fork_filter) = eth_handshake();
794        let other_status = status;
795        let other_fork_filter = fork_filter.clone();
796        let _handle = tokio::spawn(async move {
797            let (incoming, _) = listener.accept().await.unwrap();
798            let stream = crate::PassthroughCodec::default().framed(incoming);
799            let (server_hello, _) = eth_hello();
800            let (p2p_stream, _) =
801                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
802
803            let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
804                .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
805                .await
806                .unwrap();
807
808            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
809        });
810
811        let conn = connect_passthrough(local_addr, eth_hello().0).await;
812        let eth = conn.shared_capabilities().eth().unwrap().clone();
813
814        let multiplexer = RlpxProtocolMultiplexer::new(conn);
815        let _satellite = multiplexer
816            .into_satellite_stream_with_handshake(eth.capability().as_ref(), async move |proxy| {
817                UnauthedEthStream::new(proxy)
818                    .handshake::<EthNetworkPrimitives>(status, fork_filter)
819                    .await
820            })
821            .await
822            .unwrap();
823    }
824
825    /// A test that install a satellite stream eth+test protocol and sends messages between them.
826    #[tokio::test(flavor = "multi_thread")]
827    async fn eth_test_protocol_satellite() {
828        reth_tracing::init_test_tracing();
829        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
830        let local_addr = listener.local_addr().unwrap();
831        let (status, fork_filter) = eth_handshake();
832        let other_status = status;
833        let other_fork_filter = fork_filter.clone();
834        let _handle = tokio::spawn(async move {
835            let (incoming, _) = listener.accept().await.unwrap();
836            let stream = crate::PassthroughCodec::default().framed(incoming);
837            let (server_hello, _) = test_hello();
838            let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
839
840            let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
841                .into_eth_satellite_stream::<EthNetworkPrimitives>(
842                    other_status,
843                    other_fork_filter,
844                    Arc::new(EthHandshake::default()),
845                )
846                .await
847                .unwrap();
848
849            st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
850                async_stream::stream! {
851                    yield TestProtoMessage::ping().encoded();
852                    let msg = conn.next().await.unwrap();
853                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
854                    assert_eq!(msg, TestProtoMessage::pong());
855
856                    yield TestProtoMessage::message("hello").encoded();
857                    let msg = conn.next().await.unwrap();
858                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
859                    assert_eq!(msg, TestProtoMessage::message("good bye!"));
860
861                    yield TestProtoMessage::message("good bye!").encoded();
862
863                    futures::future::pending::<()>().await;
864                    unreachable!()
865                }
866            })
867            .unwrap();
868
869            loop {
870                let _ = st.next().await;
871            }
872        });
873
874        let conn = connect_passthrough(local_addr, test_hello().0).await;
875        let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
876            .into_eth_satellite_stream::<EthNetworkPrimitives>(
877                status,
878                fork_filter,
879                Arc::new(EthHandshake::default()),
880            )
881            .await
882            .unwrap();
883
884        let (tx, mut rx) = oneshot::channel();
885
886        st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
887            async_stream::stream! {
888                let msg = conn.next().await.unwrap();
889                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
890                assert_eq!(msg, TestProtoMessage::ping());
891
892                yield TestProtoMessage::pong().encoded();
893
894                let msg = conn.next().await.unwrap();
895                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
896                assert_eq!(msg, TestProtoMessage::message("hello"));
897
898                yield TestProtoMessage::message("good bye!").encoded();
899
900                let msg = conn.next().await.unwrap();
901                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
902                assert_eq!(msg, TestProtoMessage::message("good bye!"));
903
904                tx.send(()).unwrap();
905
906                futures::future::pending::<()>().await;
907                unreachable!()
908            }
909        })
910        .unwrap();
911
912        loop {
913            tokio::select! {
914                _ = &mut rx => {
915                    break
916                }
917               _ = st.next() => {
918                }
919            }
920        }
921    }
922}