reth_eth_wire/
p2pstream.rs

1use crate::{
2    capability::SharedCapabilities,
3    disconnect::CanDisconnect,
4    errors::{P2PHandshakeError, P2PStreamError},
5    pinger::{Pinger, PingerEvent},
6    DisconnectReason, HelloMessage, HelloMessageWithProtocols,
7};
8use alloy_primitives::{
9    bytes::{Buf, BufMut, Bytes, BytesMut},
10    hex,
11};
12use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE};
13use futures::{Sink, SinkExt, StreamExt};
14use pin_project::pin_project;
15use reth_codecs::add_arbitrary_tests;
16use reth_metrics::metrics::counter;
17use reth_primitives_traits::GotExpected;
18use std::{
19    collections::VecDeque,
20    future::Future,
21    io,
22    pin::Pin,
23    task::{ready, Context, Poll},
24    time::Duration,
25};
26use tokio_stream::Stream;
27use tracing::{debug, trace};
28
29#[cfg(feature = "serde")]
30use serde::{Deserialize, Serialize};
31
32/// [`MAX_PAYLOAD_SIZE`] is the maximum size of an uncompressed message payload.
33/// This is defined in [EIP-706](https://eips.ethereum.org/EIPS/eip-706).
34const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36/// [`MAX_RESERVED_MESSAGE_ID`] is the maximum message ID reserved for the `p2p` subprotocol. If
37/// there are any incoming messages with an ID greater than this, they are subprotocol messages.
38pub const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f;
39
40/// [`MAX_P2P_MESSAGE_ID`] is the maximum message ID in use for the `p2p` subprotocol.
41const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
42
43/// [`HANDSHAKE_TIMEOUT`] determines the amount of time to wait before determining that a `p2p`
44/// handshake has timed out.
45pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
46
47/// [`PING_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` ping has
48/// timed out.
49const PING_TIMEOUT: Duration = Duration::from_secs(15);
50
51/// [`PING_INTERVAL`] determines the amount of time to wait between sending `p2p` ping messages
52/// when the peer is responsive.
53const PING_INTERVAL: Duration = Duration::from_secs(60);
54
55/// [`MAX_P2P_CAPACITY`] is the maximum number of messages that can be buffered to be sent in the
56/// `p2p` stream.
57///
58/// Note: this default is rather low because it is expected that the [`P2PStream`] wraps an
59/// [`ECIESStream`](reth_ecies::stream::ECIESStream) which internally already buffers a few MB of
60/// encoded data.
61const MAX_P2P_CAPACITY: usize = 2;
62
63/// An un-authenticated [`P2PStream`]. This is consumed and returns a [`P2PStream`] after the
64/// `Hello` handshake is completed.
65#[pin_project]
66#[derive(Debug)]
67pub struct UnauthedP2PStream<S> {
68    #[pin]
69    inner: S,
70}
71
72impl<S> UnauthedP2PStream<S> {
73    /// Create a new `UnauthedP2PStream` from a type `S` which implements `Stream` and `Sink`.
74    pub const fn new(inner: S) -> Self {
75        Self { inner }
76    }
77
78    /// Returns a reference to the inner stream.
79    pub const fn inner(&self) -> &S {
80        &self.inner
81    }
82}
83
84impl<S> UnauthedP2PStream<S>
85where
86    S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
87{
88    /// Consumes the `UnauthedP2PStream` and returns a `P2PStream` after the `Hello` handshake is
89    /// completed successfully. This also returns the `Hello` message sent by the remote peer.
90    pub async fn handshake(
91        mut self,
92        hello: HelloMessageWithProtocols,
93    ) -> Result<(P2PStream<S>, HelloMessage), P2PStreamError> {
94        trace!(?hello, "sending p2p hello to peer");
95
96        // send our hello message with the Sink
97        self.inner.send(alloy_rlp::encode(P2PMessage::Hello(hello.message())).into()).await?;
98
99        let first_message_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next())
100            .await
101            .or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))?
102            .ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??;
103
104        // Check that the uncompressed message length does not exceed the max payload size.
105        // Note: The first message (Hello/Disconnect) is not snappy compressed. We will check the
106        // decompressed length again for subsequent messages after the handshake.
107        if first_message_bytes.len() > MAX_PAYLOAD_SIZE {
108            return Err(P2PStreamError::MessageTooBig {
109                message_size: first_message_bytes.len(),
110                max_size: MAX_PAYLOAD_SIZE,
111            })
112        }
113
114        // The first message sent MUST be a hello OR disconnect message
115        //
116        // If the first message is a disconnect message, we should not decode using
117        // Decodable::decode, because the first message (either Disconnect or Hello) is not snappy
118        // compressed, and the Decodable implementation assumes that non-hello messages are snappy
119        // compressed.
120        let their_hello = match P2PMessage::decode(&mut &first_message_bytes[..]) {
121            Ok(P2PMessage::Hello(hello)) => Ok(hello),
122            Ok(P2PMessage::Disconnect(reason)) => {
123                if matches!(reason, DisconnectReason::TooManyPeers) {
124                    // Too many peers is a very common disconnect reason that spams the DEBUG logs
125                    trace!(%reason, "Disconnected by peer during handshake");
126                } else {
127                    debug!(%reason, "Disconnected by peer during handshake");
128                };
129                counter!("p2pstream.disconnected_errors").increment(1);
130                Err(P2PStreamError::HandshakeError(P2PHandshakeError::Disconnected(reason)))
131            }
132            Err(err) => {
133                debug!(%err, msg=%hex::encode(&first_message_bytes), "Failed to decode first message from peer");
134                Err(P2PStreamError::HandshakeError(err.into()))
135            }
136            Ok(msg) => {
137                debug!(?msg, "expected hello message but received another message");
138                Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake))
139            }
140        }?;
141
142        trace!(
143            hello=?their_hello,
144            "validating incoming p2p hello from peer"
145        );
146
147        if (hello.protocol_version as u8) != their_hello.protocol_version as u8 {
148            // send a disconnect message notifying the peer of the protocol version mismatch
149            self.send_disconnect(DisconnectReason::IncompatibleP2PProtocolVersion).await?;
150            return Err(P2PStreamError::MismatchedProtocolVersion(GotExpected {
151                got: their_hello.protocol_version,
152                expected: hello.protocol_version,
153            }))
154        }
155
156        // determine shared capabilities (currently returns only one capability)
157        let capability_res =
158            SharedCapabilities::try_new(hello.protocols, their_hello.capabilities.clone());
159
160        let shared_capability = match capability_res {
161            Err(err) => {
162                // we don't share any capabilities, send a disconnect message
163                self.send_disconnect(DisconnectReason::UselessPeer).await?;
164                Err(err)
165            }
166            Ok(cap) => Ok(cap),
167        }?;
168
169        let stream = P2PStream::new(self.inner, shared_capability);
170
171        Ok((stream, their_hello))
172    }
173}
174
175impl<S> UnauthedP2PStream<S>
176where
177    S: Sink<Bytes, Error = io::Error> + Unpin,
178{
179    /// Send a disconnect message during the handshake. This is sent without snappy compression.
180    pub async fn send_disconnect(
181        &mut self,
182        reason: DisconnectReason,
183    ) -> Result<(), P2PStreamError> {
184        trace!(
185            %reason,
186            "Sending disconnect message during the handshake",
187        );
188        self.inner
189            .send(Bytes::from(alloy_rlp::encode(P2PMessage::Disconnect(reason))))
190            .await
191            .map_err(P2PStreamError::Io)
192    }
193}
194
195impl<S> CanDisconnect<Bytes> for P2PStream<S>
196where
197    S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
198{
199    fn disconnect(
200        &mut self,
201        reason: DisconnectReason,
202    ) -> Pin<Box<dyn Future<Output = Result<(), P2PStreamError>> + Send + '_>> {
203        Box::pin(async move { self.disconnect(reason).await })
204    }
205}
206
207/// A `P2PStream` wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
208/// protocol messages.
209///
210/// This stream supports multiple shared capabilities, that were negotiated during the handshake.
211///
212/// ### Message-ID based multiplexing
213///
214/// > Each capability is given as much of the message-ID space as it needs. All such capabilities
215/// > must statically specify how many message IDs they require. On connection and reception of the
216/// > Hello message, both peers have equivalent information about what capabilities they share
217/// > (including versions) and are able to form consensus over the composition of message ID space.
218///
219/// > Message IDs are assumed to be compact from ID 0x10 onwards (0x00-0x0f is reserved for the
220/// > "p2p" capability) and given to each shared (equal-version, equal-name) capability in
221/// > alphabetic order. Capability names are case-sensitive. Capabilities which are not shared are
222/// > ignored. If multiple versions are shared of the same (equal name) capability, the numerically
223/// > highest wins, others are ignored.
224///
225/// See also <https://github.com/ethereum/devp2p/blob/master/rlpx.md#message-id-based-multiplexing>
226///
227/// This stream emits _non-empty_ Bytes that start with the normalized message id, so that the first
228/// byte of each message starts from 0. If this stream only supports a single capability, for
229/// example `eth` then the first byte of each message will match
230/// [EthMessageID](reth_eth_wire_types::message::EthMessageID).
231#[pin_project]
232#[derive(Debug)]
233pub struct P2PStream<S> {
234    #[pin]
235    inner: S,
236
237    /// The snappy encoder used for compressing outgoing messages
238    encoder: snap::raw::Encoder,
239
240    /// The snappy decoder used for decompressing incoming messages
241    decoder: snap::raw::Decoder,
242
243    /// The state machine used for keeping track of the peer's ping status.
244    pinger: Pinger,
245
246    /// The supported capability for this stream.
247    shared_capabilities: SharedCapabilities,
248
249    /// Outgoing messages buffered for sending to the underlying stream.
250    outgoing_messages: VecDeque<Bytes>,
251
252    /// Maximum number of messages that we can buffer here before the [Sink] impl returns
253    /// [`Poll::Pending`].
254    outgoing_message_buffer_capacity: usize,
255
256    /// Whether this stream is currently in the process of disconnecting by sending a disconnect
257    /// message.
258    disconnecting: bool,
259}
260
261impl<S> P2PStream<S> {
262    /// Create a new [`P2PStream`] from the provided stream.
263    /// New [`P2PStream`]s are assumed to have completed the `p2p` handshake successfully and are
264    /// ready to send and receive subprotocol messages.
265    pub fn new(inner: S, shared_capabilities: SharedCapabilities) -> Self {
266        Self {
267            inner,
268            encoder: snap::raw::Encoder::new(),
269            decoder: snap::raw::Decoder::new(),
270            pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
271            shared_capabilities,
272            outgoing_messages: VecDeque::new(),
273            outgoing_message_buffer_capacity: MAX_P2P_CAPACITY,
274            disconnecting: false,
275        }
276    }
277
278    /// Returns a reference to the inner stream.
279    pub const fn inner(&self) -> &S {
280        &self.inner
281    }
282
283    /// Sets a custom outgoing message buffer capacity.
284    ///
285    /// # Panics
286    ///
287    /// If the provided capacity is `0`.
288    pub const fn set_outgoing_message_buffer_capacity(&mut self, capacity: usize) {
289        self.outgoing_message_buffer_capacity = capacity;
290    }
291
292    /// Returns the shared capabilities for this stream.
293    ///
294    /// This includes all the shared capabilities that were negotiated during the handshake and
295    /// their offsets based on the number of messages of each capability.
296    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
297        &self.shared_capabilities
298    }
299
300    /// Returns `true` if the stream has outgoing capacity.
301    fn has_outgoing_capacity(&self) -> bool {
302        self.outgoing_messages.len() < self.outgoing_message_buffer_capacity
303    }
304
305    /// Queues in a _snappy_ encoded [`P2PMessage::Pong`] message.
306    fn send_pong(&mut self) {
307        self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Pong)));
308    }
309
310    /// Queues in a _snappy_ encoded [`P2PMessage::Ping`] message.
311    pub fn send_ping(&mut self) {
312        self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Ping)));
313    }
314}
315
316/// Gracefully disconnects the connection by sending a disconnect message and stop reading new
317/// messages.
318pub trait DisconnectP2P {
319    /// Starts to gracefully disconnect.
320    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>;
321
322    /// Returns `true` if the connection is about to disconnect.
323    fn is_disconnecting(&self) -> bool;
324}
325
326impl<S> DisconnectP2P for P2PStream<S> {
327    /// Starts to gracefully disconnect the connection by sending a Disconnect message and stop
328    /// reading new messages.
329    ///
330    /// Once disconnect process has started, the [`Stream`] will terminate immediately.
331    ///
332    /// # Errors
333    ///
334    /// Returns an error only if the message fails to compress.
335    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
336        // clear any buffered messages and queue in
337        self.outgoing_messages.clear();
338        let disconnect = P2PMessage::Disconnect(reason);
339        let mut buf = Vec::with_capacity(disconnect.length());
340        disconnect.encode(&mut buf);
341
342        let mut compressed = vec![0u8; 1 + snap::raw::max_compress_len(buf.len() - 1)];
343        let compressed_size =
344            self.encoder.compress(&buf[1..], &mut compressed[1..]).map_err(|err| {
345                debug!(
346                    %err,
347                    msg=%hex::encode(&buf[1..]),
348                    "error compressing disconnect"
349                );
350                err
351            })?;
352
353        // truncate the compressed buffer to the actual compressed size (plus one for the message
354        // id)
355        compressed.truncate(compressed_size + 1);
356
357        // we do not add the capability offset because the disconnect message is a `p2p` reserved
358        // message
359        compressed[0] = buf[0];
360
361        self.outgoing_messages.push_back(compressed.into());
362        self.disconnecting = true;
363        Ok(())
364    }
365
366    fn is_disconnecting(&self) -> bool {
367        self.disconnecting
368    }
369}
370
371impl<S> P2PStream<S>
372where
373    S: Sink<Bytes, Error = io::Error> + Unpin + Send,
374{
375    /// Disconnects the connection by sending a disconnect message.
376    ///
377    /// This future resolves once the disconnect message has been sent and the stream has been
378    /// closed.
379    pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
380        self.start_disconnect(reason)?;
381        self.close().await
382    }
383}
384
385// S must also be `Sink` because we need to be able to respond with ping messages to follow the
386// protocol
387impl<S> Stream for P2PStream<S>
388where
389    S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
390{
391    type Item = Result<BytesMut, P2PStreamError>;
392
393    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
394        let this = self.get_mut();
395
396        if this.disconnecting {
397            // if disconnecting, stop reading messages
398            return Poll::Ready(None)
399        }
400
401        // we should loop here to ensure we don't return Poll::Pending if we have a message to
402        // return behind any pings we need to respond to
403        while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) {
404            let bytes = match res {
405                Some(Ok(bytes)) => bytes,
406                Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
407                None => return Poll::Ready(None),
408            };
409
410            if bytes.is_empty() {
411                // empty messages are not allowed
412                return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
413            }
414
415            // first decode disconnect reasons, because they can be encoded in a variety of forms
416            // over the wire, in both snappy compressed and uncompressed forms.
417            //
418            // see: [crate::disconnect::tests::test_decode_known_reasons]
419            let id = bytes[0];
420            if id == P2PMessageID::Disconnect as u8 {
421                // We can't handle the error here because disconnect reasons are encoded as both:
422                // * snappy compressed, AND
423                // * uncompressed
424                // over the network.
425                //
426                // If the decoding succeeds, we already checked the id and know this is a
427                // disconnect message, so we can return with the reason.
428                //
429                // If the decoding fails, we continue, and will attempt to decode it again if the
430                // message is snappy compressed. Failure handling in that step is the primary point
431                // where an error is returned if the disconnect reason is malformed.
432                if let Ok(reason) = DisconnectReason::decode(&mut &bytes[1..]) {
433                    return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
434                }
435            }
436
437            // first check that the compressed message length does not exceed the max
438            // payload size
439            let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
440            if decompressed_len > MAX_PAYLOAD_SIZE {
441                return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
442                    message_size: decompressed_len,
443                    max_size: MAX_PAYLOAD_SIZE,
444                })))
445            }
446
447            // create a buffer to hold the decompressed message, adding a byte to the length for
448            // the message ID byte, which is the first byte in this buffer
449            let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
450
451            // each message following a successful handshake is compressed with snappy, so we need
452            // to decompress the message before we can decode it.
453            this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..]).map_err(|err| {
454                debug!(
455                    %err,
456                    msg=%hex::encode(&bytes[1..]),
457                    "error decompressing p2p message"
458                );
459                err
460            })?;
461
462            match id {
463                _ if id == P2PMessageID::Ping as u8 => {
464                    trace!("Received Ping, Sending Pong");
465                    this.send_pong();
466                    // This is required because the `Sink` may not be polled externally, and if
467                    // that happens, the pong will never be sent.
468                    cx.waker().wake_by_ref();
469                }
470                _ if id == P2PMessageID::Hello as u8 => {
471                    // we have received a hello message outside of the handshake, so we will return
472                    // an error
473                    return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
474                        P2PHandshakeError::HelloNotInHandshake,
475                    ))))
476                }
477                _ if id == P2PMessageID::Pong as u8 => {
478                    // if we were waiting for a pong, this will reset the pinger state
479                    this.pinger.on_pong()?
480                }
481                _ if id == P2PMessageID::Disconnect as u8 => {
482                    // At this point, the `decompress_buf` contains the snappy decompressed
483                    // disconnect message.
484                    //
485                    // It's possible we already tried to RLP decode this, but it was snappy
486                    // compressed, so we need to RLP decode it again.
487                    let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).inspect_err(|err| {
488                        debug!(
489                            %err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
490                        );
491                    })?;
492                    return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
493                }
494                _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
495                    // we have received an unknown reserved message
496                    return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
497                }
498                _ => {
499                    // we have received a message that is outside the `p2p` reserved message space,
500                    // so it is a subprotocol message.
501
502                    // Peers must be able to identify messages meant for different subprotocols
503                    // using a single message ID byte, and those messages must be distinct from the
504                    // lower-level `p2p` messages.
505                    //
506                    // To ensure that messages for subprotocols are distinct from messages meant
507                    // for the `p2p` capability, message IDs 0x00 - 0x0f are reserved for `p2p`
508                    // messages, so subprotocol messages must have an ID of 0x10 or higher.
509                    //
510                    // To ensure that messages for two different capabilities are distinct from
511                    // each other, all shared capabilities are first ordered lexicographically.
512                    // Message IDs are then reserved in this order, starting at 0x10, reserving a
513                    // message ID for each message the capability supports.
514                    //
515                    // For example, if the shared capabilities are `eth/67` (containing 10
516                    // messages), and "qrs/65" (containing 8 messages):
517                    //
518                    //  * The special case of `p2p`: `p2p` is reserved message IDs 0x00 - 0x0f.
519                    //  * `eth/67` is reserved message IDs 0x10 - 0x19.
520                    //  * `qrs/65` is reserved message IDs 0x1a - 0x21.
521                    //
522                    decompress_buf[0] = bytes[0] - MAX_RESERVED_MESSAGE_ID - 1;
523
524                    return Poll::Ready(Some(Ok(decompress_buf)))
525                }
526            }
527        }
528
529        Poll::Pending
530    }
531}
532
533impl<S> Sink<Bytes> for P2PStream<S>
534where
535    S: Sink<Bytes, Error = io::Error> + Unpin,
536{
537    type Error = P2PStreamError;
538
539    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
540        let mut this = self.as_mut();
541
542        // poll the pinger to determine if we should send a ping
543        match this.pinger.poll_ping(cx) {
544            Poll::Pending => {}
545            Poll::Ready(Ok(PingerEvent::Ping)) => {
546                this.send_ping();
547            }
548            _ => {
549                // encode the disconnect message
550                this.start_disconnect(DisconnectReason::PingTimeout)?;
551
552                // End the stream after ping related error
553                return Poll::Ready(Ok(()))
554            }
555        }
556
557        match this.inner.poll_ready_unpin(cx) {
558            Poll::Pending => {}
559            Poll::Ready(Err(err)) => return Poll::Ready(Err(P2PStreamError::Io(err))),
560            Poll::Ready(Ok(())) => {
561                let flushed = this.poll_flush(cx);
562                if flushed.is_ready() {
563                    return flushed
564                }
565            }
566        }
567
568        if self.has_outgoing_capacity() {
569            // still has capacity
570            Poll::Ready(Ok(()))
571        } else {
572            Poll::Pending
573        }
574    }
575
576    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
577        if item.len() > MAX_PAYLOAD_SIZE {
578            return Err(P2PStreamError::MessageTooBig {
579                message_size: item.len(),
580                max_size: MAX_PAYLOAD_SIZE,
581            })
582        }
583
584        if item.is_empty() {
585            // empty messages are not allowed
586            return Err(P2PStreamError::EmptyProtocolMessage)
587        }
588
589        // ensure we have free capacity
590        if !self.has_outgoing_capacity() {
591            return Err(P2PStreamError::SendBufferFull)
592        }
593
594        let this = self.project();
595
596        let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
597        let compressed_size =
598            this.encoder.compress(&item[1..], &mut compressed[1..]).map_err(|err| {
599                debug!(
600                    %err,
601                    msg=%hex::encode(&item[1..]),
602                    "error compressing p2p message"
603                );
604                err
605            })?;
606
607        // truncate the compressed buffer to the actual compressed size (plus one for the message
608        // id)
609        compressed.truncate(compressed_size + 1);
610
611        // all messages sent in this stream are subprotocol messages, so we need to switch the
612        // message id based on the offset
613        compressed[0] = item[0] + MAX_RESERVED_MESSAGE_ID + 1;
614        this.outgoing_messages.push_back(compressed.freeze());
615
616        Ok(())
617    }
618
619    /// Returns `Poll::Ready(Ok(()))` when no buffered items remain.
620    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
621        let mut this = self.project();
622        let poll_res = loop {
623            match this.inner.as_mut().poll_ready(cx) {
624                Poll::Pending => break Poll::Pending,
625                Poll::Ready(Err(err)) => break Poll::Ready(Err(err.into())),
626                Poll::Ready(Ok(())) => {
627                    let Some(message) = this.outgoing_messages.pop_front() else {
628                        break Poll::Ready(Ok(()))
629                    };
630                    if let Err(err) = this.inner.as_mut().start_send(message) {
631                        break Poll::Ready(Err(err.into()))
632                    }
633                }
634            }
635        };
636
637        ready!(this.inner.as_mut().poll_flush(cx))?;
638
639        poll_res
640    }
641
642    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
643        ready!(self.as_mut().poll_flush(cx))?;
644        ready!(self.project().inner.poll_close(cx))?;
645
646        Poll::Ready(Ok(()))
647    }
648}
649
650/// This represents only the reserved `p2p` subprotocol messages.
651#[derive(Debug, Clone, PartialEq, Eq)]
652#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
653#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
654#[add_arbitrary_tests(rlp)]
655pub enum P2PMessage {
656    /// The first packet sent over the connection, and sent once by both sides.
657    Hello(HelloMessage),
658
659    /// Inform the peer that a disconnection is imminent; if received, a peer should disconnect
660    /// immediately.
661    Disconnect(DisconnectReason),
662
663    /// Requests an immediate reply of [`P2PMessage::Pong`] from the peer.
664    Ping,
665
666    /// Reply to the peer's [`P2PMessage::Ping`] packet.
667    Pong,
668}
669
670impl P2PMessage {
671    /// Gets the [`P2PMessageID`] for the given message.
672    pub const fn message_id(&self) -> P2PMessageID {
673        match self {
674            Self::Hello(_) => P2PMessageID::Hello,
675            Self::Disconnect(_) => P2PMessageID::Disconnect,
676            Self::Ping => P2PMessageID::Ping,
677            Self::Pong => P2PMessageID::Pong,
678        }
679    }
680}
681
682impl Encodable for P2PMessage {
683    /// The [`Encodable`] implementation for [`P2PMessage::Ping`] and [`P2PMessage::Pong`] encodes
684    /// the message as RLP, and prepends a snappy header to the RLP bytes for all variants except
685    /// the [`P2PMessage::Hello`] variant, because the hello message is never compressed in the
686    /// `p2p` subprotocol.
687    fn encode(&self, out: &mut dyn BufMut) {
688        (self.message_id() as u8).encode(out);
689        match self {
690            Self::Hello(msg) => msg.encode(out),
691            Self::Disconnect(msg) => msg.encode(out),
692            Self::Ping => {
693                // Ping payload is _always_ snappy encoded
694                out.put_u8(0x01);
695                out.put_u8(0x00);
696                out.put_u8(EMPTY_LIST_CODE);
697            }
698            Self::Pong => {
699                // Pong payload is _always_ snappy encoded
700                out.put_u8(0x01);
701                out.put_u8(0x00);
702                out.put_u8(EMPTY_LIST_CODE);
703            }
704        }
705    }
706
707    fn length(&self) -> usize {
708        let payload_len = match self {
709            Self::Hello(msg) => msg.length(),
710            Self::Disconnect(msg) => msg.length(),
711            // id + snappy encoded payload
712            Self::Ping | Self::Pong => 3, // len([0x01, 0x00, 0xc0]) = 3
713        };
714        payload_len + 1 // (1 for length of p2p message id)
715    }
716}
717
718impl Decodable for P2PMessage {
719    /// The [`Decodable`] implementation for [`P2PMessage`] assumes that each of the message
720    /// variants are snappy compressed, except for the [`P2PMessage::Hello`] variant since the
721    /// hello message is never compressed in the `p2p` subprotocol.
722    ///
723    /// The [`Decodable`] implementation for [`P2PMessage::Ping`] and [`P2PMessage::Pong`] expects
724    /// a snappy encoded payload, see [`Encodable`] implementation.
725    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
726        /// Removes the snappy prefix from the Ping/Pong buffer
727        fn advance_snappy_ping_pong_payload(buf: &mut &[u8]) -> alloy_rlp::Result<()> {
728            if buf.len() < 3 {
729                return Err(RlpError::InputTooShort)
730            }
731            if buf[..3] != [0x01, 0x00, EMPTY_LIST_CODE] {
732                return Err(RlpError::Custom("expected snappy payload"))
733            }
734            buf.advance(3);
735            Ok(())
736        }
737
738        let message_id = u8::decode(&mut &buf[..])?;
739        let id = P2PMessageID::try_from(message_id)
740            .or(Err(RlpError::Custom("unknown p2p message id")))?;
741        buf.advance(1);
742        match id {
743            P2PMessageID::Hello => Ok(Self::Hello(HelloMessage::decode(buf)?)),
744            P2PMessageID::Disconnect => Ok(Self::Disconnect(DisconnectReason::decode(buf)?)),
745            P2PMessageID::Ping => {
746                advance_snappy_ping_pong_payload(buf)?;
747                Ok(Self::Ping)
748            }
749            P2PMessageID::Pong => {
750                advance_snappy_ping_pong_payload(buf)?;
751                Ok(Self::Pong)
752            }
753        }
754    }
755}
756
757/// Message IDs for `p2p` subprotocol messages.
758#[derive(Debug, Copy, Clone, Eq, PartialEq)]
759pub enum P2PMessageID {
760    /// Message ID for the [`P2PMessage::Hello`] message.
761    Hello = 0x00,
762
763    /// Message ID for the [`P2PMessage::Disconnect`] message.
764    Disconnect = 0x01,
765
766    /// Message ID for the [`P2PMessage::Ping`] message.
767    Ping = 0x02,
768
769    /// Message ID for the [`P2PMessage::Pong`] message.
770    Pong = 0x03,
771}
772
773impl From<P2PMessage> for P2PMessageID {
774    fn from(msg: P2PMessage) -> Self {
775        match msg {
776            P2PMessage::Hello(_) => Self::Hello,
777            P2PMessage::Disconnect(_) => Self::Disconnect,
778            P2PMessage::Ping => Self::Ping,
779            P2PMessage::Pong => Self::Pong,
780        }
781    }
782}
783
784impl TryFrom<u8> for P2PMessageID {
785    type Error = P2PStreamError;
786
787    fn try_from(id: u8) -> Result<Self, Self::Error> {
788        match id {
789            0x00 => Ok(Self::Hello),
790            0x01 => Ok(Self::Disconnect),
791            0x02 => Ok(Self::Ping),
792            0x03 => Ok(Self::Pong),
793            _ => Err(P2PStreamError::UnknownReservedMessageId(id)),
794        }
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801    use crate::{capability::SharedCapability, test_utils::eth_hello, EthVersion, ProtocolVersion};
802    use tokio::net::{TcpListener, TcpStream};
803    use tokio_util::codec::Decoder;
804
805    #[tokio::test]
806    async fn test_can_disconnect() {
807        reth_tracing::init_test_tracing();
808        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
809        let local_addr = listener.local_addr().unwrap();
810
811        let expected_disconnect = DisconnectReason::UselessPeer;
812
813        let handle = tokio::spawn(async move {
814            // roughly based off of the design of tokio::net::TcpListener
815            let (incoming, _) = listener.accept().await.unwrap();
816            let stream = crate::PassthroughCodec::default().framed(incoming);
817
818            let (server_hello, _) = eth_hello();
819
820            let (mut p2p_stream, _) =
821                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
822
823            p2p_stream.disconnect(expected_disconnect).await.unwrap();
824        });
825
826        let outgoing = TcpStream::connect(local_addr).await.unwrap();
827        let sink = crate::PassthroughCodec::default().framed(outgoing);
828
829        let (client_hello, _) = eth_hello();
830
831        let (mut p2p_stream, _) =
832            UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
833
834        let err = p2p_stream.next().await.unwrap().unwrap_err();
835        match err {
836            P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
837            e => panic!("unexpected err: {e}"),
838        }
839
840        handle.await.unwrap();
841    }
842
843    #[tokio::test]
844    async fn test_can_disconnect_weird_disconnect_encoding() {
845        reth_tracing::init_test_tracing();
846        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
847        let local_addr = listener.local_addr().unwrap();
848
849        let expected_disconnect = DisconnectReason::SubprotocolSpecific;
850
851        let handle = tokio::spawn(async move {
852            // roughly based off of the design of tokio::net::TcpListener
853            let (incoming, _) = listener.accept().await.unwrap();
854            let stream = crate::PassthroughCodec::default().framed(incoming);
855
856            let (server_hello, _) = eth_hello();
857
858            let (mut p2p_stream, _) =
859                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
860
861            // Unrolled `disconnect` method, without compression
862            p2p_stream.outgoing_messages.clear();
863
864            p2p_stream.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(
865                P2PMessage::Disconnect(DisconnectReason::SubprotocolSpecific),
866            )));
867            p2p_stream.disconnecting = true;
868            p2p_stream.close().await.unwrap();
869        });
870
871        let outgoing = TcpStream::connect(local_addr).await.unwrap();
872        let sink = crate::PassthroughCodec::default().framed(outgoing);
873
874        let (client_hello, _) = eth_hello();
875
876        let (mut p2p_stream, _) =
877            UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
878
879        let err = p2p_stream.next().await.unwrap().unwrap_err();
880        match err {
881            P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
882            e => panic!("unexpected err: {e}"),
883        }
884
885        handle.await.unwrap();
886    }
887
888    #[tokio::test]
889    async fn test_handshake_passthrough() {
890        // create a p2p stream and server, then confirm that the two are authed
891        // create tcpstream
892        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
893        let local_addr = listener.local_addr().unwrap();
894
895        let handle = tokio::spawn(async move {
896            // roughly based off of the design of tokio::net::TcpListener
897            let (incoming, _) = listener.accept().await.unwrap();
898            let stream = crate::PassthroughCodec::default().framed(incoming);
899
900            let (server_hello, _) = eth_hello();
901
902            let unauthed_stream = UnauthedP2PStream::new(stream);
903            let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
904
905            // ensure that the two share a single capability, eth67
906            assert_eq!(
907                *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
908                SharedCapability::Eth {
909                    version: EthVersion::Eth67,
910                    offset: MAX_RESERVED_MESSAGE_ID + 1
911                }
912            );
913        });
914
915        let outgoing = TcpStream::connect(local_addr).await.unwrap();
916        let sink = crate::PassthroughCodec::default().framed(outgoing);
917
918        let (client_hello, _) = eth_hello();
919
920        let unauthed_stream = UnauthedP2PStream::new(sink);
921        let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
922
923        // ensure that the two share a single capability, eth67
924        assert_eq!(
925            *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
926            SharedCapability::Eth {
927                version: EthVersion::Eth67,
928                offset: MAX_RESERVED_MESSAGE_ID + 1
929            }
930        );
931
932        // make sure the server receives the message and asserts before ending the test
933        handle.await.unwrap();
934    }
935
936    #[tokio::test]
937    async fn test_handshake_disconnect() {
938        // create a p2p stream and server, then confirm that the two are authed
939        // create tcpstream
940        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
941        let local_addr = listener.local_addr().unwrap();
942
943        let handle = tokio::spawn(Box::pin(async move {
944            // roughly based off of the design of tokio::net::TcpListener
945            let (incoming, _) = listener.accept().await.unwrap();
946            let stream = crate::PassthroughCodec::default().framed(incoming);
947
948            let (server_hello, _) = eth_hello();
949
950            let unauthed_stream = UnauthedP2PStream::new(stream);
951            match unauthed_stream.handshake(server_hello.clone()).await {
952                Ok((_, hello)) => {
953                    panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
954                }
955                Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
956                    assert_ne!(expected, got);
957                    assert_eq!(expected, server_hello.protocol_version);
958                }
959                Err(other_err) => {
960                    panic!("expected mismatched protocol version error, got {other_err:?}")
961                }
962            }
963        }));
964
965        let outgoing = TcpStream::connect(local_addr).await.unwrap();
966        let sink = crate::PassthroughCodec::default().framed(outgoing);
967
968        let (mut client_hello, _) = eth_hello();
969
970        // modify the hello to include an incompatible p2p protocol version
971        client_hello.protocol_version = ProtocolVersion::V4;
972
973        let unauthed_stream = UnauthedP2PStream::new(sink);
974        match unauthed_stream.handshake(client_hello.clone()).await {
975            Ok((_, hello)) => {
976                panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
977            }
978            Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
979                assert_ne!(expected, got);
980                assert_eq!(expected, client_hello.protocol_version);
981            }
982            Err(other_err) => {
983                panic!("expected mismatched protocol version error, got {other_err:?}")
984            }
985        }
986
987        // make sure the server receives the message and asserts before ending the test
988        handle.await.unwrap();
989    }
990
991    #[test]
992    fn snappy_decode_encode_ping() {
993        let snappy_ping = b"\x02\x01\0\xc0";
994        let ping = P2PMessage::decode(&mut &snappy_ping[..]).unwrap();
995        assert!(matches!(ping, P2PMessage::Ping));
996        assert_eq!(alloy_rlp::encode(ping), &snappy_ping[..]);
997    }
998
999    #[test]
1000    fn snappy_decode_encode_pong() {
1001        let snappy_pong = b"\x03\x01\0\xc0";
1002        let pong = P2PMessage::decode(&mut &snappy_pong[..]).unwrap();
1003        assert!(matches!(pong, P2PMessage::Pong));
1004        assert_eq!(alloy_rlp::encode(pong), &snappy_pong[..]);
1005    }
1006}