reth_eth_wire/
p2pstream.rs

1use crate::{
2    capability::SharedCapabilities,
3    disconnect::CanDisconnect,
4    errors::{P2PHandshakeError, P2PStreamError},
5    pinger::{Pinger, PingerEvent},
6    DisconnectReason, HelloMessage, HelloMessageWithProtocols,
7};
8use alloy_primitives::{
9    bytes::{Buf, BufMut, Bytes, BytesMut},
10    hex,
11};
12use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE};
13use futures::{Sink, SinkExt, StreamExt};
14use pin_project::pin_project;
15use reth_codecs::add_arbitrary_tests;
16use reth_metrics::metrics::counter;
17use reth_primitives_traits::GotExpected;
18use std::{
19    collections::VecDeque,
20    future::Future,
21    io,
22    pin::Pin,
23    task::{ready, Context, Poll},
24    time::Duration,
25};
26use tokio_stream::Stream;
27use tracing::{debug, trace};
28
29#[cfg(feature = "serde")]
30use serde::{Deserialize, Serialize};
31
32/// [`MAX_PAYLOAD_SIZE`] is the maximum size of an uncompressed message payload.
33/// This is defined in [EIP-706](https://eips.ethereum.org/EIPS/eip-706).
34const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36/// [`MAX_RESERVED_MESSAGE_ID`] is the maximum message ID reserved for the `p2p` subprotocol. If
37/// there are any incoming messages with an ID greater than this, they are subprotocol messages.
38pub const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f;
39
40/// [`MAX_P2P_MESSAGE_ID`] is the maximum message ID in use for the `p2p` subprotocol.
41const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
42
43/// [`HANDSHAKE_TIMEOUT`] determines the amount of time to wait before determining that a `p2p`
44/// handshake has timed out.
45pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
46
47/// [`PING_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` ping has
48/// timed out.
49const PING_TIMEOUT: Duration = Duration::from_secs(15);
50
51/// [`PING_INTERVAL`] determines the amount of time to wait between sending `p2p` ping messages
52/// when the peer is responsive.
53const PING_INTERVAL: Duration = Duration::from_secs(60);
54
55/// [`MAX_P2P_CAPACITY`] is the maximum number of messages that can be buffered to be sent in the
56/// `p2p` stream.
57///
58/// Note: this default is rather low because it is expected that the [`P2PStream`] wraps an
59/// [`ECIESStream`](reth_ecies::stream::ECIESStream) which internally already buffers a few MB of
60/// encoded data.
61const MAX_P2P_CAPACITY: usize = 2;
62
63/// An un-authenticated [`P2PStream`]. This is consumed and returns a [`P2PStream`] after the
64/// `Hello` handshake is completed.
65#[pin_project]
66#[derive(Debug)]
67pub struct UnauthedP2PStream<S> {
68    #[pin]
69    inner: S,
70}
71
72impl<S> UnauthedP2PStream<S> {
73    /// Create a new `UnauthedP2PStream` from a type `S` which implements `Stream` and `Sink`.
74    pub const fn new(inner: S) -> Self {
75        Self { inner }
76    }
77
78    /// Returns a reference to the inner stream.
79    pub const fn inner(&self) -> &S {
80        &self.inner
81    }
82}
83
84impl<S> UnauthedP2PStream<S>
85where
86    S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
87{
88    /// Consumes the `UnauthedP2PStream` and returns a `P2PStream` after the `Hello` handshake is
89    /// completed successfully. This also returns the `Hello` message sent by the remote peer.
90    pub async fn handshake(
91        mut self,
92        hello: HelloMessageWithProtocols,
93    ) -> Result<(P2PStream<S>, HelloMessage), P2PStreamError> {
94        trace!(?hello, "sending p2p hello to peer");
95
96        // send our hello message with the Sink
97        self.inner.send(alloy_rlp::encode(P2PMessage::Hello(hello.message())).into()).await?;
98
99        let first_message_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next())
100            .await
101            .or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))?
102            .ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??;
103
104        // let's check the compressed length first, we will need to check again once confirming
105        // that it contains snappy-compressed data (this will be the case for all non-p2p messages).
106        if first_message_bytes.len() > MAX_PAYLOAD_SIZE {
107            return Err(P2PStreamError::MessageTooBig {
108                message_size: first_message_bytes.len(),
109                max_size: MAX_PAYLOAD_SIZE,
110            })
111        }
112
113        // The first message sent MUST be a hello OR disconnect message
114        //
115        // If the first message is a disconnect message, we should not decode using
116        // Decodable::decode, because the first message (either Disconnect or Hello) is not snappy
117        // compressed, and the Decodable implementation assumes that non-hello messages are snappy
118        // compressed.
119        let their_hello = match P2PMessage::decode(&mut &first_message_bytes[..]) {
120            Ok(P2PMessage::Hello(hello)) => Ok(hello),
121            Ok(P2PMessage::Disconnect(reason)) => {
122                if matches!(reason, DisconnectReason::TooManyPeers) {
123                    // Too many peers is a very common disconnect reason that spams the DEBUG logs
124                    trace!(%reason, "Disconnected by peer during handshake");
125                } else {
126                    debug!(%reason, "Disconnected by peer during handshake");
127                };
128                counter!("p2pstream.disconnected_errors").increment(1);
129                Err(P2PStreamError::HandshakeError(P2PHandshakeError::Disconnected(reason)))
130            }
131            Err(err) => {
132                debug!(%err, msg=%hex::encode(&first_message_bytes), "Failed to decode first message from peer");
133                Err(P2PStreamError::HandshakeError(err.into()))
134            }
135            Ok(msg) => {
136                debug!(?msg, "expected hello message but received another message");
137                Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake))
138            }
139        }?;
140
141        trace!(
142            hello=?their_hello,
143            "validating incoming p2p hello from peer"
144        );
145
146        if (hello.protocol_version as u8) != their_hello.protocol_version as u8 {
147            // send a disconnect message notifying the peer of the protocol version mismatch
148            self.send_disconnect(DisconnectReason::IncompatibleP2PProtocolVersion).await?;
149            return Err(P2PStreamError::MismatchedProtocolVersion(GotExpected {
150                got: their_hello.protocol_version,
151                expected: hello.protocol_version,
152            }))
153        }
154
155        // determine shared capabilities (currently returns only one capability)
156        let capability_res =
157            SharedCapabilities::try_new(hello.protocols, their_hello.capabilities.clone());
158
159        let shared_capability = match capability_res {
160            Err(err) => {
161                // we don't share any capabilities, send a disconnect message
162                self.send_disconnect(DisconnectReason::UselessPeer).await?;
163                Err(err)
164            }
165            Ok(cap) => Ok(cap),
166        }?;
167
168        let stream = P2PStream::new(self.inner, shared_capability);
169
170        Ok((stream, their_hello))
171    }
172}
173
174impl<S> UnauthedP2PStream<S>
175where
176    S: Sink<Bytes, Error = io::Error> + Unpin,
177{
178    /// Send a disconnect message during the handshake. This is sent without snappy compression.
179    pub async fn send_disconnect(
180        &mut self,
181        reason: DisconnectReason,
182    ) -> Result<(), P2PStreamError> {
183        trace!(
184            %reason,
185            "Sending disconnect message during the handshake",
186        );
187        self.inner
188            .send(Bytes::from(alloy_rlp::encode(P2PMessage::Disconnect(reason))))
189            .await
190            .map_err(P2PStreamError::Io)
191    }
192}
193
194impl<S> CanDisconnect<Bytes> for P2PStream<S>
195where
196    S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
197{
198    fn disconnect(
199        &mut self,
200        reason: DisconnectReason,
201    ) -> Pin<Box<dyn Future<Output = Result<(), P2PStreamError>> + Send + '_>> {
202        Box::pin(async move { self.disconnect(reason).await })
203    }
204}
205
206/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
207/// protocol messages.
208///
209/// This stream supports multiple shared capabilities, that were negotiated during the handshake.
210///
211/// ### Message-ID based multiplexing
212///
213/// > Each capability is given as much of the message-ID space as it needs. All such capabilities
214/// > must statically specify how many message IDs they require. On connection and reception of the
215/// > Hello message, both peers have equivalent information about what capabilities they share
216/// > (including versions) and are able to form consensus over the composition of message ID space.
217///
218/// > Message IDs are assumed to be compact from ID 0x10 onwards (0x00-0x0f is reserved for the
219/// > "p2p" capability) and given to each shared (equal-version, equal-name) capability in
220/// > alphabetic order. Capability names are case-sensitive. Capabilities which are not shared are
221/// > ignored. If multiple versions are shared of the same (equal name) capability, the numerically
222/// > highest wins, others are ignored.
223///
224/// See also <https://github.com/ethereum/devp2p/blob/master/rlpx.md#message-id-based-multiplexing>
225///
226/// This stream emits _non-empty_ Bytes that start with the normalized message id, so that the first
227/// byte of each message starts from 0. If this stream only supports a single capability, for
228/// example `eth` then the first byte of each message will match
229/// [EthMessageID](reth_eth_wire_types::message::EthMessageID).
230#[pin_project]
231#[derive(Debug)]
232pub struct P2PStream<S> {
233    #[pin]
234    inner: S,
235
236    /// The snappy encoder used for compressing outgoing messages
237    encoder: snap::raw::Encoder,
238
239    /// The snappy decoder used for decompressing incoming messages
240    decoder: snap::raw::Decoder,
241
242    /// The state machine used for keeping track of the peer's ping status.
243    pinger: Pinger,
244
245    /// The supported capability for this stream.
246    shared_capabilities: SharedCapabilities,
247
248    /// Outgoing messages buffered for sending to the underlying stream.
249    outgoing_messages: VecDeque<Bytes>,
250
251    /// Maximum number of messages that we can buffer here before the [Sink] impl returns
252    /// [`Poll::Pending`].
253    outgoing_message_buffer_capacity: usize,
254
255    /// Whether this stream is currently in the process of disconnecting by sending a disconnect
256    /// message.
257    disconnecting: bool,
258}
259
260impl<S> P2PStream<S> {
261    /// Create a new [`P2PStream`] from the provided stream.
262    /// New [`P2PStream`]s are assumed to have completed the `p2p` handshake successfully and are
263    /// ready to send and receive subprotocol messages.
264    pub fn new(inner: S, shared_capabilities: SharedCapabilities) -> Self {
265        Self {
266            inner,
267            encoder: snap::raw::Encoder::new(),
268            decoder: snap::raw::Decoder::new(),
269            pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
270            shared_capabilities,
271            outgoing_messages: VecDeque::new(),
272            outgoing_message_buffer_capacity: MAX_P2P_CAPACITY,
273            disconnecting: false,
274        }
275    }
276
277    /// Returns a reference to the inner stream.
278    pub const fn inner(&self) -> &S {
279        &self.inner
280    }
281
282    /// Sets a custom outgoing message buffer capacity.
283    ///
284    /// # Panics
285    ///
286    /// If the provided capacity is `0`.
287    pub const fn set_outgoing_message_buffer_capacity(&mut self, capacity: usize) {
288        self.outgoing_message_buffer_capacity = capacity;
289    }
290
291    /// Returns the shared capabilities for this stream.
292    ///
293    /// This includes all the shared capabilities that were negotiated during the handshake and
294    /// their offsets based on the number of messages of each capability.
295    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
296        &self.shared_capabilities
297    }
298
299    /// Returns `true` if the stream has outgoing capacity.
300    fn has_outgoing_capacity(&self) -> bool {
301        self.outgoing_messages.len() < self.outgoing_message_buffer_capacity
302    }
303
304    /// Queues in a _snappy_ encoded [`P2PMessage::Pong`] message.
305    fn send_pong(&mut self) {
306        self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Pong)));
307    }
308
309    /// Queues in a _snappy_ encoded [`P2PMessage::Ping`] message.
310    pub fn send_ping(&mut self) {
311        self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Ping)));
312    }
313}
314
315/// Gracefully disconnects the connection by sending a disconnect message and stop reading new
316/// messages.
317pub trait DisconnectP2P {
318    /// Starts to gracefully disconnect.
319    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>;
320
321    /// Returns `true` if the connection is about to disconnect.
322    fn is_disconnecting(&self) -> bool;
323}
324
325impl<S> DisconnectP2P for P2PStream<S> {
326    /// Starts to gracefully disconnect the connection by sending a Disconnect message and stop
327    /// reading new messages.
328    ///
329    /// Once disconnect process has started, the [`Stream`] will terminate immediately.
330    ///
331    /// # Errors
332    ///
333    /// Returns an error only if the message fails to compress.
334    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
335        // clear any buffered messages and queue in
336        self.outgoing_messages.clear();
337        let disconnect = P2PMessage::Disconnect(reason);
338        let mut buf = Vec::with_capacity(disconnect.length());
339        disconnect.encode(&mut buf);
340
341        let mut compressed = vec![0u8; 1 + snap::raw::max_compress_len(buf.len() - 1)];
342        let compressed_size =
343            self.encoder.compress(&buf[1..], &mut compressed[1..]).map_err(|err| {
344                debug!(
345                    %err,
346                    msg=%hex::encode(&buf[1..]),
347                    "error compressing disconnect"
348                );
349                err
350            })?;
351
352        // truncate the compressed buffer to the actual compressed size (plus one for the message
353        // id)
354        compressed.truncate(compressed_size + 1);
355
356        // we do not add the capability offset because the disconnect message is a `p2p` reserved
357        // message
358        compressed[0] = buf[0];
359
360        self.outgoing_messages.push_back(compressed.into());
361        self.disconnecting = true;
362        Ok(())
363    }
364
365    fn is_disconnecting(&self) -> bool {
366        self.disconnecting
367    }
368}
369
370impl<S> P2PStream<S>
371where
372    S: Sink<Bytes, Error = io::Error> + Unpin + Send,
373{
374    /// Disconnects the connection by sending a disconnect message.
375    ///
376    /// This future resolves once the disconnect message has been sent and the stream has been
377    /// closed.
378    pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
379        self.start_disconnect(reason)?;
380        self.close().await
381    }
382}
383
384// S must also be `Sink` because we need to be able to respond with ping messages to follow the
385// protocol
386impl<S> Stream for P2PStream<S>
387where
388    S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
389{
390    type Item = Result<BytesMut, P2PStreamError>;
391
392    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
393        let this = self.get_mut();
394
395        if this.disconnecting {
396            // if disconnecting, stop reading messages
397            return Poll::Ready(None)
398        }
399
400        // we should loop here to ensure we don't return Poll::Pending if we have a message to
401        // return behind any pings we need to respond to
402        while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) {
403            let bytes = match res {
404                Some(Ok(bytes)) => bytes,
405                Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
406                None => return Poll::Ready(None),
407            };
408
409            if bytes.is_empty() {
410                // empty messages are not allowed
411                return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
412            }
413
414            // first decode disconnect reasons, because they can be encoded in a variety of forms
415            // over the wire, in both snappy compressed and uncompressed forms.
416            //
417            // see: [crate::disconnect::tests::test_decode_known_reasons]
418            let id = bytes[0];
419            if id == P2PMessageID::Disconnect as u8 {
420                // We can't handle the error here because disconnect reasons are encoded as both:
421                // * snappy compressed, AND
422                // * uncompressed
423                // over the network.
424                //
425                // If the decoding succeeds, we already checked the id and know this is a
426                // disconnect message, so we can return with the reason.
427                //
428                // If the decoding fails, we continue, and will attempt to decode it again if the
429                // message is snappy compressed. Failure handling in that step is the primary point
430                // where an error is returned if the disconnect reason is malformed.
431                if let Ok(reason) = DisconnectReason::decode(&mut &bytes[1..]) {
432                    return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
433                }
434            }
435
436            // first check that the compressed message length does not exceed the max
437            // payload size
438            let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
439            if decompressed_len > MAX_PAYLOAD_SIZE {
440                return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
441                    message_size: decompressed_len,
442                    max_size: MAX_PAYLOAD_SIZE,
443                })))
444            }
445
446            // create a buffer to hold the decompressed message, adding a byte to the length for
447            // the message ID byte, which is the first byte in this buffer
448            let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
449
450            // each message following a successful handshake is compressed with snappy, so we need
451            // to decompress the message before we can decode it.
452            this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..]).map_err(|err| {
453                debug!(
454                    %err,
455                    msg=%hex::encode(&bytes[1..]),
456                    "error decompressing p2p message"
457                );
458                err
459            })?;
460
461            match id {
462                _ if id == P2PMessageID::Ping as u8 => {
463                    trace!("Received Ping, Sending Pong");
464                    this.send_pong();
465                    // This is required because the `Sink` may not be polled externally, and if
466                    // that happens, the pong will never be sent.
467                    cx.waker().wake_by_ref();
468                }
469                _ if id == P2PMessageID::Hello as u8 => {
470                    // we have received a hello message outside of the handshake, so we will return
471                    // an error
472                    return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
473                        P2PHandshakeError::HelloNotInHandshake,
474                    ))))
475                }
476                _ if id == P2PMessageID::Pong as u8 => {
477                    // if we were waiting for a pong, this will reset the pinger state
478                    this.pinger.on_pong()?
479                }
480                _ if id == P2PMessageID::Disconnect as u8 => {
481                    // At this point, the `decompress_buf` contains the snappy decompressed
482                    // disconnect message.
483                    //
484                    // It's possible we already tried to RLP decode this, but it was snappy
485                    // compressed, so we need to RLP decode it again.
486                    let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).inspect_err(|err| {
487                        debug!(
488                            %err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
489                        );
490                    })?;
491                    return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
492                }
493                _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
494                    // we have received an unknown reserved message
495                    return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
496                }
497                _ => {
498                    // we have received a message that is outside the `p2p` reserved message space,
499                    // so it is a subprotocol message.
500
501                    // Peers must be able to identify messages meant for different subprotocols
502                    // using a single message ID byte, and those messages must be distinct from the
503                    // lower-level `p2p` messages.
504                    //
505                    // To ensure that messages for subprotocols are distinct from messages meant
506                    // for the `p2p` capability, message IDs 0x00 - 0x0f are reserved for `p2p`
507                    // messages, so subprotocol messages must have an ID of 0x10 or higher.
508                    //
509                    // To ensure that messages for two different capabilities are distinct from
510                    // each other, all shared capabilities are first ordered lexicographically.
511                    // Message IDs are then reserved in this order, starting at 0x10, reserving a
512                    // message ID for each message the capability supports.
513                    //
514                    // For example, if the shared capabilities are `eth/67` (containing 10
515                    // messages), and "qrs/65" (containing 8 messages):
516                    //
517                    //  * The special case of `p2p`: `p2p` is reserved message IDs 0x00 - 0x0f.
518                    //  * `eth/67` is reserved message IDs 0x10 - 0x19.
519                    //  * `qrs/65` is reserved message IDs 0x1a - 0x21.
520                    //
521                    decompress_buf[0] = bytes[0] - MAX_RESERVED_MESSAGE_ID - 1;
522
523                    return Poll::Ready(Some(Ok(decompress_buf)))
524                }
525            }
526        }
527
528        Poll::Pending
529    }
530}
531
532impl<S> Sink<Bytes> for P2PStream<S>
533where
534    S: Sink<Bytes, Error = io::Error> + Unpin,
535{
536    type Error = P2PStreamError;
537
538    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
539        let mut this = self.as_mut();
540
541        // poll the pinger to determine if we should send a ping
542        match this.pinger.poll_ping(cx) {
543            Poll::Pending => {}
544            Poll::Ready(Ok(PingerEvent::Ping)) => {
545                this.send_ping();
546            }
547            _ => {
548                // encode the disconnect message
549                this.start_disconnect(DisconnectReason::PingTimeout)?;
550
551                // End the stream after ping related error
552                return Poll::Ready(Ok(()))
553            }
554        }
555
556        match this.inner.poll_ready_unpin(cx) {
557            Poll::Pending => {}
558            Poll::Ready(Err(err)) => return Poll::Ready(Err(P2PStreamError::Io(err))),
559            Poll::Ready(Ok(())) => {
560                let flushed = this.poll_flush(cx);
561                if flushed.is_ready() {
562                    return flushed
563                }
564            }
565        }
566
567        if self.has_outgoing_capacity() {
568            // still has capacity
569            Poll::Ready(Ok(()))
570        } else {
571            Poll::Pending
572        }
573    }
574
575    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
576        if item.len() > MAX_PAYLOAD_SIZE {
577            return Err(P2PStreamError::MessageTooBig {
578                message_size: item.len(),
579                max_size: MAX_PAYLOAD_SIZE,
580            })
581        }
582
583        if item.is_empty() {
584            // empty messages are not allowed
585            return Err(P2PStreamError::EmptyProtocolMessage)
586        }
587
588        // ensure we have free capacity
589        if !self.has_outgoing_capacity() {
590            return Err(P2PStreamError::SendBufferFull)
591        }
592
593        let this = self.project();
594
595        let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
596        let compressed_size =
597            this.encoder.compress(&item[1..], &mut compressed[1..]).map_err(|err| {
598                debug!(
599                    %err,
600                    msg=%hex::encode(&item[1..]),
601                    "error compressing p2p message"
602                );
603                err
604            })?;
605
606        // truncate the compressed buffer to the actual compressed size (plus one for the message
607        // id)
608        compressed.truncate(compressed_size + 1);
609
610        // all messages sent in this stream are subprotocol messages, so we need to switch the
611        // message id based on the offset
612        compressed[0] = item[0] + MAX_RESERVED_MESSAGE_ID + 1;
613        this.outgoing_messages.push_back(compressed.freeze());
614
615        Ok(())
616    }
617
618    /// Returns `Poll::Ready(Ok(()))` when no buffered items remain.
619    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
620        let mut this = self.project();
621        let poll_res = loop {
622            match this.inner.as_mut().poll_ready(cx) {
623                Poll::Pending => break Poll::Pending,
624                Poll::Ready(Err(err)) => break Poll::Ready(Err(err.into())),
625                Poll::Ready(Ok(())) => {
626                    let Some(message) = this.outgoing_messages.pop_front() else {
627                        break Poll::Ready(Ok(()))
628                    };
629                    if let Err(err) = this.inner.as_mut().start_send(message) {
630                        break Poll::Ready(Err(err.into()))
631                    }
632                }
633            }
634        };
635
636        ready!(this.inner.as_mut().poll_flush(cx))?;
637
638        poll_res
639    }
640
641    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
642        ready!(self.as_mut().poll_flush(cx))?;
643        ready!(self.project().inner.poll_close(cx))?;
644
645        Poll::Ready(Ok(()))
646    }
647}
648
649/// This represents only the reserved `p2p` subprotocol messages.
650#[derive(Debug, Clone, PartialEq, Eq)]
651#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
652#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
653#[add_arbitrary_tests(rlp)]
654pub enum P2PMessage {
655    /// The first packet sent over the connection, and sent once by both sides.
656    Hello(HelloMessage),
657
658    /// Inform the peer that a disconnection is imminent; if received, a peer should disconnect
659    /// immediately.
660    Disconnect(DisconnectReason),
661
662    /// Requests an immediate reply of [`P2PMessage::Pong`] from the peer.
663    Ping,
664
665    /// Reply to the peer's [`P2PMessage::Ping`] packet.
666    Pong,
667}
668
669impl P2PMessage {
670    /// Gets the [`P2PMessageID`] for the given message.
671    pub const fn message_id(&self) -> P2PMessageID {
672        match self {
673            Self::Hello(_) => P2PMessageID::Hello,
674            Self::Disconnect(_) => P2PMessageID::Disconnect,
675            Self::Ping => P2PMessageID::Ping,
676            Self::Pong => P2PMessageID::Pong,
677        }
678    }
679}
680
681impl Encodable for P2PMessage {
682    /// The [`Encodable`] implementation for [`P2PMessage::Ping`] and [`P2PMessage::Pong`] encodes
683    /// the message as RLP, and prepends a snappy header to the RLP bytes for all variants except
684    /// the [`P2PMessage::Hello`] variant, because the hello message is never compressed in the
685    /// `p2p` subprotocol.
686    fn encode(&self, out: &mut dyn BufMut) {
687        (self.message_id() as u8).encode(out);
688        match self {
689            Self::Hello(msg) => msg.encode(out),
690            Self::Disconnect(msg) => msg.encode(out),
691            Self::Ping => {
692                // Ping payload is _always_ snappy encoded
693                out.put_u8(0x01);
694                out.put_u8(0x00);
695                out.put_u8(EMPTY_LIST_CODE);
696            }
697            Self::Pong => {
698                // Pong payload is _always_ snappy encoded
699                out.put_u8(0x01);
700                out.put_u8(0x00);
701                out.put_u8(EMPTY_LIST_CODE);
702            }
703        }
704    }
705
706    fn length(&self) -> usize {
707        let payload_len = match self {
708            Self::Hello(msg) => msg.length(),
709            Self::Disconnect(msg) => msg.length(),
710            // id + snappy encoded payload
711            Self::Ping | Self::Pong => 3, // len([0x01, 0x00, 0xc0]) = 3
712        };
713        payload_len + 1 // (1 for length of p2p message id)
714    }
715}
716
717impl Decodable for P2PMessage {
718    /// The [`Decodable`] implementation for [`P2PMessage`] assumes that each of the message
719    /// variants are snappy compressed, except for the [`P2PMessage::Hello`] variant since the
720    /// hello message is never compressed in the `p2p` subprotocol.
721    ///
722    /// The [`Decodable`] implementation for [`P2PMessage::Ping`] and [`P2PMessage::Pong`] expects
723    /// a snappy encoded payload, see [`Encodable`] implementation.
724    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
725        /// Removes the snappy prefix from the Ping/Pong buffer
726        fn advance_snappy_ping_pong_payload(buf: &mut &[u8]) -> alloy_rlp::Result<()> {
727            if buf.len() < 3 {
728                return Err(RlpError::InputTooShort)
729            }
730            if buf[..3] != [0x01, 0x00, EMPTY_LIST_CODE] {
731                return Err(RlpError::Custom("expected snappy payload"))
732            }
733            buf.advance(3);
734            Ok(())
735        }
736
737        let message_id = u8::decode(&mut &buf[..])?;
738        let id = P2PMessageID::try_from(message_id)
739            .or(Err(RlpError::Custom("unknown p2p message id")))?;
740        buf.advance(1);
741        match id {
742            P2PMessageID::Hello => Ok(Self::Hello(HelloMessage::decode(buf)?)),
743            P2PMessageID::Disconnect => Ok(Self::Disconnect(DisconnectReason::decode(buf)?)),
744            P2PMessageID::Ping => {
745                advance_snappy_ping_pong_payload(buf)?;
746                Ok(Self::Ping)
747            }
748            P2PMessageID::Pong => {
749                advance_snappy_ping_pong_payload(buf)?;
750                Ok(Self::Pong)
751            }
752        }
753    }
754}
755
756/// Message IDs for `p2p` subprotocol messages.
757#[derive(Debug, Copy, Clone, Eq, PartialEq)]
758pub enum P2PMessageID {
759    /// Message ID for the [`P2PMessage::Hello`] message.
760    Hello = 0x00,
761
762    /// Message ID for the [`P2PMessage::Disconnect`] message.
763    Disconnect = 0x01,
764
765    /// Message ID for the [`P2PMessage::Ping`] message.
766    Ping = 0x02,
767
768    /// Message ID for the [`P2PMessage::Pong`] message.
769    Pong = 0x03,
770}
771
772impl From<P2PMessage> for P2PMessageID {
773    fn from(msg: P2PMessage) -> Self {
774        match msg {
775            P2PMessage::Hello(_) => Self::Hello,
776            P2PMessage::Disconnect(_) => Self::Disconnect,
777            P2PMessage::Ping => Self::Ping,
778            P2PMessage::Pong => Self::Pong,
779        }
780    }
781}
782
783impl TryFrom<u8> for P2PMessageID {
784    type Error = P2PStreamError;
785
786    fn try_from(id: u8) -> Result<Self, Self::Error> {
787        match id {
788            0x00 => Ok(Self::Hello),
789            0x01 => Ok(Self::Disconnect),
790            0x02 => Ok(Self::Ping),
791            0x03 => Ok(Self::Pong),
792            _ => Err(P2PStreamError::UnknownReservedMessageId(id)),
793        }
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800    use crate::{capability::SharedCapability, test_utils::eth_hello, EthVersion, ProtocolVersion};
801    use tokio::net::{TcpListener, TcpStream};
802    use tokio_util::codec::Decoder;
803
804    #[tokio::test]
805    async fn test_can_disconnect() {
806        reth_tracing::init_test_tracing();
807        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
808        let local_addr = listener.local_addr().unwrap();
809
810        let expected_disconnect = DisconnectReason::UselessPeer;
811
812        let handle = tokio::spawn(async move {
813            // roughly based off of the design of tokio::net::TcpListener
814            let (incoming, _) = listener.accept().await.unwrap();
815            let stream = crate::PassthroughCodec::default().framed(incoming);
816
817            let (server_hello, _) = eth_hello();
818
819            let (mut p2p_stream, _) =
820                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
821
822            p2p_stream.disconnect(expected_disconnect).await.unwrap();
823        });
824
825        let outgoing = TcpStream::connect(local_addr).await.unwrap();
826        let sink = crate::PassthroughCodec::default().framed(outgoing);
827
828        let (client_hello, _) = eth_hello();
829
830        let (mut p2p_stream, _) =
831            UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
832
833        let err = p2p_stream.next().await.unwrap().unwrap_err();
834        match err {
835            P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
836            e => panic!("unexpected err: {e}"),
837        }
838
839        handle.await.unwrap();
840    }
841
842    #[tokio::test]
843    async fn test_can_disconnect_weird_disconnect_encoding() {
844        reth_tracing::init_test_tracing();
845        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
846        let local_addr = listener.local_addr().unwrap();
847
848        let expected_disconnect = DisconnectReason::SubprotocolSpecific;
849
850        let handle = tokio::spawn(async move {
851            // roughly based off of the design of tokio::net::TcpListener
852            let (incoming, _) = listener.accept().await.unwrap();
853            let stream = crate::PassthroughCodec::default().framed(incoming);
854
855            let (server_hello, _) = eth_hello();
856
857            let (mut p2p_stream, _) =
858                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
859
860            // Unrolled `disconnect` method, without compression
861            p2p_stream.outgoing_messages.clear();
862
863            p2p_stream.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(
864                P2PMessage::Disconnect(DisconnectReason::SubprotocolSpecific),
865            )));
866            p2p_stream.disconnecting = true;
867            p2p_stream.close().await.unwrap();
868        });
869
870        let outgoing = TcpStream::connect(local_addr).await.unwrap();
871        let sink = crate::PassthroughCodec::default().framed(outgoing);
872
873        let (client_hello, _) = eth_hello();
874
875        let (mut p2p_stream, _) =
876            UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
877
878        let err = p2p_stream.next().await.unwrap().unwrap_err();
879        match err {
880            P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
881            e => panic!("unexpected err: {e}"),
882        }
883
884        handle.await.unwrap();
885    }
886
887    #[tokio::test]
888    async fn test_handshake_passthrough() {
889        // create a p2p stream and server, then confirm that the two are authed
890        // create tcpstream
891        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
892        let local_addr = listener.local_addr().unwrap();
893
894        let handle = tokio::spawn(async move {
895            // roughly based off of the design of tokio::net::TcpListener
896            let (incoming, _) = listener.accept().await.unwrap();
897            let stream = crate::PassthroughCodec::default().framed(incoming);
898
899            let (server_hello, _) = eth_hello();
900
901            let unauthed_stream = UnauthedP2PStream::new(stream);
902            let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
903
904            // ensure that the two share a single capability, eth67
905            assert_eq!(
906                *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
907                SharedCapability::Eth {
908                    version: EthVersion::Eth67,
909                    offset: MAX_RESERVED_MESSAGE_ID + 1
910                }
911            );
912        });
913
914        let outgoing = TcpStream::connect(local_addr).await.unwrap();
915        let sink = crate::PassthroughCodec::default().framed(outgoing);
916
917        let (client_hello, _) = eth_hello();
918
919        let unauthed_stream = UnauthedP2PStream::new(sink);
920        let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
921
922        // ensure that the two share a single capability, eth67
923        assert_eq!(
924            *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
925            SharedCapability::Eth {
926                version: EthVersion::Eth67,
927                offset: MAX_RESERVED_MESSAGE_ID + 1
928            }
929        );
930
931        // make sure the server receives the message and asserts before ending the test
932        handle.await.unwrap();
933    }
934
935    #[tokio::test]
936    async fn test_handshake_disconnect() {
937        // create a p2p stream and server, then confirm that the two are authed
938        // create tcpstream
939        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
940        let local_addr = listener.local_addr().unwrap();
941
942        let handle = tokio::spawn(Box::pin(async move {
943            // roughly based off of the design of tokio::net::TcpListener
944            let (incoming, _) = listener.accept().await.unwrap();
945            let stream = crate::PassthroughCodec::default().framed(incoming);
946
947            let (server_hello, _) = eth_hello();
948
949            let unauthed_stream = UnauthedP2PStream::new(stream);
950            match unauthed_stream.handshake(server_hello.clone()).await {
951                Ok((_, hello)) => {
952                    panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
953                }
954                Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
955                    assert_ne!(expected, got);
956                    assert_eq!(expected, server_hello.protocol_version);
957                }
958                Err(other_err) => {
959                    panic!("expected mismatched protocol version error, got {other_err:?}")
960                }
961            }
962        }));
963
964        let outgoing = TcpStream::connect(local_addr).await.unwrap();
965        let sink = crate::PassthroughCodec::default().framed(outgoing);
966
967        let (mut client_hello, _) = eth_hello();
968
969        // modify the hello to include an incompatible p2p protocol version
970        client_hello.protocol_version = ProtocolVersion::V4;
971
972        let unauthed_stream = UnauthedP2PStream::new(sink);
973        match unauthed_stream.handshake(client_hello.clone()).await {
974            Ok((_, hello)) => {
975                panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
976            }
977            Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
978                assert_ne!(expected, got);
979                assert_eq!(expected, client_hello.protocol_version);
980            }
981            Err(other_err) => {
982                panic!("expected mismatched protocol version error, got {other_err:?}")
983            }
984        }
985
986        // make sure the server receives the message and asserts before ending the test
987        handle.await.unwrap();
988    }
989
990    #[test]
991    fn snappy_decode_encode_ping() {
992        let snappy_ping = b"\x02\x01\0\xc0";
993        let ping = P2PMessage::decode(&mut &snappy_ping[..]).unwrap();
994        assert!(matches!(ping, P2PMessage::Ping));
995        assert_eq!(alloy_rlp::encode(ping), &snappy_ping[..]);
996    }
997
998    #[test]
999    fn snappy_decode_encode_pong() {
1000        let snappy_pong = b"\x03\x01\0\xc0";
1001        let pong = P2PMessage::decode(&mut &snappy_pong[..]).unwrap();
1002        assert!(matches!(pong, P2PMessage::Pong));
1003        assert_eq!(alloy_rlp::encode(pong), &snappy_pong[..]);
1004    }
1005}