Skip to main content

reth_eth_wire/
multiplex.rs

1//! Rlpx protocol multiplexer and satellite stream
2//!
3//! A Satellite is a Stream that primarily drives a single `RLPx` subprotocol but can also handle
4//! additional subprotocols.
5//!
6//! Most of other subprotocols are "dependent satellite" protocols of "eth" and not a fully standalone protocol, for example "snap", See also [snap protocol](https://github.com/ethereum/devp2p/blob/298d7a77c3bf833641579ecbbb5b13f0311eeeea/caps/snap.md?plain=1#L71)
7//! Hence it is expected that the primary protocol is "eth" and the additional protocols are
8//! "dependent satellite" protocols.
9
10use std::{
11    collections::VecDeque,
12    fmt,
13    future::Future,
14    io,
15    pin::{pin, Pin},
16    sync::Arc,
17    task::{ready, Context, Poll},
18};
19
20use crate::{
21    capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22    errors::{EthStreamError, P2PStreamError},
23    handshake::EthRlpxHandshake,
24    p2pstream::DisconnectP2P,
25    CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26    HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35/// A Stream and Sink type that wraps a raw rlpx stream [`P2PStream`] and handles message ID
36/// multiplexing.
37#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39    inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43    /// Wraps the raw p2p stream
44    pub fn new(conn: P2PStream<St>) -> Self {
45        Self {
46            inner: MultiplexInner {
47                conn,
48                protocols: Default::default(),
49                out_buffer: Default::default(),
50            },
51        }
52    }
53
54    /// Installs a new protocol on top of the raw p2p stream.
55    ///
56    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
57    /// the given capability.
58    pub fn install_protocol<F, Proto>(
59        &mut self,
60        cap: &Capability,
61        f: F,
62    ) -> Result<(), UnsupportedCapabilityError>
63    where
64        F: FnOnce(ProtocolConnection) -> Proto,
65        Proto: Stream<Item = BytesMut> + Send + 'static,
66    {
67        self.inner.install_protocol(cap, f)
68    }
69
70    /// Returns the [`SharedCapabilities`] of the underlying raw p2p stream
71    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72        self.inner.shared_capabilities()
73    }
74
75    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
76    pub fn into_satellite_stream<F, Primary>(
77        self,
78        cap: &Capability,
79        primary: F,
80    ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81    where
82        F: FnOnce(ProtocolProxy) -> Primary,
83    {
84        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85        else {
86            return Err(P2PStreamError::CapabilityNotShared)
87        };
88
89        let (to_primary, from_wire) = mpsc::unbounded_channel();
90        let (to_wire, from_primary) = mpsc::unbounded_channel();
91        let proxy = ProtocolProxy {
92            shared_cap: shared_cap.clone(),
93            from_wire: UnboundedReceiverStream::new(from_wire),
94            to_wire,
95        };
96
97        let st = primary(proxy);
98        Ok(RlpxSatelliteStream {
99            inner: self.inner,
100            primary: PrimaryProtocol {
101                to_primary,
102                from_primary: UnboundedReceiverStream::new(from_primary),
103                st,
104                shared_cap,
105            },
106            next_outbound: 0,
107        })
108    }
109
110    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
111    ///
112    /// Returns an error if the primary protocol is not supported by the remote or the handshake
113    /// failed.
114    pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
115        self,
116        cap: &Capability,
117        handshake: F,
118    ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
119    where
120        F: FnOnce(ProtocolProxy) -> Fut,
121        Fut: Future<Output = Result<Primary, Err>>,
122        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
123        P2PStreamError: Into<Err>,
124    {
125        self.into_satellite_stream_with_tuple_handshake(cap, async move |proxy| {
126            let st = handshake(proxy).await?;
127            Ok((st, ()))
128        })
129        .await
130        .map(|(st, _)| st)
131    }
132
133    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with the given primary protocol.
134    ///
135    /// Returns an error if the primary protocol is not supported by the remote or the handshake
136    /// failed.
137    ///
138    /// This accepts a closure that does a handshake with the remote peer and returns a tuple of the
139    /// primary stream and extra data.
140    ///
141    /// See also [`UnauthedEthStream::handshake`](crate::UnauthedEthStream)
142    pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
143        mut self,
144        cap: &Capability,
145        handshake: F,
146    ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
147    where
148        F: FnOnce(ProtocolProxy) -> Fut,
149        Fut: Future<Output = Result<(Primary, Extra), Err>>,
150        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
151        P2PStreamError: Into<Err>,
152    {
153        let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
154        else {
155            return Err(P2PStreamError::CapabilityNotShared.into())
156        };
157
158        let (to_primary, from_wire) = mpsc::unbounded_channel();
159        let (to_wire, mut from_primary) = mpsc::unbounded_channel();
160        let proxy = ProtocolProxy {
161            shared_cap: shared_cap.clone(),
162            from_wire: UnboundedReceiverStream::new(from_wire),
163            to_wire,
164        };
165
166        let f = handshake(proxy);
167        let mut f = pin!(f);
168
169        // this polls the connection and the primary stream concurrently until the handshake is
170        // complete
171        loop {
172            tokio::select! {
173                biased;
174                Some(Ok(msg)) = self.inner.conn.next() => {
175                    // Ensure the message belongs to the primary protocol
176                    let Some(offset) = msg.first().copied()
177                    else {
178                        return Err(P2PStreamError::EmptyProtocolMessage.into())
179                    };
180                    if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
181                            if cap == shared_cap {
182                                // delegate to primary
183                                let _ = to_primary.send(msg);
184                            } else {
185                                // delegate to satellite
186                                self.inner.delegate_message(&cap, msg);
187                            }
188                        } else {
189                           return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
190                        }
191                }
192                Some(msg) = from_primary.recv() => {
193                    self.inner.conn.send(msg).await.map_err(Into::into)?;
194                }
195                // Poll all subprotocols for new messages
196                msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
197                     self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
198                }
199                res = &mut f => {
200                    let (st, extra) = res?;
201                    return Ok((
202                        RlpxSatelliteStream {
203                            inner: self.inner,
204                            primary: PrimaryProtocol {
205                                to_primary,
206                                from_primary: UnboundedReceiverStream::new(from_primary),
207                                st,
208                                shared_cap,
209                            },
210                            next_outbound: 0,
211                        },
212                        extra,
213                    ))
214                }
215            }
216        }
217    }
218
219    /// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given
220    /// primary protocol and the handshake implementation.
221    pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
222        self,
223        status: UnifiedStatus,
224        fork_filter: ForkFilter,
225        handshake: Arc<dyn EthRlpxHandshake>,
226        eth_max_message_size: usize,
227    ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
228    where
229        St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
230    {
231        let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
232        self.into_satellite_stream_with_tuple_handshake(
233            &Capability::eth(eth_cap),
234            async move |proxy| {
235                let handshake = handshake.clone();
236                let mut unauth = UnauthProxy { inner: proxy };
237                let their_status = handshake
238                    .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
239                    .await?;
240                let eth_stream = EthStream::with_max_message_size(
241                    eth_cap,
242                    unauth.into_inner(),
243                    eth_max_message_size,
244                );
245                Ok((eth_stream, their_status))
246            },
247        )
248        .await
249    }
250}
251
252#[derive(Debug)]
253struct MultiplexInner<St> {
254    /// The raw p2p stream
255    conn: P2PStream<St>,
256    /// All the subprotocols that are multiplexed on top of the raw p2p stream
257    protocols: VecDeque<ProtocolStream>,
258    /// Buffer for outgoing messages on the wire.
259    out_buffer: OutBuffer,
260}
261
262impl<St> MultiplexInner<St> {
263    const fn shared_capabilities(&self) -> &SharedCapabilities {
264        self.conn.shared_capabilities()
265    }
266
267    /// Delegates a message to the matching protocol.
268    fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
269        for proto in &self.protocols {
270            if proto.shared_cap == *cap {
271                proto.send_raw(msg);
272                return true
273            }
274        }
275        false
276    }
277
278    fn install_protocol<F, Proto>(
279        &mut self,
280        cap: &Capability,
281        f: F,
282    ) -> Result<(), UnsupportedCapabilityError>
283    where
284        F: FnOnce(ProtocolConnection) -> Proto,
285        Proto: Stream<Item = BytesMut> + Send + 'static,
286    {
287        let shared_cap =
288            self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
289        let (to_satellite, rx) = mpsc::unbounded_channel();
290        let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
291        let st = f(proto_conn);
292        let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
293        self.protocols.push_back(st);
294        Ok(())
295    }
296}
297
298/// Represents a protocol in the multiplexer that is used as the primary protocol.
299#[derive(Debug)]
300struct PrimaryProtocol<Primary> {
301    /// Channel to send messages to the primary protocol.
302    to_primary: UnboundedSender<BytesMut>,
303    /// Receiver for messages from the primary protocol.
304    from_primary: UnboundedReceiverStream<Bytes>,
305    /// Shared capability of the primary protocol.
306    shared_cap: SharedCapability,
307    /// The primary stream.
308    st: Primary,
309}
310
311/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
312///
313/// Only emits and sends _non-empty_ messages
314#[derive(Debug)]
315pub struct ProtocolProxy {
316    shared_cap: SharedCapability,
317    /// Receives _non-empty_ messages from the wire
318    from_wire: UnboundedReceiverStream<BytesMut>,
319    /// Sends _non-empty_ messages from the wire
320    to_wire: UnboundedSender<Bytes>,
321}
322
323impl ProtocolProxy {
324    /// Sends a _non-empty_ message on the wire.
325    fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
326        if msg.is_empty() {
327            // message must not be empty
328            return Err(io::ErrorKind::InvalidInput.into())
329        }
330        self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
331    }
332
333    /// Masks the message ID of a message to be sent on the wire.
334    #[inline]
335    fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
336        if msg.is_empty() {
337            // message must not be empty
338            return Err(io::ErrorKind::InvalidInput.into())
339        }
340
341        let offset = self.shared_cap.relative_message_id_offset();
342        if offset == 0 {
343            return Ok(msg);
344        }
345
346        let mut masked: BytesMut = msg.into();
347        masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
348        Ok(masked.freeze())
349    }
350
351    /// Unmasks the message ID of a message received from the wire.
352    #[inline]
353    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
354        if msg.is_empty() {
355            // message must not be empty
356            return Err(io::ErrorKind::InvalidInput.into())
357        }
358        msg[0] = msg[0]
359            .checked_sub(self.shared_cap.relative_message_id_offset())
360            .ok_or(io::ErrorKind::InvalidInput)?;
361        Ok(msg)
362    }
363}
364
365impl Stream for ProtocolProxy {
366    type Item = Result<BytesMut, io::Error>;
367
368    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
369        let msg = ready!(self.from_wire.poll_next_unpin(cx));
370        Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
371    }
372}
373
374impl Sink<Bytes> for ProtocolProxy {
375    type Error = io::Error;
376
377    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
378        Poll::Ready(Ok(()))
379    }
380
381    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
382        self.get_mut().try_send(item)
383    }
384
385    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
386        Poll::Ready(Ok(()))
387    }
388
389    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
390        Poll::Ready(Ok(()))
391    }
392}
393
394impl CanDisconnect<Bytes> for ProtocolProxy {
395    fn disconnect(
396        &mut self,
397        _reason: DisconnectReason,
398    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
399        Box::pin(async move { Ok(()) })
400    }
401}
402
403/// Adapter so the injected `EthRlpxHandshake` can run over a multiplexed `ProtocolProxy`
404/// using the same error type expectations (`P2PStreamError`).
405#[derive(Debug)]
406struct UnauthProxy {
407    inner: ProtocolProxy,
408}
409
410impl UnauthProxy {
411    fn into_inner(self) -> ProtocolProxy {
412        self.inner
413    }
414}
415
416impl Stream for UnauthProxy {
417    type Item = Result<BytesMut, P2PStreamError>;
418
419    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
420        self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
421    }
422}
423
424impl Sink<Bytes> for UnauthProxy {
425    type Error = P2PStreamError;
426
427    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428        self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
429    }
430
431    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
432        self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
433    }
434
435    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436        self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
437    }
438
439    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
440        self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
441    }
442}
443
444impl CanDisconnect<Bytes> for UnauthProxy {
445    fn disconnect(
446        &mut self,
447        reason: DisconnectReason,
448    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
449        let fut = self.inner.disconnect(reason);
450        Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
451    }
452}
453
454/// A connection channel to receive _`non_empty`_ messages for the negotiated protocol.
455///
456/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
457#[derive(Debug)]
458pub struct ProtocolConnection {
459    from_wire: UnboundedReceiverStream<BytesMut>,
460}
461
462impl Stream for ProtocolConnection {
463    type Item = BytesMut;
464
465    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
466        self.from_wire.poll_next_unpin(cx)
467    }
468}
469
470/// A Stream and Sink type that acts as a wrapper around a primary `RLPx` subprotocol (e.g. "eth")
471/// [`EthStream`] and can also handle additional subprotocols.
472#[derive(Debug)]
473pub struct RlpxSatelliteStream<St, Primary> {
474    inner: MultiplexInner<St>,
475    primary: PrimaryProtocol<Primary>,
476    /// Round-robin cursor for the next outbound producer to poll.
477    next_outbound: usize,
478}
479
480impl<St, Primary> RlpxSatelliteStream<St, Primary> {
481    /// Installs a new protocol on top of the raw p2p stream.
482    ///
483    /// This accepts a closure that receives a [`ProtocolConnection`] that will yield messages for
484    /// the given capability.
485    pub fn install_protocol<F, Proto>(
486        &mut self,
487        cap: &Capability,
488        f: F,
489    ) -> Result<(), UnsupportedCapabilityError>
490    where
491        F: FnOnce(ProtocolConnection) -> Proto,
492        Proto: Stream<Item = BytesMut> + Send + 'static,
493    {
494        self.inner.install_protocol(cap, f)
495    }
496
497    /// Returns the primary protocol.
498    #[inline]
499    pub const fn primary(&self) -> &Primary {
500        &self.primary.st
501    }
502
503    /// Returns mutable access to the primary protocol.
504    #[inline]
505    pub const fn primary_mut(&mut self) -> &mut Primary {
506        &mut self.primary.st
507    }
508
509    /// Returns the underlying [`P2PStream`].
510    #[inline]
511    pub const fn inner(&self) -> &P2PStream<St> {
512        &self.inner.conn
513    }
514
515    /// Returns mutable access to the underlying [`P2PStream`].
516    #[inline]
517    pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
518        &mut self.inner.conn
519    }
520
521    /// Consumes this type and returns the wrapped [`P2PStream`].
522    #[inline]
523    pub fn into_inner(self) -> P2PStream<St> {
524        self.inner.conn
525    }
526
527    /// Polls primary and satellite outbound producers round-robin until the `OutBuffer` is full or
528    /// every producer is pending.
529    ///
530    /// The cursor advances after each producer poll, so a ready producer cannot drain repeatedly
531    /// before later producers get a turn.
532    fn poll_outbound_producers(&mut self, cx: &mut Context<'_>) -> Result<ProducerPoll, io::Error> {
533        let producers = self.inner.protocols.len() + 1;
534        let mut pending = 0;
535
536        while pending < producers {
537            if self.inner.out_buffer.is_full() {
538                return Ok(ProducerPoll::Full)
539            }
540
541            if self.next_outbound >= producers {
542                self.next_outbound = 0;
543            }
544
545            let producer = self.next_outbound;
546            self.next_outbound = (self.next_outbound + 1) % producers;
547
548            let msg = if producer == 0 {
549                match self.primary.from_primary.poll_next_unpin(cx) {
550                    Poll::Ready(Some(msg)) => msg,
551                    Poll::Ready(None) => return Ok(ProducerPoll::Closed),
552                    Poll::Pending => {
553                        pending += 1;
554                        continue
555                    }
556                }
557            } else {
558                let proto = self
559                    .inner
560                    .protocols
561                    .get_mut(producer - 1)
562                    .expect("outbound producer index checked against protocol count");
563                match proto.poll_next_unpin(cx) {
564                    Poll::Ready(Some(msg)) => msg?,
565                    Poll::Ready(None) => return Ok(ProducerPoll::Closed),
566                    Poll::Pending => {
567                        pending += 1;
568                        continue
569                    }
570                }
571            };
572
573            pending = 0;
574            self.inner.out_buffer.push_back(msg);
575        }
576
577        Ok(ProducerPoll::Pending)
578    }
579}
580
581impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
582where
583    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
584    Primary: TryStream<Error = PrimaryErr> + Unpin,
585    P2PStreamError: Into<PrimaryErr>,
586{
587    type Item = Result<Primary::Ok, Primary::Error>;
588
589    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
590        let this = self.get_mut();
591
592        loop {
593            // first drain the primary stream
594            if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
595                return Poll::Ready(Some(msg))
596            }
597
598            let mut conn_ready = true;
599            loop {
600                match this.inner.conn.poll_ready_unpin(cx) {
601                    Poll::Ready(Ok(())) => {
602                        if let Some(msg) = this.inner.out_buffer.pop_front() {
603                            if let Err(err) = this.inner.conn.start_send_unpin(msg) {
604                                return Poll::Ready(Some(Err(err.into())))
605                            }
606                        } else {
607                            break
608                        }
609                    }
610                    Poll::Ready(Err(err)) => {
611                        if let Err(disconnect_err) =
612                            this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
613                        {
614                            return Poll::Ready(Some(Err(disconnect_err.into())))
615                        }
616                        return Poll::Ready(Some(Err(err.into())))
617                    }
618                    Poll::Pending => {
619                        conn_ready = false;
620                        break
621                    }
622                }
623            }
624
625            match this.poll_outbound_producers(cx) {
626                Ok(ProducerPoll::Pending | ProducerPoll::Full) => {}
627                Ok(ProducerPoll::Closed) => return Poll::Ready(None),
628                Err(err) => return Poll::Ready(Some(Err(P2PStreamError::Io(err).into()))),
629            }
630
631            let mut delegated = false;
632            loop {
633                // pull messages from connection
634                match this.inner.conn.poll_next_unpin(cx) {
635                    Poll::Ready(Some(Ok(msg))) => {
636                        delegated = true;
637                        let Some(offset) = msg.first().copied() else {
638                            return Poll::Ready(Some(Err(
639                                P2PStreamError::EmptyProtocolMessage.into()
640                            )))
641                        };
642                        // delegate the multiplexed message to the correct protocol
643                        if let Some(cap) =
644                            this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
645                        {
646                            if cap == &this.primary.shared_cap {
647                                // delegate to primary
648                                let _ = this.primary.to_primary.send(msg);
649                            } else {
650                                // delegate to installed satellite if any
651                                for proto in &this.inner.protocols {
652                                    if proto.shared_cap == *cap {
653                                        proto.send_raw(msg);
654                                        break
655                                    }
656                                }
657                            }
658                        } else {
659                            return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
660                                offset,
661                            )
662                            .into())))
663                        }
664                    }
665                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
666                    Poll::Ready(None) => {
667                        // connection closed
668                        return Poll::Ready(None)
669                    }
670                    Poll::Pending => break,
671                }
672            }
673
674            if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
675                return Poll::Pending
676            }
677        }
678    }
679}
680
681impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
682where
683    St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
684    Primary: Sink<T> + Unpin,
685    P2PStreamError: Into<<Primary as Sink<T>>::Error>,
686{
687    type Error = <Primary as Sink<T>>::Error;
688
689    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
690        let this = self.get_mut();
691        if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
692            return Poll::Ready(Err(err.into()))
693        }
694        if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
695            return Poll::Ready(Err(err))
696        }
697        Poll::Ready(Ok(()))
698    }
699
700    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
701        self.get_mut().primary.st.start_send_unpin(item)
702    }
703
704    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
705        self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
706    }
707
708    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
709        self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
710    }
711}
712
713/// Wraps a `RLPx` subprotocol and handles message ID multiplexing.
714struct ProtocolStream {
715    shared_cap: SharedCapability,
716    /// the channel shared with the satellite stream
717    to_satellite: UnboundedSender<BytesMut>,
718    satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
719}
720
721impl ProtocolStream {
722    /// Masks the message ID of a message to be sent on the wire.
723    #[inline]
724    fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
725        if msg.is_empty() {
726            // message must not be empty
727            return Err(io::ErrorKind::InvalidInput.into())
728        }
729        msg[0] = msg[0]
730            .checked_add(self.shared_cap.relative_message_id_offset())
731            .ok_or(io::ErrorKind::InvalidInput)?;
732        Ok(msg.freeze())
733    }
734
735    /// Unmasks the message ID of a message received from the wire.
736    #[inline]
737    fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
738        if msg.is_empty() {
739            // message must not be empty
740            return Err(io::ErrorKind::InvalidInput.into())
741        }
742        msg[0] = msg[0]
743            .checked_sub(self.shared_cap.relative_message_id_offset())
744            .ok_or(io::ErrorKind::InvalidInput)?;
745        Ok(msg)
746    }
747
748    /// Sends the message to the satellite stream.
749    fn send_raw(&self, msg: BytesMut) {
750        let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
751    }
752}
753
754impl Stream for ProtocolStream {
755    type Item = Result<Bytes, io::Error>;
756
757    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
758        let this = self.get_mut();
759        let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
760        Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
761    }
762}
763
764impl fmt::Debug for ProtocolStream {
765    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766        f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
767    }
768}
769
770/// Helper to poll multiple protocol streams in a `tokio::select`! branch
771struct ProtocolsPoller<'a> {
772    protocols: &'a mut VecDeque<ProtocolStream>,
773}
774
775impl<'a> ProtocolsPoller<'a> {
776    const fn new(protocols: &'a mut VecDeque<ProtocolStream>) -> Self {
777        Self { protocols }
778    }
779}
780
781impl<'a> Future for ProtocolsPoller<'a> {
782    type Output = Result<Bytes, P2PStreamError>;
783
784    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
785        let protocols = self.protocols.len();
786        for _ in 0..protocols {
787            let mut proto = self.protocols.pop_front().expect("protocol count checked");
788            match proto.poll_next_unpin(cx) {
789                Poll::Ready(Some(Err(err))) => {
790                    self.protocols.push_back(proto);
791                    return Poll::Ready(Err(P2PStreamError::from(err)))
792                }
793                Poll::Ready(Some(Ok(msg))) => {
794                    // Got a message, put protocol back and return the message
795                    self.protocols.push_back(proto);
796                    return Poll::Ready(Ok(msg));
797                }
798                _ => {
799                    // push it back because we still want to complete the handshake first
800                    self.protocols.push_back(proto);
801                }
802            }
803        }
804
805        // All protocols processed, nothing ready
806        Poll::Pending
807    }
808}
809
810/// Soft cap for per-connection outbound `RLPx` messages waiting in the multiplexer.
811///
812/// The cap is soft because the next message size is only known after polling a protocol stream.
813/// The buffer may exceed this by at most one message before producer polling is paused.
814///
815/// The lower [`P2PStream`] sink admits two outbound messages and rejects uncompressed payloads
816/// above 16 MiB, so 32 MiB mirrors the largest payload volume the lower p2p layer is already
817/// prepared to buffer.
818const MAX_MUX_OUT_BUFFER_BYTES: usize = 32 * 1024 * 1024;
819
820#[derive(Debug)]
821struct OutBuffer {
822    messages: VecDeque<Bytes>,
823    bytes: usize,
824    max_bytes: usize,
825}
826
827impl Default for OutBuffer {
828    fn default() -> Self {
829        Self { messages: Default::default(), bytes: 0, max_bytes: MAX_MUX_OUT_BUFFER_BYTES }
830    }
831}
832
833impl OutBuffer {
834    fn push_back(&mut self, msg: Bytes) {
835        self.bytes += msg.len();
836        self.messages.push_back(msg);
837    }
838
839    fn pop_front(&mut self) -> Option<Bytes> {
840        let msg = self.messages.pop_front()?;
841        self.bytes -= msg.len();
842        Some(msg)
843    }
844
845    fn is_empty(&self) -> bool {
846        self.messages.is_empty()
847    }
848
849    const fn is_full(&self) -> bool {
850        self.bytes >= self.max_bytes
851    }
852}
853
854/// Result of polling outbound producers into the mux buffer.
855#[derive(Clone, Copy, Debug, Eq, PartialEq)]
856enum ProducerPoll {
857    /// All outbound producers are pending.
858    Pending,
859    /// The mux buffer reached its soft cap.
860    Full,
861    /// An outbound producer closed.
862    Closed,
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868    use crate::{
869        handshake::EthHandshake,
870        message::MAX_MESSAGE_SIZE,
871        protocol::Protocol,
872        test_utils::{
873            connect_passthrough, eth_handshake, eth_hello,
874            proto::{test_hello, TestProtoMessage},
875        },
876        UnauthedEthStream, UnauthedP2PStream,
877    };
878    use futures::{stream, task::noop_waker_ref};
879    use reth_eth_wire_types::EthNetworkPrimitives;
880    use std::task::Poll;
881    use tokio::{net::TcpListener, sync::oneshot};
882    use tokio_util::codec::Decoder;
883
884    #[derive(Debug)]
885    struct PendingPrimary {
886        _proxy: ProtocolProxy,
887    }
888
889    impl Stream for PendingPrimary {
890        type Item = Result<(), P2PStreamError>;
891
892        fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
893            Poll::Pending
894        }
895    }
896
897    #[derive(Debug)]
898    struct StalledTransport;
899
900    impl Stream for StalledTransport {
901        type Item = io::Result<BytesMut>;
902
903        fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
904            Poll::Pending
905        }
906    }
907
908    impl Sink<Bytes> for StalledTransport {
909        type Error = io::Error;
910
911        fn poll_ready(
912            self: Pin<&mut Self>,
913            _cx: &mut Context<'_>,
914        ) -> Poll<Result<(), Self::Error>> {
915            Poll::Pending
916        }
917
918        fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> {
919            Ok(())
920        }
921
922        fn poll_flush(
923            self: Pin<&mut Self>,
924            _cx: &mut Context<'_>,
925        ) -> Poll<Result<(), Self::Error>> {
926            Poll::Pending
927        }
928
929        fn poll_close(
930            self: Pin<&mut Self>,
931            _cx: &mut Context<'_>,
932        ) -> Poll<Result<(), Self::Error>> {
933            Poll::Pending
934        }
935    }
936
937    #[tokio::test]
938    async fn satellite_mux_stops_polling_protocols_when_out_buffer_is_full() {
939        let (hello, _) = test_hello();
940        let shared_capabilities =
941            SharedCapabilities::try_new(hello.protocols.clone(), hello.message().capabilities)
942                .unwrap();
943        let conn = P2PStream::new(StalledTransport, shared_capabilities);
944        let eth = conn.shared_capabilities().eth().unwrap().clone();
945
946        let mut st = RlpxProtocolMultiplexer::new(conn)
947            .into_satellite_stream(eth.capability().as_ref(), |proxy| PendingPrimary {
948                _proxy: proxy,
949            })
950            .unwrap();
951        const MESSAGE_COUNT: usize = 4096;
952        const MESSAGE_BYTES: usize = 1024;
953        st.inner.out_buffer.max_bytes = 4 * MESSAGE_BYTES + 1;
954        st.install_protocol(&TestProtoMessage::capability(), |_conn| {
955            stream::iter((0..MESSAGE_COUNT).map(|_| {
956                let mut msg = BytesMut::zeroed(MESSAGE_BYTES);
957                msg[0] = TestProtoMessage::ping().message_type as u8;
958                msg
959            }))
960        })
961        .unwrap();
962
963        let mut cx = Context::from_waker(noop_waker_ref());
964        assert!(Pin::new(&mut st).poll_next(&mut cx).is_pending());
965
966        assert!(st.inner.out_buffer.bytes > st.inner.out_buffer.max_bytes);
967        assert!(st.inner.out_buffer.bytes <= st.inner.out_buffer.max_bytes + MESSAGE_BYTES);
968        assert!(st.inner.out_buffer.messages.len() < MESSAGE_COUNT);
969    }
970
971    #[tokio::test]
972    async fn satellite_mux_round_robins_ready_protocols_when_out_buffer_fills() {
973        let (mut hello, _) = eth_hello();
974        let cap_a = Capability::new_static("aaa", 1);
975        let cap_b = Capability::new_static("bbb", 1);
976        hello.protocols.push(Protocol::new(cap_a.clone(), 1));
977        hello.protocols.push(Protocol::new(cap_b.clone(), 1));
978
979        let shared_capabilities =
980            SharedCapabilities::try_new(hello.protocols.clone(), hello.message().capabilities)
981                .unwrap();
982        let conn = P2PStream::new(StalledTransport, shared_capabilities);
983        let eth = conn.shared_capabilities().eth().unwrap().clone();
984        let cap_a_offset =
985            conn.shared_capabilities().find(&cap_a).unwrap().relative_message_id_offset();
986        let cap_b_offset =
987            conn.shared_capabilities().find(&cap_b).unwrap().relative_message_id_offset();
988
989        let mut st = RlpxProtocolMultiplexer::new(conn)
990            .into_satellite_stream(eth.capability().as_ref(), |proxy| PendingPrimary {
991                _proxy: proxy,
992            })
993            .unwrap();
994        st.inner.out_buffer.max_bytes = 5;
995        st.install_protocol(&cap_a, |_conn| {
996            stream::iter((0..16).map(|_| BytesMut::from(&[0, b'a'][..])))
997        })
998        .unwrap();
999        st.install_protocol(&cap_b, |_conn| {
1000            stream::iter((0..16).map(|_| BytesMut::from(&[0, b'b'][..])))
1001        })
1002        .unwrap();
1003
1004        let mut cx = Context::from_waker(noop_waker_ref());
1005        assert!(Pin::new(&mut st).poll_next(&mut cx).is_pending());
1006
1007        let message_ids =
1008            st.inner.out_buffer.messages.iter().take(2).map(|msg| msg[0]).collect::<Vec<_>>();
1009        assert_eq!(message_ids.len(), 2);
1010        assert_ne!(message_ids[0], message_ids[1]);
1011        assert!(message_ids.contains(&cap_a_offset));
1012        assert!(message_ids.contains(&cap_b_offset));
1013    }
1014
1015    #[tokio::test]
1016    async fn eth_satellite() {
1017        reth_tracing::init_test_tracing();
1018        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1019        let local_addr = listener.local_addr().unwrap();
1020        let (status, fork_filter) = eth_handshake();
1021        let other_status = status;
1022        let other_fork_filter = fork_filter.clone();
1023        let _handle = tokio::spawn(async move {
1024            let (incoming, _) = listener.accept().await.unwrap();
1025            let stream = crate::PassthroughCodec::default().framed(incoming);
1026            let (server_hello, _) = eth_hello();
1027            let (p2p_stream, _) =
1028                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
1029
1030            let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
1031                .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
1032                .await
1033                .unwrap();
1034
1035            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1036        });
1037
1038        let conn = connect_passthrough(local_addr, eth_hello().0).await;
1039        let eth = conn.shared_capabilities().eth().unwrap().clone();
1040
1041        let multiplexer = RlpxProtocolMultiplexer::new(conn);
1042        let _satellite = multiplexer
1043            .into_satellite_stream_with_handshake(eth.capability().as_ref(), async move |proxy| {
1044                UnauthedEthStream::new(proxy)
1045                    .handshake::<EthNetworkPrimitives>(status, fork_filter)
1046                    .await
1047            })
1048            .await
1049            .unwrap();
1050    }
1051
1052    /// A test that install a satellite stream eth+test protocol and sends messages between them.
1053    #[tokio::test(flavor = "multi_thread")]
1054    async fn eth_test_protocol_satellite() {
1055        reth_tracing::init_test_tracing();
1056        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1057        let local_addr = listener.local_addr().unwrap();
1058        let (status, fork_filter) = eth_handshake();
1059        let other_status = status;
1060        let other_fork_filter = fork_filter.clone();
1061        let _handle = tokio::spawn(async move {
1062            let (incoming, _) = listener.accept().await.unwrap();
1063            let stream = crate::PassthroughCodec::default().framed(incoming);
1064            let (server_hello, _) = test_hello();
1065            let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
1066
1067            let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
1068                .into_eth_satellite_stream::<EthNetworkPrimitives>(
1069                    other_status,
1070                    other_fork_filter,
1071                    Arc::new(EthHandshake::default()),
1072                    MAX_MESSAGE_SIZE,
1073                )
1074                .await
1075                .unwrap();
1076
1077            st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
1078                async_stream::stream! {
1079                    yield TestProtoMessage::ping().encoded();
1080                    let msg = conn.next().await.unwrap();
1081                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1082                    assert_eq!(msg, TestProtoMessage::pong());
1083
1084                    yield TestProtoMessage::message("hello").encoded();
1085                    let msg = conn.next().await.unwrap();
1086                    let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1087                    assert_eq!(msg, TestProtoMessage::message("good bye!"));
1088
1089                    yield TestProtoMessage::message("good bye!").encoded();
1090
1091                    futures::future::pending::<()>().await;
1092                    unreachable!()
1093                }
1094            })
1095            .unwrap();
1096
1097            loop {
1098                let _ = st.next().await;
1099            }
1100        });
1101
1102        let conn = connect_passthrough(local_addr, test_hello().0).await;
1103        let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
1104            .into_eth_satellite_stream::<EthNetworkPrimitives>(
1105                status,
1106                fork_filter,
1107                Arc::new(EthHandshake::default()),
1108                MAX_MESSAGE_SIZE,
1109            )
1110            .await
1111            .unwrap();
1112
1113        let (tx, mut rx) = oneshot::channel();
1114
1115        st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
1116            async_stream::stream! {
1117                let msg = conn.next().await.unwrap();
1118                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1119                assert_eq!(msg, TestProtoMessage::ping());
1120
1121                yield TestProtoMessage::pong().encoded();
1122
1123                let msg = conn.next().await.unwrap();
1124                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1125                assert_eq!(msg, TestProtoMessage::message("hello"));
1126
1127                yield TestProtoMessage::message("good bye!").encoded();
1128
1129                let msg = conn.next().await.unwrap();
1130                let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1131                assert_eq!(msg, TestProtoMessage::message("good bye!"));
1132
1133                tx.send(()).unwrap();
1134
1135                futures::future::pending::<()>().await;
1136                unreachable!()
1137            }
1138        })
1139        .unwrap();
1140
1141        loop {
1142            tokio::select! {
1143                _ = &mut rx => {
1144                    break
1145                }
1146               _ = st.next() => {
1147                }
1148            }
1149        }
1150    }
1151}