reth_eth_wire_types/
message.rs

1//! Implements Ethereum wire protocol for versions 66, 67, and 68.
2//! Defines structs/enums for messages, request-response pairs, and broadcasts.
3//! Handles compatibility with [`EthVersion`].
4//!
5//! Examples include creating, encoding, and decoding protocol messages.
6//!
7//! Reference: [Ethereum Wire Protocol](https://github.com/ethereum/wiki/wiki/Ethereum-Wire-Protocol).
8
9use super::{
10    broadcast::NewBlockHashes, BlockBodies, BlockHeaders, GetBlockBodies, GetBlockHeaders,
11    GetNodeData, GetPooledTransactions, GetReceipts, NewBlock, NewPooledTransactionHashes66,
12    NewPooledTransactionHashes68, NodeData, PooledTransactions, Receipts, Status, StatusEth69,
13    Transactions,
14};
15use crate::{
16    status::StatusMessage, EthNetworkPrimitives, EthVersion, NetworkPrimitives,
17    RawCapabilityMessage, SharedTransactions,
18};
19use alloc::{boxed::Box, sync::Arc};
20use alloy_primitives::{
21    bytes::{Buf, BufMut},
22    Bytes,
23};
24use alloy_rlp::{length_of_length, Decodable, Encodable, Header};
25use core::fmt::Debug;
26
27/// [`MAX_MESSAGE_SIZE`] is the maximum cap on the size of a protocol message.
28// https://github.com/ethereum/go-ethereum/blob/30602163d5d8321fbc68afdcbbaf2362b2641bde/eth/protocols/eth/protocol.go#L50
29pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
30
31/// Error when sending/receiving a message
32#[derive(thiserror::Error, Debug)]
33pub enum MessageError {
34    /// Flags an unrecognized message ID for a given protocol version.
35    #[error("message id {1:?} is invalid for version {0:?}")]
36    Invalid(EthVersion, EthMessageID),
37    /// Thrown when rlp decoding a message failed.
38    #[error("RLP error: {0}")]
39    RlpError(#[from] alloy_rlp::Error),
40}
41
42/// An `eth` protocol message, containing a message ID and payload.
43#[derive(Clone, Debug, PartialEq, Eq)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
46    /// The unique identifier representing the type of the Ethereum message.
47    pub message_type: EthMessageID,
48    /// The content of the message, including specific data based on the message type.
49    #[cfg_attr(
50        feature = "serde",
51        serde(bound = "EthMessage<N>: serde::Serialize + serde::de::DeserializeOwned")
52    )]
53    pub message: EthMessage<N>,
54}
55
56impl<N: NetworkPrimitives> ProtocolMessage<N> {
57    /// Create a new `ProtocolMessage` from a message type and message rlp bytes.
58    pub fn decode_message(version: EthVersion, buf: &mut &[u8]) -> Result<Self, MessageError> {
59        let message_type = EthMessageID::decode(buf)?;
60
61        // For EIP-7642 (https://github.com/ethereum/EIPs/blob/master/EIPS/eip-7642.md):
62        // pre-merge (legacy) status messages include total difficulty, whereas eth/69 omits it.
63        let message = match message_type {
64            EthMessageID::Status => EthMessage::Status(if version < EthVersion::Eth69 {
65                StatusMessage::Legacy(Status::decode(buf)?)
66            } else {
67                StatusMessage::Eth69(StatusEth69::decode(buf)?)
68            }),
69            EthMessageID::NewBlockHashes => {
70                if version.is_eth69() {
71                    return Err(MessageError::Invalid(version, EthMessageID::NewBlockHashes));
72                }
73                EthMessage::NewBlockHashes(NewBlockHashes::decode(buf)?)
74            }
75            EthMessageID::NewBlock => {
76                if version.is_eth69() {
77                    return Err(MessageError::Invalid(version, EthMessageID::NewBlock));
78                }
79                EthMessage::NewBlock(Box::new(NewBlock::decode(buf)?))
80            }
81            EthMessageID::Transactions => EthMessage::Transactions(Transactions::decode(buf)?),
82            EthMessageID::NewPooledTransactionHashes => {
83                if version >= EthVersion::Eth68 {
84                    EthMessage::NewPooledTransactionHashes68(NewPooledTransactionHashes68::decode(
85                        buf,
86                    )?)
87                } else {
88                    EthMessage::NewPooledTransactionHashes66(NewPooledTransactionHashes66::decode(
89                        buf,
90                    )?)
91                }
92            }
93            EthMessageID::GetBlockHeaders => EthMessage::GetBlockHeaders(RequestPair::decode(buf)?),
94            EthMessageID::BlockHeaders => EthMessage::BlockHeaders(RequestPair::decode(buf)?),
95            EthMessageID::GetBlockBodies => EthMessage::GetBlockBodies(RequestPair::decode(buf)?),
96            EthMessageID::BlockBodies => EthMessage::BlockBodies(RequestPair::decode(buf)?),
97            EthMessageID::GetPooledTransactions => {
98                EthMessage::GetPooledTransactions(RequestPair::decode(buf)?)
99            }
100            EthMessageID::PooledTransactions => {
101                EthMessage::PooledTransactions(RequestPair::decode(buf)?)
102            }
103            EthMessageID::GetNodeData => {
104                if version >= EthVersion::Eth67 {
105                    return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
106                }
107                EthMessage::GetNodeData(RequestPair::decode(buf)?)
108            }
109            EthMessageID::NodeData => {
110                if version >= EthVersion::Eth67 {
111                    return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
112                }
113                EthMessage::NodeData(RequestPair::decode(buf)?)
114            }
115            EthMessageID::GetReceipts => EthMessage::GetReceipts(RequestPair::decode(buf)?),
116            EthMessageID::Receipts => EthMessage::Receipts(RequestPair::decode(buf)?),
117            EthMessageID::Other(_) => {
118                let raw_payload = Bytes::copy_from_slice(buf);
119                buf.advance(raw_payload.len());
120                EthMessage::Other(RawCapabilityMessage::new(
121                    message_type.to_u8() as usize,
122                    raw_payload.into(),
123                ))
124            }
125        };
126        Ok(Self { message_type, message })
127    }
128}
129
130impl<N: NetworkPrimitives> Encodable for ProtocolMessage<N> {
131    /// Encodes the protocol message into bytes. The message type is encoded as a single byte and
132    /// prepended to the message.
133    fn encode(&self, out: &mut dyn BufMut) {
134        self.message_type.encode(out);
135        self.message.encode(out);
136    }
137    fn length(&self) -> usize {
138        self.message_type.length() + self.message.length()
139    }
140}
141
142impl<N: NetworkPrimitives> From<EthMessage<N>> for ProtocolMessage<N> {
143    fn from(message: EthMessage<N>) -> Self {
144        Self { message_type: message.message_id(), message }
145    }
146}
147
148/// Represents messages that can be sent to multiple peers.
149#[derive(Clone, Debug)]
150pub struct ProtocolBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
151    /// The unique identifier representing the type of the Ethereum message.
152    pub message_type: EthMessageID,
153    /// The content of the message to be broadcasted, including specific data based on the message
154    /// type.
155    pub message: EthBroadcastMessage<N>,
156}
157
158impl<N: NetworkPrimitives> Encodable for ProtocolBroadcastMessage<N> {
159    /// Encodes the protocol message into bytes. The message type is encoded as a single byte and
160    /// prepended to the message.
161    fn encode(&self, out: &mut dyn BufMut) {
162        self.message_type.encode(out);
163        self.message.encode(out);
164    }
165    fn length(&self) -> usize {
166        self.message_type.length() + self.message.length()
167    }
168}
169
170impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for ProtocolBroadcastMessage<N> {
171    fn from(message: EthBroadcastMessage<N>) -> Self {
172        Self { message_type: message.message_id(), message }
173    }
174}
175
176/// Represents a message in the eth wire protocol, versions 66, 67 and 68.
177///
178/// The ethereum wire protocol is a set of messages that are broadcast to the network in two
179/// styles:
180///  * A request message sent by a peer (such as [`GetPooledTransactions`]), and an associated
181///    response message (such as [`PooledTransactions`]).
182///  * A message that is broadcast to the network, without a corresponding request.
183///
184/// The newer `eth/66` is an efficiency upgrade on top of `eth/65`, introducing a request id to
185/// correlate request-response message pairs. This allows for request multiplexing.
186///
187/// The `eth/67` is based on `eth/66` but only removes two messages, [`GetNodeData`] and
188/// [`NodeData`].
189///
190/// The `eth/68` changes only `NewPooledTransactionHashes` to include `types` and `sized`. For
191/// it, `NewPooledTransactionHashes` is renamed as [`NewPooledTransactionHashes66`] and
192/// [`NewPooledTransactionHashes68`] is defined.
193#[derive(Clone, Debug, PartialEq, Eq)]
194#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
195pub enum EthMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
196    /// Represents a Status message required for the protocol handshake.
197    Status(StatusMessage),
198    /// Represents a `NewBlockHashes` message broadcast to the network.
199    NewBlockHashes(NewBlockHashes),
200    /// Represents a `NewBlock` message broadcast to the network.
201    #[cfg_attr(
202        feature = "serde",
203        serde(bound = "N::Block: serde::Serialize + serde::de::DeserializeOwned")
204    )]
205    NewBlock(Box<NewBlock<N::Block>>),
206    /// Represents a Transactions message broadcast to the network.
207    #[cfg_attr(
208        feature = "serde",
209        serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
210    )]
211    Transactions(Transactions<N::BroadcastedTransaction>),
212    /// Represents a `NewPooledTransactionHashes` message for eth/66 version.
213    NewPooledTransactionHashes66(NewPooledTransactionHashes66),
214    /// Represents a `NewPooledTransactionHashes` message for eth/68 version.
215    NewPooledTransactionHashes68(NewPooledTransactionHashes68),
216    // The following messages are request-response message pairs
217    /// Represents a `GetBlockHeaders` request-response pair.
218    GetBlockHeaders(RequestPair<GetBlockHeaders>),
219    /// Represents a `BlockHeaders` request-response pair.
220    #[cfg_attr(
221        feature = "serde",
222        serde(bound = "N::BlockHeader: serde::Serialize + serde::de::DeserializeOwned")
223    )]
224    BlockHeaders(RequestPair<BlockHeaders<N::BlockHeader>>),
225    /// Represents a `GetBlockBodies` request-response pair.
226    GetBlockBodies(RequestPair<GetBlockBodies>),
227    /// Represents a `BlockBodies` request-response pair.
228    #[cfg_attr(
229        feature = "serde",
230        serde(bound = "N::BlockBody: serde::Serialize + serde::de::DeserializeOwned")
231    )]
232    BlockBodies(RequestPair<BlockBodies<N::BlockBody>>),
233    /// Represents a `GetPooledTransactions` request-response pair.
234    GetPooledTransactions(RequestPair<GetPooledTransactions>),
235    /// Represents a `PooledTransactions` request-response pair.
236    #[cfg_attr(
237        feature = "serde",
238        serde(bound = "N::PooledTransaction: serde::Serialize + serde::de::DeserializeOwned")
239    )]
240    PooledTransactions(RequestPair<PooledTransactions<N::PooledTransaction>>),
241    /// Represents a `GetNodeData` request-response pair.
242    GetNodeData(RequestPair<GetNodeData>),
243    /// Represents a `NodeData` request-response pair.
244    NodeData(RequestPair<NodeData>),
245    /// Represents a `GetReceipts` request-response pair.
246    GetReceipts(RequestPair<GetReceipts>),
247    /// Represents a Receipts request-response pair.
248    #[cfg_attr(
249        feature = "serde",
250        serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
251    )]
252    Receipts(RequestPair<Receipts<N::Receipt>>),
253    /// Represents an encoded message that doesn't match any other variant
254    Other(RawCapabilityMessage),
255}
256
257impl<N: NetworkPrimitives> EthMessage<N> {
258    /// Returns the message's ID.
259    pub const fn message_id(&self) -> EthMessageID {
260        match self {
261            Self::Status(_) => EthMessageID::Status,
262            Self::NewBlockHashes(_) => EthMessageID::NewBlockHashes,
263            Self::NewBlock(_) => EthMessageID::NewBlock,
264            Self::Transactions(_) => EthMessageID::Transactions,
265            Self::NewPooledTransactionHashes66(_) | Self::NewPooledTransactionHashes68(_) => {
266                EthMessageID::NewPooledTransactionHashes
267            }
268            Self::GetBlockHeaders(_) => EthMessageID::GetBlockHeaders,
269            Self::BlockHeaders(_) => EthMessageID::BlockHeaders,
270            Self::GetBlockBodies(_) => EthMessageID::GetBlockBodies,
271            Self::BlockBodies(_) => EthMessageID::BlockBodies,
272            Self::GetPooledTransactions(_) => EthMessageID::GetPooledTransactions,
273            Self::PooledTransactions(_) => EthMessageID::PooledTransactions,
274            Self::GetNodeData(_) => EthMessageID::GetNodeData,
275            Self::NodeData(_) => EthMessageID::NodeData,
276            Self::GetReceipts(_) => EthMessageID::GetReceipts,
277            Self::Receipts(_) => EthMessageID::Receipts,
278            Self::Other(msg) => EthMessageID::Other(msg.id as u8),
279        }
280    }
281
282    /// Returns true if the message variant is a request.
283    pub const fn is_request(&self) -> bool {
284        matches!(
285            self,
286            Self::GetBlockBodies(_) |
287                Self::GetBlockHeaders(_) |
288                Self::GetReceipts(_) |
289                Self::GetPooledTransactions(_) |
290                Self::GetNodeData(_)
291        )
292    }
293
294    /// Returns true if the message variant is a response to a request.
295    pub const fn is_response(&self) -> bool {
296        matches!(
297            self,
298            Self::PooledTransactions(_) |
299                Self::Receipts(_) |
300                Self::BlockHeaders(_) |
301                Self::BlockBodies(_) |
302                Self::NodeData(_)
303        )
304    }
305}
306
307impl<N: NetworkPrimitives> Encodable for EthMessage<N> {
308    fn encode(&self, out: &mut dyn BufMut) {
309        match self {
310            Self::Status(status) => status.encode(out),
311            Self::NewBlockHashes(new_block_hashes) => new_block_hashes.encode(out),
312            Self::NewBlock(new_block) => new_block.encode(out),
313            Self::Transactions(transactions) => transactions.encode(out),
314            Self::NewPooledTransactionHashes66(hashes) => hashes.encode(out),
315            Self::NewPooledTransactionHashes68(hashes) => hashes.encode(out),
316            Self::GetBlockHeaders(request) => request.encode(out),
317            Self::BlockHeaders(headers) => headers.encode(out),
318            Self::GetBlockBodies(request) => request.encode(out),
319            Self::BlockBodies(bodies) => bodies.encode(out),
320            Self::GetPooledTransactions(request) => request.encode(out),
321            Self::PooledTransactions(transactions) => transactions.encode(out),
322            Self::GetNodeData(request) => request.encode(out),
323            Self::NodeData(data) => data.encode(out),
324            Self::GetReceipts(request) => request.encode(out),
325            Self::Receipts(receipts) => receipts.encode(out),
326            Self::Other(unknown) => out.put_slice(&unknown.payload),
327        }
328    }
329    fn length(&self) -> usize {
330        match self {
331            Self::Status(status) => status.length(),
332            Self::NewBlockHashes(new_block_hashes) => new_block_hashes.length(),
333            Self::NewBlock(new_block) => new_block.length(),
334            Self::Transactions(transactions) => transactions.length(),
335            Self::NewPooledTransactionHashes66(hashes) => hashes.length(),
336            Self::NewPooledTransactionHashes68(hashes) => hashes.length(),
337            Self::GetBlockHeaders(request) => request.length(),
338            Self::BlockHeaders(headers) => headers.length(),
339            Self::GetBlockBodies(request) => request.length(),
340            Self::BlockBodies(bodies) => bodies.length(),
341            Self::GetPooledTransactions(request) => request.length(),
342            Self::PooledTransactions(transactions) => transactions.length(),
343            Self::GetNodeData(request) => request.length(),
344            Self::NodeData(data) => data.length(),
345            Self::GetReceipts(request) => request.length(),
346            Self::Receipts(receipts) => receipts.length(),
347            Self::Other(unknown) => unknown.length(),
348        }
349    }
350}
351
352/// Represents broadcast messages of [`EthMessage`] with the same object that can be sent to
353/// multiple peers.
354///
355/// Messages that contain a list of hashes depend on the peer the message is sent to. A peer should
356/// never receive a hash of an object (block, transaction) it has already seen.
357///
358/// Note: This is only useful for outgoing messages.
359#[derive(Clone, Debug, PartialEq, Eq)]
360pub enum EthBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
361    /// Represents a new block broadcast message.
362    NewBlock(Arc<NewBlock<N::Block>>),
363    /// Represents a transactions broadcast message.
364    Transactions(SharedTransactions<N::BroadcastedTransaction>),
365}
366
367// === impl EthBroadcastMessage ===
368
369impl<N: NetworkPrimitives> EthBroadcastMessage<N> {
370    /// Returns the message's ID.
371    pub const fn message_id(&self) -> EthMessageID {
372        match self {
373            Self::NewBlock(_) => EthMessageID::NewBlock,
374            Self::Transactions(_) => EthMessageID::Transactions,
375        }
376    }
377}
378
379impl<N: NetworkPrimitives> Encodable for EthBroadcastMessage<N> {
380    fn encode(&self, out: &mut dyn BufMut) {
381        match self {
382            Self::NewBlock(new_block) => new_block.encode(out),
383            Self::Transactions(transactions) => transactions.encode(out),
384        }
385    }
386
387    fn length(&self) -> usize {
388        match self {
389            Self::NewBlock(new_block) => new_block.length(),
390            Self::Transactions(transactions) => transactions.length(),
391        }
392    }
393}
394
395/// Represents message IDs for eth protocol messages.
396#[repr(u8)]
397#[derive(Clone, Copy, Debug, PartialEq, Eq)]
398#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
399pub enum EthMessageID {
400    /// Status message.
401    Status = 0x00,
402    /// New block hashes message.
403    NewBlockHashes = 0x01,
404    /// Transactions message.
405    Transactions = 0x02,
406    /// Get block headers message.
407    GetBlockHeaders = 0x03,
408    /// Block headers message.
409    BlockHeaders = 0x04,
410    /// Get block bodies message.
411    GetBlockBodies = 0x05,
412    /// Block bodies message.
413    BlockBodies = 0x06,
414    /// New block message.
415    NewBlock = 0x07,
416    /// New pooled transaction hashes message.
417    NewPooledTransactionHashes = 0x08,
418    /// Requests pooled transactions.
419    GetPooledTransactions = 0x09,
420    /// Represents pooled transactions.
421    PooledTransactions = 0x0a,
422    /// Requests node data.
423    GetNodeData = 0x0d,
424    /// Represents node data.
425    NodeData = 0x0e,
426    /// Requests receipts.
427    GetReceipts = 0x0f,
428    /// Represents receipts.
429    Receipts = 0x10,
430    /// Represents unknown message types.
431    Other(u8),
432}
433
434impl EthMessageID {
435    /// Returns the corresponding `u8` value for an `EthMessageID`.
436    pub const fn to_u8(&self) -> u8 {
437        match self {
438            Self::Status => 0x00,
439            Self::NewBlockHashes => 0x01,
440            Self::Transactions => 0x02,
441            Self::GetBlockHeaders => 0x03,
442            Self::BlockHeaders => 0x04,
443            Self::GetBlockBodies => 0x05,
444            Self::BlockBodies => 0x06,
445            Self::NewBlock => 0x07,
446            Self::NewPooledTransactionHashes => 0x08,
447            Self::GetPooledTransactions => 0x09,
448            Self::PooledTransactions => 0x0a,
449            Self::GetNodeData => 0x0d,
450            Self::NodeData => 0x0e,
451            Self::GetReceipts => 0x0f,
452            Self::Receipts => 0x10,
453            Self::Other(value) => *value, // Return the stored `u8`
454        }
455    }
456
457    /// Returns the max value.
458    pub const fn max() -> u8 {
459        Self::Receipts.to_u8()
460    }
461}
462
463impl Encodable for EthMessageID {
464    fn encode(&self, out: &mut dyn BufMut) {
465        out.put_u8(self.to_u8());
466    }
467    fn length(&self) -> usize {
468        1
469    }
470}
471
472impl Decodable for EthMessageID {
473    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
474        let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
475            0x00 => Self::Status,
476            0x01 => Self::NewBlockHashes,
477            0x02 => Self::Transactions,
478            0x03 => Self::GetBlockHeaders,
479            0x04 => Self::BlockHeaders,
480            0x05 => Self::GetBlockBodies,
481            0x06 => Self::BlockBodies,
482            0x07 => Self::NewBlock,
483            0x08 => Self::NewPooledTransactionHashes,
484            0x09 => Self::GetPooledTransactions,
485            0x0a => Self::PooledTransactions,
486            0x0d => Self::GetNodeData,
487            0x0e => Self::NodeData,
488            0x0f => Self::GetReceipts,
489            0x10 => Self::Receipts,
490            unknown => Self::Other(*unknown),
491        };
492        buf.advance(1);
493        Ok(id)
494    }
495}
496
497impl TryFrom<usize> for EthMessageID {
498    type Error = &'static str;
499
500    fn try_from(value: usize) -> Result<Self, Self::Error> {
501        match value {
502            0x00 => Ok(Self::Status),
503            0x01 => Ok(Self::NewBlockHashes),
504            0x02 => Ok(Self::Transactions),
505            0x03 => Ok(Self::GetBlockHeaders),
506            0x04 => Ok(Self::BlockHeaders),
507            0x05 => Ok(Self::GetBlockBodies),
508            0x06 => Ok(Self::BlockBodies),
509            0x07 => Ok(Self::NewBlock),
510            0x08 => Ok(Self::NewPooledTransactionHashes),
511            0x09 => Ok(Self::GetPooledTransactions),
512            0x0a => Ok(Self::PooledTransactions),
513            0x0d => Ok(Self::GetNodeData),
514            0x0e => Ok(Self::NodeData),
515            0x0f => Ok(Self::GetReceipts),
516            0x10 => Ok(Self::Receipts),
517            _ => Err("Invalid message ID"),
518        }
519    }
520}
521
522/// This is used for all request-response style `eth` protocol messages.
523/// This can represent either a request or a response, since both include a message payload and
524/// request id.
525#[derive(Clone, Debug, PartialEq, Eq)]
526#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
527#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
528pub struct RequestPair<T> {
529    /// id for the contained request or response message
530    pub request_id: u64,
531
532    /// the request or response message payload
533    pub message: T,
534}
535
536/// Allows messages with request ids to be serialized into RLP bytes.
537impl<T> Encodable for RequestPair<T>
538where
539    T: Encodable,
540{
541    fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
542        let header =
543            Header { list: true, payload_length: self.request_id.length() + self.message.length() };
544
545        header.encode(out);
546        self.request_id.encode(out);
547        self.message.encode(out);
548    }
549
550    fn length(&self) -> usize {
551        let mut length = 0;
552        length += self.request_id.length();
553        length += self.message.length();
554        length += length_of_length(length);
555        length
556    }
557}
558
559/// Allows messages with request ids to be deserialized into RLP bytes.
560impl<T> Decodable for RequestPair<T>
561where
562    T: Decodable,
563{
564    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
565        let header = Header::decode(buf)?;
566
567        let initial_length = buf.len();
568        let request_id = u64::decode(buf)?;
569        let message = T::decode(buf)?;
570
571        // Check that the buffer consumed exactly payload_length bytes after decoding the
572        // RequestPair
573        let consumed_len = initial_length - buf.len();
574        if consumed_len != header.payload_length {
575            return Err(alloy_rlp::Error::UnexpectedLength)
576        }
577
578        Ok(Self { request_id, message })
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    use super::MessageError;
585    use crate::{
586        message::RequestPair, EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion,
587        GetNodeData, NodeData, ProtocolMessage, RawCapabilityMessage,
588    };
589    use alloy_primitives::hex;
590    use alloy_rlp::{Decodable, Encodable, Error};
591    use reth_ethereum_primitives::BlockBody;
592
593    fn encode<T: Encodable>(value: T) -> Vec<u8> {
594        let mut buf = vec![];
595        value.encode(&mut buf);
596        buf
597    }
598
599    #[test]
600    fn test_removed_message_at_eth67() {
601        let get_node_data = EthMessage::<EthNetworkPrimitives>::GetNodeData(RequestPair {
602            request_id: 1337,
603            message: GetNodeData(vec![]),
604        });
605        let buf = encode(ProtocolMessage {
606            message_type: EthMessageID::GetNodeData,
607            message: get_node_data,
608        });
609        let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
610            crate::EthVersion::Eth67,
611            &mut &buf[..],
612        );
613        assert!(matches!(msg, Err(MessageError::Invalid(..))));
614
615        let node_data = EthMessage::<EthNetworkPrimitives>::NodeData(RequestPair {
616            request_id: 1337,
617            message: NodeData(vec![]),
618        });
619        let buf =
620            encode(ProtocolMessage { message_type: EthMessageID::NodeData, message: node_data });
621        let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
622            crate::EthVersion::Eth67,
623            &mut &buf[..],
624        );
625        assert!(matches!(msg, Err(MessageError::Invalid(..))));
626    }
627
628    #[test]
629    fn request_pair_encode() {
630        let request_pair = RequestPair { request_id: 1337, message: vec![5u8] };
631
632        // c5: start of list (c0) + len(full_list) (length is <55 bytes)
633        // 82: 0x80 + len(1337)
634        // 05 39: 1337 (request_id)
635        // === full_list ===
636        // c1: start of list (c0) + len(list) (length is <55 bytes)
637        // 05: 5 (message)
638        let expected = hex!("c5820539c105");
639        let got = encode(request_pair);
640        assert_eq!(expected[..], got, "expected: {expected:X?}, got: {got:X?}",);
641    }
642
643    #[test]
644    fn request_pair_decode() {
645        let raw_pair = &hex!("c5820539c105")[..];
646
647        let expected = RequestPair { request_id: 1337, message: vec![5u8] };
648
649        let got = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair).unwrap();
650        assert_eq!(expected.length(), raw_pair.len());
651        assert_eq!(expected, got);
652    }
653
654    #[test]
655    fn malicious_request_pair_decode() {
656        // A maliciously encoded request pair, where the len(full_list) is 5, but it
657        // actually consumes 6 bytes when decoding
658        //
659        // c5: start of list (c0) + len(full_list) (length is <55 bytes)
660        // 82: 0x80 + len(1337)
661        // 05 39: 1337 (request_id)
662        // === full_list ===
663        // c2: start of list (c0) + len(list) (length is <55 bytes)
664        // 05 05: 5 5(message)
665        let raw_pair = &hex!("c5820539c20505")[..];
666
667        let result = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair);
668        assert!(matches!(result, Err(Error::UnexpectedLength)));
669    }
670
671    #[test]
672    fn empty_block_bodies_protocol() {
673        let empty_block_bodies =
674            ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
675                request_id: 0,
676                message: Default::default(),
677            }));
678        let mut buf = Vec::new();
679        empty_block_bodies.encode(&mut buf);
680        let decoded =
681            ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
682        assert_eq!(empty_block_bodies, decoded);
683    }
684
685    #[test]
686    fn empty_block_body_protocol() {
687        let empty_block_bodies =
688            ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
689                request_id: 0,
690                message: vec![BlockBody {
691                    transactions: vec![],
692                    ommers: vec![],
693                    withdrawals: Some(Default::default()),
694                }]
695                .into(),
696            }));
697        let mut buf = Vec::new();
698        empty_block_bodies.encode(&mut buf);
699        let decoded =
700            ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
701        assert_eq!(empty_block_bodies, decoded);
702    }
703
704    #[test]
705    fn decode_block_bodies_message() {
706        let buf = hex!("06c48199c1c0");
707        let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
708            EthVersion::Eth68,
709            &mut &buf[..],
710        )
711        .unwrap_err();
712        assert!(matches!(msg, MessageError::RlpError(alloy_rlp::Error::InputTooShort)));
713    }
714
715    #[test]
716    fn custom_message_roundtrip() {
717        let custom_payload = vec![1, 2, 3, 4, 5];
718        let custom_message = RawCapabilityMessage::new(0x20, custom_payload.into());
719        let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
720            message_type: EthMessageID::Other(0x20),
721            message: EthMessage::Other(custom_message),
722        };
723
724        let encoded = encode(protocol_message.clone());
725        let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
726            EthVersion::Eth68,
727            &mut &encoded[..],
728        )
729        .unwrap();
730
731        assert_eq!(protocol_message, decoded);
732    }
733
734    #[test]
735    fn custom_message_empty_payload_roundtrip() {
736        let custom_message = RawCapabilityMessage::new(0x30, vec![].into());
737        let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
738            message_type: EthMessageID::Other(0x30),
739            message: EthMessage::Other(custom_message),
740        };
741
742        let encoded = encode(protocol_message.clone());
743        let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
744            EthVersion::Eth68,
745            &mut &encoded[..],
746        )
747        .unwrap();
748
749        assert_eq!(protocol_message, decoded);
750    }
751}