reth_eth_wire/
multiplex.rs

1//! Rlpx protocol multiplexer and satellite stream
2//!
3//! A Satellite is a Stream that primarily drives a single `RLPx` subprotocol but can also handle
4//! additional subprotocols.
5//!
6//! Most of other subprotocols are "dependent satellite" protocols of "eth" and not a fully standalone protocol, for example "snap", See also [snap protocol](https://github.com/ethereum/devp2p/blob/298d7a77c3bf833641579ecbbb5b13f0311eeeea/caps/snap.md?plain=1#L71)
7//! Hence it is expected that the primary protocol is "eth" and the additional protocols are
8//! "dependent satellite" protocols.
9
10use std::{
11    collections::VecDeque,
12    fmt,
13    future::Future,
14    io,
15    pin::{pin, Pin},
16    sync::Arc,
17    task::{ready, Context, Poll},
18};
19
20use crate::{
21    capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22    errors::{EthStreamError, P2PStreamError},
23    handshake::EthRlpxHandshake,
24    p2pstream::DisconnectP2P,
25    CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26    HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35/// A Stream and Sink type that wraps a raw rlpx stream [`P2PStream`] and handles message ID
36/// multiplexing.
37#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39    inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43    /// Wraps the raw p2p stream
44    pub fn new(conn: P2PStream<St>) -> Self {
45        Self {
46            inner: MultiplexInner {
47                conn,
48                protocols: Default::default(),
49                out_buffer: Default::default(),
50            },
51        }
52    }
53
54    /// Installs a new protocol on top of the raw p2p stream.
55    ///
56    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
57    /// the given capability.
58    pub fn install_protocol<F, Proto>(
59        &mut self,
60        cap: &Capability,
61        f: F,
62    ) -> Result<(), UnsupportedCapabilityError>
63    where
64        F: FnOnce(ProtocolConnection) -> Proto,
65        Proto: Stream<Item = BytesMut> + Send + 'static,
66    {
67        self.inner.install_protocol(cap, f)
68    }
69
70    /// Returns the [`SharedCapabilities`] of the underlying raw p2p stream
71    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72        self.inner.shared_capabilities()
73    }
74
75    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
76    pub fn into_satellite_stream<F, Primary>(
77        self,
78        cap: &Capability,
79        primary: F,
80    ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81    where
82        F: FnOnce(ProtocolProxy) -> Primary,
83    {
84        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85        else {
86            return Err(P2PStreamError::CapabilityNotShared)
87        };
88
89        let (to_primary, from_wire) = mpsc::unbounded_channel();
90        let (to_wire, from_primary) = mpsc::unbounded_channel();
91        let proxy = ProtocolProxy {
92            shared_cap: shared_cap.clone(),
93            from_wire: UnboundedReceiverStream::new(from_wire),
94            to_wire,
95        };
96
97        let st = primary(proxy);
98        Ok(RlpxSatelliteStream {
99            inner: self.inner,
100            primary: PrimaryProtocol {
101                to_primary,
102                from_primary: UnboundedReceiverStream::new(from_primary),
103                st,
104                shared_cap,
105            },
106        })
107    }
108
109    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
110    ///
111    /// Returns an error if the primary protocol is not supported by the remote or the handshake
112    /// failed.
113    pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
114        self,
115        cap: &Capability,
116        handshake: F,
117    ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
118    where
119        F: FnOnce(ProtocolProxy) -> Fut,
120        Fut: Future<Output = Result<Primary, Err>>,
121        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
122        P2PStreamError: Into<Err>,
123    {
124        self.into_satellite_stream_with_tuple_handshake(cap, move |proxy| async move {
125            let st = handshake(proxy).await?;
126            Ok((st, ()))
127        })
128        .await
129        .map(|(st, _)| st)
130    }
131
132    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
133    ///
134    /// Returns an error if the primary protocol is not supported by the remote or the handshake
135    /// failed.
136    ///
137    /// This accepts a closure that does a handshake with the remote peer and returns a tuple of the
138    /// primary stream and extra data.
139    ///
140    /// See also [`UnauthedEthStream::handshake`](crate::UnauthedEthStream)
141    pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
142        mut self,
143        cap: &Capability,
144        handshake: F,
145    ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
146    where
147        F: FnOnce(ProtocolProxy) -> Fut,
148        Fut: Future<Output = Result<(Primary, Extra), Err>>,
149        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
150        P2PStreamError: Into<Err>,
151    {
152        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
153        else {
154            return Err(P2PStreamError::CapabilityNotShared.into())
155        };
156
157        let (to_primary, from_wire) = mpsc::unbounded_channel();
158        let (to_wire, mut from_primary) = mpsc::unbounded_channel();
159        let proxy = ProtocolProxy {
160            shared_cap: shared_cap.clone(),
161            from_wire: UnboundedReceiverStream::new(from_wire),
162            to_wire,
163        };
164
165        let f = handshake(proxy);
166        let mut f = pin!(f);
167
168        // this polls the connection and the primary stream concurrently until the handshake is
169        // complete
170        loop {
171            tokio::select! {
172                biased;
173                Some(Ok(msg)) = self.inner.conn.next() => {
174                    // Ensure the message belongs to the primary protocol
175                    let Some(offset) = msg.first().copied()
176                    else {
177                        return Err(P2PStreamError::EmptyProtocolMessage.into())
178                    };
179                    if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
180                            if cap == shared_cap {
181                                // delegate to primary
182                                let _ = to_primary.send(msg);
183                            } else {
184                                // delegate to satellite
185                                self.inner.delegate_message(&cap, msg);
186                            }
187                        } else {
188                           return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
189                        }
190                }
191                Some(msg) = from_primary.recv() => {
192                    self.inner.conn.send(msg).await.map_err(Into::into)?;
193                }
194                // Poll all subprotocols for new messages
195                msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
196                     self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
197                }
198                res = &mut f => {
199                    let (st, extra) = res?;
200                    return Ok((RlpxSatelliteStream {
201                            inner: self.inner,
202                            primary: PrimaryProtocol {
203                                to_primary,
204                                from_primary: UnboundedReceiverStream::new(from_primary),
205                                st,
206                                shared_cap,
207                            }
208                    }, extra))
209                }
210            }
211        }
212    }
213
214    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given
215    /// primary protocol and the handshake implementation.
216    pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
217        self,
218        status: UnifiedStatus,
219        fork_filter: ForkFilter,
220        handshake: Arc<dyn EthRlpxHandshake>,
221    ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
222    where
223        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
224    {
225        let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
226        self.into_satellite_stream_with_tuple_handshake(&Capability::eth(eth_cap), move |proxy| {
227            let handshake = handshake.clone();
228            async move {
229                let mut unauth = UnauthProxy { inner: proxy };
230                let their_status = handshake
231                    .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
232                    .await?;
233                let eth_stream = EthStream::new(eth_cap, unauth.into_inner());
234                Ok((eth_stream, their_status))
235            }
236        })
237        .await
238    }
239}
240
241#[derive(Debug)]
242struct MultiplexInner<St> {
243    /// The raw p2p stream
244    conn: P2PStream<St>,
245    /// All the subprotocols that are multiplexed on top of the raw p2p stream
246    protocols: Vec<ProtocolStream>,
247    /// Buffer for outgoing messages on the wire.
248    out_buffer: VecDeque<Bytes>,
249}
250
251impl<St> MultiplexInner<St> {
252    const fn shared_capabilities(&self) -> &SharedCapabilities {
253        self.conn.shared_capabilities()
254    }
255
256    /// Delegates a message to the matching protocol.
257    fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
258        for proto in &self.protocols {
259            if proto.shared_cap == *cap {
260                proto.send_raw(msg);
261                return true
262            }
263        }
264        false
265    }
266
267    fn install_protocol<F, Proto>(
268        &mut self,
269        cap: &Capability,
270        f: F,
271    ) -> Result<(), UnsupportedCapabilityError>
272    where
273        F: FnOnce(ProtocolConnection) -> Proto,
274        Proto: Stream<Item = BytesMut> + Send + 'static,
275    {
276        let shared_cap =
277            self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
278        let (to_satellite, rx) = mpsc::unbounded_channel();
279        let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
280        let st = f(proto_conn);
281        let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
282        self.protocols.push(st);
283        Ok(())
284    }
285}
286
287/// Represents a protocol in the multiplexer that is used as the primary protocol.
288#[derive(Debug)]
289struct PrimaryProtocol<Primary> {
290    /// Channel to send messages to the primary protocol.
291    to_primary: UnboundedSender<BytesMut>,
292    /// Receiver for messages from the primary protocol.
293    from_primary: UnboundedReceiverStream<Bytes>,
294    /// Shared capability of the primary protocol.
295    shared_cap: SharedCapability,
296    /// The primary stream.
297    st: Primary,
298}
299
300/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
301///
302/// Only emits and sends _non-empty_ messages
303#[derive(Debug)]
304pub struct ProtocolProxy {
305    shared_cap: SharedCapability,
306    /// Receives _non-empty_ messages from the wire
307    from_wire: UnboundedReceiverStream<BytesMut>,
308    /// Sends _non-empty_ messages from the wire
309    to_wire: UnboundedSender<Bytes>,
310}
311
312impl ProtocolProxy {
313    /// Sends a _non-empty_ message on the wire.
314    fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
315        if msg.is_empty() {
316            // message must not be empty
317            return Err(io::ErrorKind::InvalidInput.into())
318        }
319        self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
320    }
321
322    /// Masks the message ID of a message to be sent on the wire.
323    #[inline]
324    fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
325        if msg.is_empty() {
326            // message must not be empty
327            return Err(io::ErrorKind::InvalidInput.into())
328        }
329
330        let offset = self.shared_cap.relative_message_id_offset();
331        if offset == 0 {
332            return Ok(msg);
333        }
334
335        let mut masked: BytesMut = msg.into();
336        masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
337        Ok(masked.freeze())
338    }
339
340    /// Unmasks the message ID of a message received from the wire.
341    #[inline]
342    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
343        if msg.is_empty() {
344            // message must not be empty
345            return Err(io::ErrorKind::InvalidInput.into())
346        }
347        msg[0] = msg[0]
348            .checked_sub(self.shared_cap.relative_message_id_offset())
349            .ok_or(io::ErrorKind::InvalidInput)?;
350        Ok(msg)
351    }
352}
353
354impl Stream for ProtocolProxy {
355    type Item = Result<BytesMut, io::Error>;
356
357    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
358        let msg = ready!(self.from_wire.poll_next_unpin(cx));
359        Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
360    }
361}
362
363impl Sink<Bytes> for ProtocolProxy {
364    type Error = io::Error;
365
366    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
367        Poll::Ready(Ok(()))
368    }
369
370    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
371        self.get_mut().try_send(item)
372    }
373
374    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
375        Poll::Ready(Ok(()))
376    }
377
378    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379        Poll::Ready(Ok(()))
380    }
381}
382
383impl CanDisconnect<Bytes> for ProtocolProxy {
384    fn disconnect(
385        &mut self,
386        _reason: DisconnectReason,
387    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
388        Box::pin(async move { Ok(()) })
389    }
390}
391
392/// Adapter so the injected `EthRlpxHandshake` can run over a multiplexed `ProtocolProxy`
393/// using the same error type expectations (`P2PStreamError`).
394#[derive(Debug)]
395struct UnauthProxy {
396    inner: ProtocolProxy,
397}
398
399impl UnauthProxy {
400    fn into_inner(self) -> ProtocolProxy {
401        self.inner
402    }
403}
404
405impl Stream for UnauthProxy {
406    type Item = Result<BytesMut, P2PStreamError>;
407
408    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
409        self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
410    }
411}
412
413impl Sink<Bytes> for UnauthProxy {
414    type Error = P2PStreamError;
415
416    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
417        self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
418    }
419
420    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
421        self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
422    }
423
424    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
425        self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
426    }
427
428    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
429        self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
430    }
431}
432
433impl CanDisconnect<Bytes> for UnauthProxy {
434    fn disconnect(
435        &mut self,
436        reason: DisconnectReason,
437    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
438        let fut = self.inner.disconnect(reason);
439        Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
440    }
441}
442
443/// A connection channel to receive _`non_empty`_ messages for the negotiated protocol.
444///
445/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
446#[derive(Debug)]
447pub struct ProtocolConnection {
448    from_wire: UnboundedReceiverStream<BytesMut>,
449}
450
451impl Stream for ProtocolConnection {
452    type Item = BytesMut;
453
454    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
455        self.from_wire.poll_next_unpin(cx)
456    }
457}
458
459/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
460/// [`EthStream`] and can also handle additional subprotocols.
461#[derive(Debug)]
462pub struct RlpxSatelliteStream<St, Primary> {
463    inner: MultiplexInner<St>,
464    primary: PrimaryProtocol<Primary>,
465}
466
467impl<St, Primary> RlpxSatelliteStream<St, Primary> {
468    /// Installs a new protocol on top of the raw p2p stream.
469    ///
470    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
471    /// the given capability.
472    pub fn install_protocol<F, Proto>(
473        &mut self,
474        cap: &Capability,
475        f: F,
476    ) -> Result<(), UnsupportedCapabilityError>
477    where
478        F: FnOnce(ProtocolConnection) -> Proto,
479        Proto: Stream<Item = BytesMut> + Send + 'static,
480    {
481        self.inner.install_protocol(cap, f)
482    }
483
484    /// Returns the primary protocol.
485    #[inline]
486    pub const fn primary(&self) -> &Primary {
487        &self.primary.st
488    }
489
490    /// Returns mutable access to the primary protocol.
491    #[inline]
492    pub const fn primary_mut(&mut self) -> &mut Primary {
493        &mut self.primary.st
494    }
495
496    /// Returns the underlying [`P2PStream`].
497    #[inline]
498    pub const fn inner(&self) -> &P2PStream<St> {
499        &self.inner.conn
500    }
501
502    /// Returns mutable access to the underlying [`P2PStream`].
503    #[inline]
504    pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
505        &mut self.inner.conn
506    }
507
508    /// Consumes this type and returns the wrapped [`P2PStream`].
509    #[inline]
510    pub fn into_inner(self) -> P2PStream<St> {
511        self.inner.conn
512    }
513}
514
515impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
516where
517    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
518    Primary: TryStream<Error = PrimaryErr> + Unpin,
519    P2PStreamError: Into<PrimaryErr>,
520{
521    type Item = Result<Primary::Ok, Primary::Error>;
522
523    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
524        let this = self.get_mut();
525
526        loop {
527            // first drain the primary stream
528            if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
529                return Poll::Ready(Some(msg))
530            }
531
532            let mut conn_ready = true;
533            loop {
534                match this.inner.conn.poll_ready_unpin(cx) {
535                    Poll::Ready(Ok(())) => {
536                        if let Some(msg) = this.inner.out_buffer.pop_front() {
537                            if let Err(err) = this.inner.conn.start_send_unpin(msg) {
538                                return Poll::Ready(Some(Err(err.into())))
539                            }
540                        } else {
541                            break
542                        }
543                    }
544                    Poll::Ready(Err(err)) => {
545                        if let Err(disconnect_err) =
546                            this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
547                        {
548                            return Poll::Ready(Some(Err(disconnect_err.into())))
549                        }
550                        return Poll::Ready(Some(Err(err.into())))
551                    }
552                    Poll::Pending => {
553                        conn_ready = false;
554                        break
555                    }
556                }
557            }
558
559            // advance primary out
560            loop {
561                match this.primary.from_primary.poll_next_unpin(cx) {
562                    Poll::Ready(Some(msg)) => {
563                        this.inner.out_buffer.push_back(msg);
564                    }
565                    Poll::Ready(None) => {
566                        // primary closed
567                        return Poll::Ready(None)
568                    }
569                    Poll::Pending => break,
570                }
571            }
572
573            // advance all satellites
574            for idx in (0..this.inner.protocols.len()).rev() {
575                let mut proto = this.inner.protocols.swap_remove(idx);
576                loop {
577                    match proto.poll_next_unpin(cx) {
578                        Poll::Ready(Some(Err(err))) => {
579                            return Poll::Ready(Some(Err(P2PStreamError::Io(err).into())))
580                        }
581                        Poll::Ready(Some(Ok(msg))) => {
582                            this.inner.out_buffer.push_back(msg);
583                        }
584                        Poll::Ready(None) => return Poll::Ready(None),
585                        Poll::Pending => {
586                            this.inner.protocols.push(proto);
587                            break
588                        }
589                    }
590                }
591            }
592
593            let mut delegated = false;
594            loop {
595                // pull messages from connection
596                match this.inner.conn.poll_next_unpin(cx) {
597                    Poll::Ready(Some(Ok(msg))) => {
598                        delegated = true;
599                        let Some(offset) = msg.first().copied() else {
600                            return Poll::Ready(Some(Err(
601                                P2PStreamError::EmptyProtocolMessage.into()
602                            )))
603                        };
604                        // delegate the multiplexed message to the correct protocol
605                        if let Some(cap) =
606                            this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
607                        {
608                            if cap == &this.primary.shared_cap {
609                                // delegate to primary
610                                let _ = this.primary.to_primary.send(msg);
611                            } else {
612                                // delegate to installed satellite if any
613                                for proto in &this.inner.protocols {
614                                    if proto.shared_cap == *cap {
615                                        proto.send_raw(msg);
616                                        break
617                                    }
618                                }
619                            }
620                        } else {
621                            return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
622                                offset,
623                            )
624                            .into())))
625                        }
626                    }
627                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
628                    Poll::Ready(None) => {
629                        // connection closed
630                        return Poll::Ready(None)
631                    }
632                    Poll::Pending => break,
633                }
634            }
635
636            if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
637                return Poll::Pending
638            }
639        }
640    }
641}
642
643impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
644where
645    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
646    Primary: Sink<T> + Unpin,
647    P2PStreamError: Into<<Primary as Sink<T>>::Error>,
648{
649    type Error = <Primary as Sink<T>>::Error;
650
651    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
652        let this = self.get_mut();
653        if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
654            return Poll::Ready(Err(err.into()))
655        }
656        if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
657            return Poll::Ready(Err(err))
658        }
659        Poll::Ready(Ok(()))
660    }
661
662    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
663        self.get_mut().primary.st.start_send_unpin(item)
664    }
665
666    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
667        self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
668    }
669
670    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
671        self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
672    }
673}
674
675/// Wraps a `RLPx` subprotocol and handles message ID multiplexing.
676struct ProtocolStream {
677    shared_cap: SharedCapability,
678    /// the channel shared with the satellite stream
679    to_satellite: UnboundedSender<BytesMut>,
680    satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
681}
682
683impl ProtocolStream {
684    /// Masks the message ID of a message to be sent on the wire.
685    #[inline]
686    fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
687        if msg.is_empty() {
688            // message must not be empty
689            return Err(io::ErrorKind::InvalidInput.into())
690        }
691        msg[0] = msg[0]
692            .checked_add(self.shared_cap.relative_message_id_offset())
693            .ok_or(io::ErrorKind::InvalidInput)?;
694        Ok(msg.freeze())
695    }
696
697    /// Unmasks the message ID of a message received from the wire.
698    #[inline]
699    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
700        if msg.is_empty() {
701            // message must not be empty
702            return Err(io::ErrorKind::InvalidInput.into())
703        }
704        msg[0] = msg[0]
705            .checked_sub(self.shared_cap.relative_message_id_offset())
706            .ok_or(io::ErrorKind::InvalidInput)?;
707        Ok(msg)
708    }
709
710    /// Sends the message to the satellite stream.
711    fn send_raw(&self, msg: BytesMut) {
712        let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
713    }
714}
715
716impl Stream for ProtocolStream {
717    type Item = Result<Bytes, io::Error>;
718
719    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
720        let this = self.get_mut();
721        let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
722        Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
723    }
724}
725
726impl fmt::Debug for ProtocolStream {
727    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
728        f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
729    }
730}
731
732/// Helper to poll multiple protocol streams in a `tokio::select`! branch
733struct ProtocolsPoller<'a> {
734    protocols: &'a mut Vec<ProtocolStream>,
735}
736
737impl<'a> ProtocolsPoller<'a> {
738    const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
739        Self { protocols }
740    }
741}
742
743impl<'a> Future for ProtocolsPoller<'a> {
744    type Output = Result<Bytes, P2PStreamError>;
745
746    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
747        // Process protocols in reverse order, like the existing pattern
748        for idx in (0..self.protocols.len()).rev() {
749            let mut proto = self.protocols.swap_remove(idx);
750            match proto.poll_next_unpin(cx) {
751                Poll::Ready(Some(Err(err))) => {
752                    self.protocols.push(proto);
753                    return Poll::Ready(Err(P2PStreamError::from(err)))
754                }
755                Poll::Ready(Some(Ok(msg))) => {
756                    // Got a message, put protocol back and return the message
757                    self.protocols.push(proto);
758                    return Poll::Ready(Ok(msg));
759                }
760                _ => {
761                    // push it back because we still want to complete the handshake first
762                    self.protocols.push(proto);
763                }
764            }
765        }
766
767        // All protocols processed, nothing ready
768        Poll::Pending
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775    use crate::{
776        handshake::EthHandshake,
777        test_utils::{
778            connect_passthrough, eth_handshake, eth_hello,
779            proto::{test_hello, TestProtoMessage},
780        },
781        UnauthedEthStream, UnauthedP2PStream,
782    };
783    use reth_eth_wire_types::EthNetworkPrimitives;
784    use tokio::{net::TcpListener, sync::oneshot};
785    use tokio_util::codec::Decoder;
786
787    #[tokio::test]
788    async fn eth_satellite() {
789        reth_tracing::init_test_tracing();
790        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
791        let local_addr = listener.local_addr().unwrap();
792        let (status, fork_filter) = eth_handshake();
793        let other_status = status;
794        let other_fork_filter = fork_filter.clone();
795        let _handle = tokio::spawn(async move {
796            let (incoming, _) = listener.accept().await.unwrap();
797            let stream = crate::PassthroughCodec::default().framed(incoming);
798            let (server_hello, _) = eth_hello();
799            let (p2p_stream, _) =
800                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
801
802            let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
803                .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
804                .await
805                .unwrap();
806
807            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
808        });
809
810        let conn = connect_passthrough(local_addr, eth_hello().0).await;
811        let eth = conn.shared_capabilities().eth().unwrap().clone();
812
813        let multiplexer = RlpxProtocolMultiplexer::new(conn);
814        let _satellite = multiplexer
815            .into_satellite_stream_with_handshake(
816                eth.capability().as_ref(),
817                move |proxy| async move {
818                    UnauthedEthStream::new(proxy)
819                        .handshake::<EthNetworkPrimitives>(status, fork_filter)
820                        .await
821                },
822            )
823            .await
824            .unwrap();
825    }
826
827    /// A test that install a satellite stream eth+test protocol and sends messages between them.
828    #[tokio::test(flavor = "multi_thread")]
829    async fn eth_test_protocol_satellite() {
830        reth_tracing::init_test_tracing();
831        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
832        let local_addr = listener.local_addr().unwrap();
833        let (status, fork_filter) = eth_handshake();
834        let other_status = status;
835        let other_fork_filter = fork_filter.clone();
836        let _handle = tokio::spawn(async move {
837            let (incoming, _) = listener.accept().await.unwrap();
838            let stream = crate::PassthroughCodec::default().framed(incoming);
839            let (server_hello, _) = test_hello();
840            let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
841
842            let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
843                .into_eth_satellite_stream::<EthNetworkPrimitives>(
844                    other_status,
845                    other_fork_filter,
846                    Arc::new(EthHandshake::default()),
847                )
848                .await
849                .unwrap();
850
851            st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
852                async_stream::stream! {
853                    yield TestProtoMessage::ping().encoded();
854                    let msg = conn.next().await.unwrap();
855                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
856                    assert_eq!(msg, TestProtoMessage::pong());
857
858                    yield TestProtoMessage::message("hello").encoded();
859                    let msg = conn.next().await.unwrap();
860                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
861                    assert_eq!(msg, TestProtoMessage::message("good bye!"));
862
863                    yield TestProtoMessage::message("good bye!").encoded();
864
865                    futures::future::pending::<()>().await;
866                    unreachable!()
867                }
868            })
869            .unwrap();
870
871            loop {
872                let _ = st.next().await;
873            }
874        });
875
876        let conn = connect_passthrough(local_addr, test_hello().0).await;
877        let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
878            .into_eth_satellite_stream::<EthNetworkPrimitives>(
879                status,
880                fork_filter,
881                Arc::new(EthHandshake::default()),
882            )
883            .await
884            .unwrap();
885
886        let (tx, mut rx) = oneshot::channel();
887
888        st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
889            async_stream::stream! {
890                let msg = conn.next().await.unwrap();
891                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
892                assert_eq!(msg, TestProtoMessage::ping());
893
894                yield TestProtoMessage::pong().encoded();
895
896                let msg = conn.next().await.unwrap();
897                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
898                assert_eq!(msg, TestProtoMessage::message("hello"));
899
900                yield TestProtoMessage::message("good bye!").encoded();
901
902                let msg = conn.next().await.unwrap();
903                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
904                assert_eq!(msg, TestProtoMessage::message("good bye!"));
905
906                tx.send(()).unwrap();
907
908                futures::future::pending::<()>().await;
909                unreachable!()
910            }
911        })
912        .unwrap();
913
914        loop {
915            tokio::select! {
916                _ = &mut rx => {
917                    break
918                }
919               _ = st.next() => {
920                }
921            }
922        }
923    }
924}