reth_eth_wire/
multiplex.rs

1//! Rlpx protocol multiplexer and satellite stream
2//!
3//! A Satellite is a Stream that primarily drives a single `RLPx` subprotocol but can also handle
4//! additional subprotocols.
5//!
6//! Most of other subprotocols are "dependent satellite" protocols of "eth" and not a fully standalone protocol, for example "snap", See also [snap protocol](https://github.com/ethereum/devp2p/blob/298d7a77c3bf833641579ecbbb5b13f0311eeeea/caps/snap.md?plain=1#L71)
7//! Hence it is expected that the primary protocol is "eth" and the additional protocols are
8//! "dependent satellite" protocols.
9
10use std::{
11    collections::VecDeque,
12    fmt,
13    future::Future,
14    io,
15    pin::{pin, Pin},
16    sync::Arc,
17    task::{ready, Context, Poll},
18};
19
20use crate::{
21    capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22    errors::{EthStreamError, P2PStreamError},
23    handshake::EthRlpxHandshake,
24    p2pstream::DisconnectP2P,
25    CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26    HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35/// A Stream and Sink type that wraps a raw rlpx stream [`P2PStream`] and handles message ID
36/// multiplexing.
37#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39    inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43    /// Wraps the raw p2p stream
44    pub fn new(conn: P2PStream<St>) -> Self {
45        Self {
46            inner: MultiplexInner {
47                conn,
48                protocols: Default::default(),
49                out_buffer: Default::default(),
50            },
51        }
52    }
53
54    /// Installs a new protocol on top of the raw p2p stream.
55    ///
56    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
57    /// the given capability.
58    pub fn install_protocol<F, Proto>(
59        &mut self,
60        cap: &Capability,
61        f: F,
62    ) -> Result<(), UnsupportedCapabilityError>
63    where
64        F: FnOnce(ProtocolConnection) -> Proto,
65        Proto: Stream<Item = BytesMut> + Send + 'static,
66    {
67        self.inner.install_protocol(cap, f)
68    }
69
70    /// Returns the [`SharedCapabilities`] of the underlying raw p2p stream
71    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72        self.inner.shared_capabilities()
73    }
74
75    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
76    pub fn into_satellite_stream<F, Primary>(
77        self,
78        cap: &Capability,
79        primary: F,
80    ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81    where
82        F: FnOnce(ProtocolProxy) -> Primary,
83    {
84        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85        else {
86            return Err(P2PStreamError::CapabilityNotShared)
87        };
88
89        let (to_primary, from_wire) = mpsc::unbounded_channel();
90        let (to_wire, from_primary) = mpsc::unbounded_channel();
91        let proxy = ProtocolProxy {
92            shared_cap: shared_cap.clone(),
93            from_wire: UnboundedReceiverStream::new(from_wire),
94            to_wire,
95        };
96
97        let st = primary(proxy);
98        Ok(RlpxSatelliteStream {
99            inner: self.inner,
100            primary: PrimaryProtocol {
101                to_primary,
102                from_primary: UnboundedReceiverStream::new(from_primary),
103                st,
104                shared_cap,
105            },
106        })
107    }
108
109    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
110    ///
111    /// Returns an error if the primary protocol is not supported by the remote or the handshake
112    /// failed.
113    pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114        self,
115        cap: &Capability,
116        handshake: F,
117    ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118    where
119        F: FnOnce(ProtocolProxy) -> Fut,
120        Fut: Future<Output = Result<Primary, Err>>,
121        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122        P2PStreamError: Into<Err>,
123    {
124        self.into_satellite_stream_with_tuple_handshake(cap, move |proxy| async move {
125            let st = handshake(proxy).await?;
126            Ok((st, ()))
127        })
128        .await
129        .map(|(st, _)| st)
130    }
131
132    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
133    ///
134    /// Returns an error if the primary protocol is not supported by the remote or the handshake
135    /// failed.
136    ///
137    /// This accepts a closure that does a handshake with the remote peer and returns a tuple of the
138    /// primary stream and extra data.
139    ///
140    /// See also [`UnauthedEthStream::handshake`](crate::UnauthedEthStream)
141    pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142        mut self,
143        cap: &Capability,
144        handshake: F,
145    ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146    where
147        F: FnOnce(ProtocolProxy) -> Fut,
148        Fut: Future<Output = Result<(Primary, Extra), Err>>,
149        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150        P2PStreamError: Into<Err>,
151    {
152        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153        else {
154            return Err(P2PStreamError::CapabilityNotShared.into())
155        };
156
157        let (to_primary, from_wire) = mpsc::unbounded_channel();
158        let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159        let proxy = ProtocolProxy {
160            shared_cap: shared_cap.clone(),
161            from_wire: UnboundedReceiverStream::new(from_wire),
162            to_wire,
163        };
164
165        let f = handshake(proxy);
166        let mut f = pin!(f);
167
168        // this polls the connection and the primary stream concurrently until the handshake is
169        // complete
170        loop {
171            tokio::select! {
172                biased;
173                Some(Ok(msg)) = self.inner.conn.next() => {
174                    // Ensure the message belongs to the primary protocol
175                    let Some(offset) = msg.first().copied()
176                    else {
177                        return Err(P2PStreamError::EmptyProtocolMessage.into())
178                    };
179                    if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180                            if cap == shared_cap {
181                                // delegate to primary
182                                let _ = to_primary.send(msg);
183                            } else {
184                                // delegate to satellite
185                                self.inner.delegate_message(&cap, msg);
186                            }
187                        } else {
188                           return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189                        }
190                }
191                Some(msg) = from_primary.recv() => {
192                    self.inner.conn.send(msg).await.map_err(Into::into)?;
193                }
194                // Poll all subprotocols for new messages
195                msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196                     self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197                }
198                res = &mut f => {
199                    let (st, extra) = res?;
200                    return Ok((RlpxSatelliteStream {
201                            inner: self.inner,
202                            primary: PrimaryProtocol {
203                                to_primary,
204                                from_primary: UnboundedReceiverStream::new(from_primary),
205                                st,
206                                shared_cap,
207                            }
208                    }, extra))
209                }
210            }
211        }
212    }
213
214    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given
215    /// primary protocol and the handshake implementation.
216    pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217        self,
218        status: UnifiedStatus,
219        fork_filter: ForkFilter,
220        handshake: Arc<dyn EthRlpxHandshake>,
221    ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
222    where
223        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
224    {
225        let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
226        self.into_satellite_stream_with_tuple_handshake(&Capability::eth(eth_cap), move |proxy| {
227            let handshake = handshake.clone();
228            async move {
229                let mut unauth = UnauthProxy { inner: proxy };
230                let their_status = handshake
231                    .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
232                    .await?;
233                let eth_stream = EthStream::new(eth_cap, unauth.into_inner());
234                Ok((eth_stream, their_status))
235            }
236        })
237        .await
238    }
239}
240
241#[derive(Debug)]
242struct MultiplexInner<St> {
243    /// The raw p2p stream
244    conn: P2PStream<St>,
245    /// All the subprotocols that are multiplexed on top of the raw p2p stream
246    protocols: Vec<ProtocolStream>,
247    /// Buffer for outgoing messages on the wire.
248    out_buffer: VecDeque<Bytes>,
249}
250
251impl<St> MultiplexInner<St> {
252    const fn shared_capabilities(&self) -> &SharedCapabilities {
253        self.conn.shared_capabilities()
254    }
255
256    /// Delegates a message to the matching protocol.
257    fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
258        for proto in &self.protocols {
259            if proto.shared_cap == *cap {
260                proto.send_raw(msg);
261                return true
262            }
263        }
264        false
265    }
266
267    fn install_protocol<F, Proto>(
268        &mut self,
269        cap: &Capability,
270        f: F,
271    ) -> Result<(), UnsupportedCapabilityError>
272    where
273        F: FnOnce(ProtocolConnection) -> Proto,
274        Proto: Stream<Item = BytesMut> + Send + 'static,
275    {
276        let shared_cap =
277            self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
278        let (to_satellite, rx) = mpsc::unbounded_channel();
279        let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
280        let st = f(proto_conn);
281        let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
282        self.protocols.push(st);
283        Ok(())
284    }
285}
286
287/// Represents a protocol in the multiplexer that is used as the primary protocol.
288#[derive(Debug)]
289struct PrimaryProtocol<Primary> {
290    /// Channel to send messages to the primary protocol.
291    to_primary: UnboundedSender<BytesMut>,
292    /// Receiver for messages from the primary protocol.
293    from_primary: UnboundedReceiverStream<Bytes>,
294    /// Shared capability of the primary protocol.
295    shared_cap: SharedCapability,
296    /// The primary stream.
297    st: Primary,
298}
299
300/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
301///
302/// Only emits and sends _non-empty_ messages
303#[derive(Debug)]
304pub struct ProtocolProxy {
305    shared_cap: SharedCapability,
306    /// Receives _non-empty_ messages from the wire
307    from_wire: UnboundedReceiverStream<BytesMut>,
308    /// Sends _non-empty_ messages from the wire
309    to_wire: UnboundedSender<Bytes>,
310}
311
312impl ProtocolProxy {
313    /// Sends a _non-empty_ message on the wire.
314    fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
315        if msg.is_empty() {
316            // message must not be empty
317            return Err(io::ErrorKind::InvalidInput.into())
318        }
319        self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
320    }
321
322    /// Masks the message ID of a message to be sent on the wire.
323    #[inline]
324    fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
325        if msg.is_empty() {
326            // message must not be empty
327            return Err(io::ErrorKind::InvalidInput.into())
328        }
329
330        let offset = self.shared_cap.relative_message_id_offset();
331        if offset == 0 {
332            return Ok(msg);
333        }
334
335        let mut masked = Vec::from(msg);
336        masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
337        Ok(masked.into())
338    }
339
340    /// Unmasks the message ID of a message received from the wire.
341    #[inline]
342    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
343        if msg.is_empty() {
344            // message must not be empty
345            return Err(io::ErrorKind::InvalidInput.into())
346        }
347        msg[0] = msg[0]
348            .checked_sub(self.shared_cap.relative_message_id_offset())
349            .ok_or(io::ErrorKind::InvalidInput)?;
350        Ok(msg)
351    }
352}
353
354impl Stream for ProtocolProxy {
355    type Item = Result<BytesMut, io::Error>;
356
357    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
358        let msg = ready!(self.from_wire.poll_next_unpin(cx));
359        Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
360    }
361}
362
363impl Sink<Bytes> for ProtocolProxy {
364    type Error = io::Error;
365
366    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
367        Poll::Ready(Ok(()))
368    }
369
370    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
371        self.get_mut().try_send(item)
372    }
373
374    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
375        Poll::Ready(Ok(()))
376    }
377
378    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379        Poll::Ready(Ok(()))
380    }
381}
382
383impl CanDisconnect<Bytes> for ProtocolProxy {
384    fn disconnect(
385        &mut self,
386        _reason: DisconnectReason,
387    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
388        // TODO handle disconnects
389        Box::pin(async move { Ok(()) })
390    }
391}
392
393/// Adapter so the injected `EthRlpxHandshake` can run over a multiplexed `ProtocolProxy`
394/// using the same error type expectations (`P2PStreamError`).
395#[derive(Debug)]
396struct UnauthProxy {
397    inner: ProtocolProxy,
398}
399
400impl UnauthProxy {
401    fn into_inner(self) -> ProtocolProxy {
402        self.inner
403    }
404}
405
406impl Stream for UnauthProxy {
407    type Item = Result<BytesMut, P2PStreamError>;
408
409    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410        self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
411    }
412}
413
414impl Sink<Bytes> for UnauthProxy {
415    type Error = P2PStreamError;
416
417    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
418        self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
419    }
420
421    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
422        self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
423    }
424
425    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
426        self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
427    }
428
429    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
430        self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
431    }
432}
433
434impl CanDisconnect<Bytes> for UnauthProxy {
435    fn disconnect(
436        &mut self,
437        reason: DisconnectReason,
438    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
439        let fut = self.inner.disconnect(reason);
440        Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
441    }
442}
443
444/// A connection channel to receive _`non_empty`_ messages for the negotiated protocol.
445///
446/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
447#[derive(Debug)]
448pub struct ProtocolConnection {
449    from_wire: UnboundedReceiverStream<BytesMut>,
450}
451
452impl Stream for ProtocolConnection {
453    type Item = BytesMut;
454
455    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
456        self.from_wire.poll_next_unpin(cx)
457    }
458}
459
460/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
461/// [`EthStream`] and can also handle additional subprotocols.
462#[derive(Debug)]
463pub struct RlpxSatelliteStream<St, Primary> {
464    inner: MultiplexInner<St>,
465    primary: PrimaryProtocol<Primary>,
466}
467
468impl<St, Primary> RlpxSatelliteStream<St, Primary> {
469    /// Installs a new protocol on top of the raw p2p stream.
470    ///
471    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
472    /// the given capability.
473    pub fn install_protocol<F, Proto>(
474        &mut self,
475        cap: &Capability,
476        f: F,
477    ) -> Result<(), UnsupportedCapabilityError>
478    where
479        F: FnOnce(ProtocolConnection) -> Proto,
480        Proto: Stream<Item = BytesMut> + Send + 'static,
481    {
482        self.inner.install_protocol(cap, f)
483    }
484
485    /// Returns the primary protocol.
486    #[inline]
487    pub const fn primary(&self) -> &Primary {
488        &self.primary.st
489    }
490
491    /// Returns mutable access to the primary protocol.
492    #[inline]
493    pub const fn primary_mut(&mut self) -> &mut Primary {
494        &mut self.primary.st
495    }
496
497    /// Returns the underlying [`P2PStream`].
498    #[inline]
499    pub const fn inner(&self) -> &P2PStream<St> {
500        &self.inner.conn
501    }
502
503    /// Returns mutable access to the underlying [`P2PStream`].
504    #[inline]
505    pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
506        &mut self.inner.conn
507    }
508
509    /// Consumes this type and returns the wrapped [`P2PStream`].
510    #[inline]
511    pub fn into_inner(self) -> P2PStream<St> {
512        self.inner.conn
513    }
514}
515
516impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
517where
518    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
519    Primary: TryStream<Error = PrimaryErr> + Unpin,
520    P2PStreamError: Into<PrimaryErr>,
521{
522    type Item = Result<Primary::Ok, Primary::Error>;
523
524    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
525        let this = self.get_mut();
526
527        loop {
528            // first drain the primary stream
529            if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
530                return Poll::Ready(Some(msg))
531            }
532
533            let mut conn_ready = true;
534            loop {
535                match this.inner.conn.poll_ready_unpin(cx) {
536                    Poll::Ready(Ok(())) => {
537                        if let Some(msg) = this.inner.out_buffer.pop_front() {
538                            if let Err(err) = this.inner.conn.start_send_unpin(msg) {
539                                return Poll::Ready(Some(Err(err.into())))
540                            }
541                        } else {
542                            break
543                        }
544                    }
545                    Poll::Ready(Err(err)) => {
546                        if let Err(disconnect_err) =
547                            this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
548                        {
549                            return Poll::Ready(Some(Err(disconnect_err.into())))
550                        }
551                        return Poll::Ready(Some(Err(err.into())))
552                    }
553                    Poll::Pending => {
554                        conn_ready = false;
555                        break
556                    }
557                }
558            }
559
560            // advance primary out
561            loop {
562                match this.primary.from_primary.poll_next_unpin(cx) {
563                    Poll::Ready(Some(msg)) => {
564                        this.inner.out_buffer.push_back(msg);
565                    }
566                    Poll::Ready(None) => {
567                        // primary closed
568                        return Poll::Ready(None)
569                    }
570                    Poll::Pending => break,
571                }
572            }
573
574            // advance all satellites
575            for idx in (0..this.inner.protocols.len()).rev() {
576                let mut proto = this.inner.protocols.swap_remove(idx);
577                loop {
578                    match proto.poll_next_unpin(cx) {
579                        Poll::Ready(Some(Err(err))) => {
580                            return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
581                        }
582                        Poll::Ready(Some(Ok(msg))) => {
583                            this.inner.out_buffer.push_back(msg);
584                        }
585                        Poll::Ready(None) => return Poll::Ready(None),
586                        Poll::Pending => {
587                            this.inner.protocols.push(proto);
588                            break
589                        }
590                    }
591                }
592            }
593
594            let mut delegated = false;
595            loop {
596                // pull messages from connection
597                match this.inner.conn.poll_next_unpin(cx) {
598                    Poll::Ready(Some(Ok(msg))) => {
599                        delegated = true;
600                        let Some(offset) = msg.first().copied() else {
601                            return Poll::Ready(Some(Err(
602                                P2PStreamError::EmptyProtocolMessage.into()
603                            )))
604                        };
605                        // delegate the multiplexed message to the correct protocol
606                        if let Some(cap) =
607                            this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
608                        {
609                            if cap == &this.primary.shared_cap {
610                                // delegate to primary
611                                let _ = this.primary.to_primary.send(msg);
612                            } else {
613                                // delegate to installed satellite if any
614                                for proto in &this.inner.protocols {
615                                    if proto.shared_cap == *cap {
616                                        proto.send_raw(msg);
617                                        break
618                                    }
619                                }
620                            }
621                        } else {
622                            return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
623                                offset,
624                            )
625                            .into())))
626                        }
627                    }
628                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
629                    Poll::Ready(None) => {
630                        // connection closed
631                        return Poll::Ready(None)
632                    }
633                    Poll::Pending => break,
634                }
635            }
636
637            if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
638                return Poll::Pending
639            }
640        }
641    }
642}
643
644impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
645where
646    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
647    Primary: Sink<T> + Unpin,
648    P2PStreamError: Into<<Primary as Sink<T>>::Error>,
649{
650    type Error = <Primary as Sink<T>>::Error;
651
652    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
653        let this = self.get_mut();
654        if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
655            return Poll::Ready(Err(err.into()))
656        }
657        if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
658            return Poll::Ready(Err(err))
659        }
660        Poll::Ready(Ok(()))
661    }
662
663    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
664        self.get_mut().primary.st.start_send_unpin(item)
665    }
666
667    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
668        self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
669    }
670
671    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
672        self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
673    }
674}
675
676/// Wraps a `RLPx` subprotocol and handles message ID multiplexing.
677struct ProtocolStream {
678    shared_cap: SharedCapability,
679    /// the channel shared with the satellite stream
680    to_satellite: UnboundedSender<BytesMut>,
681    satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
682}
683
684impl ProtocolStream {
685    /// Masks the message ID of a message to be sent on the wire.
686    #[inline]
687    fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
688        if msg.is_empty() {
689            // message must not be empty
690            return Err(io::ErrorKind::InvalidInput.into())
691        }
692        msg[0] = msg[0]
693            .checked_add(self.shared_cap.relative_message_id_offset())
694            .ok_or(io::ErrorKind::InvalidInput)?;
695        Ok(msg.freeze())
696    }
697
698    /// Unmasks the message ID of a message received from the wire.
699    #[inline]
700    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
701        if msg.is_empty() {
702            // message must not be empty
703            return Err(io::ErrorKind::InvalidInput.into())
704        }
705        msg[0] = msg[0]
706            .checked_sub(self.shared_cap.relative_message_id_offset())
707            .ok_or(io::ErrorKind::InvalidInput)?;
708        Ok(msg)
709    }
710
711    /// Sends the message to the satellite stream.
712    fn send_raw(&self, msg: BytesMut) {
713        let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
714    }
715}
716
717impl Stream for ProtocolStream {
718    type Item = Result<Bytes, io::Error>;
719
720    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
721        let this = self.get_mut();
722        let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
723        Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
724    }
725}
726
727impl fmt::Debug for ProtocolStream {
728    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
729        f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
730    }
731}
732
733/// Helper to poll multiple protocol streams in a `tokio::select`! branch
734struct ProtocolsPoller<'a> {
735    protocols: &'a mut Vec<ProtocolStream>,
736}
737
738impl<'a> ProtocolsPoller<'a> {
739    const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
740        Self { protocols }
741    }
742}
743
744impl<'a> Future for ProtocolsPoller<'a> {
745    type Output = Result<Bytes, P2PStreamError>;
746
747    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
748        // Process protocols in reverse order, like the existing pattern
749        for idx in (0..self.protocols.len()).rev() {
750            let mut proto = self.protocols.swap_remove(idx);
751            match proto.poll_next_unpin(cx) {
752                Poll::Ready(Some(Err(err))) => {
753                    self.protocols.push(proto);
754                    return Poll::Ready(Err(P2PStreamError::from(err)))
755                }
756                Poll::Ready(Some(Ok(msg))) => {
757                    // Got a message, put protocol back and return the message
758                    self.protocols.push(proto);
759                    return Poll::Ready(Ok(msg));
760                }
761                _ => {
762                    // push it back because we still want to complete the handshake first
763                    self.protocols.push(proto);
764                }
765            }
766        }
767
768        // All protocols processed, nothing ready
769        Poll::Pending
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use crate::{
777        handshake::EthHandshake,
778        test_utils::{
779            connect_passthrough, eth_handshake, eth_hello,
780            proto::{test_hello, TestProtoMessage},
781        },
782        UnauthedEthStream, UnauthedP2PStream,
783    };
784    use reth_eth_wire_types::EthNetworkPrimitives;
785    use tokio::{net::TcpListener, sync::oneshot};
786    use tokio_util::codec::Decoder;
787
788    #[tokio::test]
789    async fn eth_satellite() {
790        reth_tracing::init_test_tracing();
791        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
792        let local_addr = listener.local_addr().unwrap();
793        let (status, fork_filter) = eth_handshake();
794        let other_status = status;
795        let other_fork_filter = fork_filter.clone();
796        let _handle = tokio::spawn(async move {
797            let (incoming, _) = listener.accept().await.unwrap();
798            let stream = crate::PassthroughCodec::default().framed(incoming);
799            let (server_hello, _) = eth_hello();
800            let (p2p_stream, _) =
801                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
802
803            let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
804                .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
805                .await
806                .unwrap();
807
808            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
809        });
810
811        let conn = connect_passthrough(local_addr, eth_hello().0).await;
812        let eth = conn.shared_capabilities().eth().unwrap().clone();
813
814        let multiplexer = RlpxProtocolMultiplexer::new(conn);
815        let _satellite = multiplexer
816            .into_satellite_stream_with_handshake(
817                eth.capability().as_ref(),
818                move |proxy| async move {
819                    UnauthedEthStream::new(proxy)
820                        .handshake::<EthNetworkPrimitives>(status, fork_filter)
821                        .await
822                },
823            )
824            .await
825            .unwrap();
826    }
827
828    /// A test that install a satellite stream eth+test protocol and sends messages between them.
829    #[tokio::test(flavor = "multi_thread")]
830    async fn eth_test_protocol_satellite() {
831        reth_tracing::init_test_tracing();
832        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
833        let local_addr = listener.local_addr().unwrap();
834        let (status, fork_filter) = eth_handshake();
835        let other_status = status;
836        let other_fork_filter = fork_filter.clone();
837        let _handle = tokio::spawn(async move {
838            let (incoming, _) = listener.accept().await.unwrap();
839            let stream = crate::PassthroughCodec::default().framed(incoming);
840            let (server_hello, _) = test_hello();
841            let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
842
843            let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
844                .into_eth_satellite_stream::<EthNetworkPrimitives>(
845                    other_status,
846                    other_fork_filter,
847                    Arc::new(EthHandshake::default()),
848                )
849                .await
850                .unwrap();
851
852            st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
853                async_stream::stream! {
854                    yield TestProtoMessage::ping().encoded();
855                    let msg = conn.next().await.unwrap();
856                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
857                    assert_eq!(msg, TestProtoMessage::pong());
858
859                    yield TestProtoMessage::message("hello").encoded();
860                    let msg = conn.next().await.unwrap();
861                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
862                    assert_eq!(msg, TestProtoMessage::message("good bye!"));
863
864                    yield TestProtoMessage::message("good bye!").encoded();
865
866                    futures::future::pending::<()>().await;
867                    unreachable!()
868                }
869            })
870            .unwrap();
871
872            loop {
873                let _ = st.next().await;
874            }
875        });
876
877        let conn = connect_passthrough(local_addr, test_hello().0).await;
878        let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
879            .into_eth_satellite_stream::<EthNetworkPrimitives>(
880                status,
881                fork_filter,
882                Arc::new(EthHandshake::default()),
883            )
884            .await
885            .unwrap();
886
887        let (tx, mut rx) = oneshot::channel();
888
889        st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
890            async_stream::stream! {
891                let msg = conn.next().await.unwrap();
892                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
893                assert_eq!(msg, TestProtoMessage::ping());
894
895                yield TestProtoMessage::pong().encoded();
896
897                let msg = conn.next().await.unwrap();
898                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
899                assert_eq!(msg, TestProtoMessage::message("hello"));
900
901                yield TestProtoMessage::message("good bye!").encoded();
902
903                let msg = conn.next().await.unwrap();
904                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
905                assert_eq!(msg, TestProtoMessage::message("good bye!"));
906
907                tx.send(()).unwrap();
908
909                futures::future::pending::<()>().await;
910                unreachable!()
911            }
912        })
913        .unwrap();
914
915        loop {
916            tokio::select! {
917                _ = &mut rx => {
918                    break
919                }
920               _ = st.next() => {
921                }
922            }
923        }
924    }
925}