reth_ress_protocol/
message.rs

1//! Implements Ress protocol
2//! Defines structs/enums for messages, request-response pairs.
3//!
4//! Examples include creating, encoding, and decoding protocol messages.
5
6use crate::NodeType;
7use alloy_consensus::Header;
8use alloy_primitives::{
9    bytes::{Buf, BufMut},
10    BlockHash, Bytes, B256,
11};
12use alloy_rlp::{BytesMut, Decodable, Encodable, RlpDecodable, RlpEncodable};
13use reth_eth_wire::{message::RequestPair, protocol::Protocol, Capability};
14use reth_ethereum_primitives::BlockBody;
15
16/// An Ress protocol message, containing a message ID and payload.
17#[derive(PartialEq, Eq, Clone, Debug)]
18pub struct RessProtocolMessage {
19    /// The unique identifier representing the type of the Ress message.
20    pub message_type: RessMessageID,
21    /// The content of the message, including specific data based on the message type.
22    pub message: RessMessage,
23}
24
25#[cfg(any(test, feature = "arbitrary"))]
26impl<'a> arbitrary::Arbitrary<'a> for RessProtocolMessage {
27    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
28        let message: RessMessage = u.arbitrary()?;
29        Ok(Self { message_type: message.message_id(), message })
30    }
31}
32
33impl RessProtocolMessage {
34    /// Returns the capability for the `ress` protocol.
35    pub const fn capability() -> Capability {
36        Capability::new_static("ress", 1)
37    }
38
39    /// Returns the protocol for the `ress` protocol.
40    pub const fn protocol() -> Protocol {
41        Protocol::new(Self::capability(), 9)
42    }
43
44    /// Create node type message.
45    pub const fn node_type(node_type: NodeType) -> Self {
46        RessMessage::NodeType(node_type).into_protocol_message()
47    }
48
49    /// Headers request.
50    pub const fn get_headers(request_id: u64, request: GetHeaders) -> Self {
51        RessMessage::GetHeaders(RequestPair { request_id, message: request })
52            .into_protocol_message()
53    }
54
55    /// Headers response.
56    pub const fn headers(request_id: u64, headers: Vec<Header>) -> Self {
57        RessMessage::Headers(RequestPair { request_id, message: headers }).into_protocol_message()
58    }
59
60    /// Block bodies request.
61    pub const fn get_block_bodies(request_id: u64, block_hashes: Vec<B256>) -> Self {
62        RessMessage::GetBlockBodies(RequestPair { request_id, message: block_hashes })
63            .into_protocol_message()
64    }
65
66    /// Block bodies response.
67    pub const fn block_bodies(request_id: u64, bodies: Vec<BlockBody>) -> Self {
68        RessMessage::BlockBodies(RequestPair { request_id, message: bodies })
69            .into_protocol_message()
70    }
71
72    /// Bytecode request.
73    pub const fn get_bytecode(request_id: u64, code_hash: B256) -> Self {
74        RessMessage::GetBytecode(RequestPair { request_id, message: code_hash })
75            .into_protocol_message()
76    }
77
78    /// Bytecode response.
79    pub const fn bytecode(request_id: u64, bytecode: Bytes) -> Self {
80        RessMessage::Bytecode(RequestPair { request_id, message: bytecode }).into_protocol_message()
81    }
82
83    /// Execution witness request.
84    pub const fn get_witness(request_id: u64, block_hash: BlockHash) -> Self {
85        RessMessage::GetWitness(RequestPair { request_id, message: block_hash })
86            .into_protocol_message()
87    }
88
89    /// Execution witness response.
90    pub const fn witness(request_id: u64, witness: Vec<Bytes>) -> Self {
91        RessMessage::Witness(RequestPair { request_id, message: witness }).into_protocol_message()
92    }
93
94    /// Return RLP encoded message.
95    pub fn encoded(&self) -> BytesMut {
96        let mut buf = BytesMut::with_capacity(self.length());
97        self.encode(&mut buf);
98        buf
99    }
100
101    /// Decodes a `RessProtocolMessage` from the given message buffer.
102    pub fn decode_message(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
103        let message_type = RessMessageID::decode(buf)?;
104        let message = match message_type {
105            RessMessageID::NodeType => RessMessage::NodeType(NodeType::decode(buf)?),
106            RessMessageID::GetHeaders => RessMessage::GetHeaders(RequestPair::decode(buf)?),
107            RessMessageID::Headers => RessMessage::Headers(RequestPair::decode(buf)?),
108            RessMessageID::GetBlockBodies => RessMessage::GetBlockBodies(RequestPair::decode(buf)?),
109            RessMessageID::BlockBodies => RessMessage::BlockBodies(RequestPair::decode(buf)?),
110            RessMessageID::GetBytecode => RessMessage::GetBytecode(RequestPair::decode(buf)?),
111            RessMessageID::Bytecode => RessMessage::Bytecode(RequestPair::decode(buf)?),
112            RessMessageID::GetWitness => RessMessage::GetWitness(RequestPair::decode(buf)?),
113            RessMessageID::Witness => RessMessage::Witness(RequestPair::decode(buf)?),
114        };
115        Ok(Self { message_type, message })
116    }
117}
118
119impl Encodable for RessProtocolMessage {
120    fn encode(&self, out: &mut dyn BufMut) {
121        self.message_type.encode(out);
122        self.message.encode(out);
123    }
124
125    fn length(&self) -> usize {
126        self.message_type.length() + self.message.length()
127    }
128}
129
130/// Represents message IDs for `ress` protocol messages.
131#[repr(u8)]
132#[derive(PartialEq, Eq, Clone, Copy, Debug)]
133#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
134#[cfg_attr(test, derive(strum_macros::EnumCount))]
135pub enum RessMessageID {
136    /// Node type message.
137    NodeType = 0x00,
138
139    /// Headers request message.
140    GetHeaders = 0x01,
141    /// Headers response message.
142    Headers = 0x02,
143
144    /// Block bodies request message.
145    GetBlockBodies = 0x03,
146    /// Block bodies response message.
147    BlockBodies = 0x04,
148
149    /// Bytecode request message.
150    GetBytecode = 0x05,
151    /// Bytecode response message.
152    Bytecode = 0x06,
153
154    /// Witness request message.
155    GetWitness = 0x07,
156    /// Witness response message.
157    Witness = 0x08,
158}
159
160impl Encodable for RessMessageID {
161    fn encode(&self, out: &mut dyn BufMut) {
162        out.put_u8(*self as u8);
163    }
164
165    fn length(&self) -> usize {
166        1
167    }
168}
169
170impl Decodable for RessMessageID {
171    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
172        let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
173            0x00 => Self::NodeType,
174            0x01 => Self::GetHeaders,
175            0x02 => Self::Headers,
176            0x03 => Self::GetBlockBodies,
177            0x04 => Self::BlockBodies,
178            0x05 => Self::GetBytecode,
179            0x06 => Self::Bytecode,
180            0x07 => Self::GetWitness,
181            0x08 => Self::Witness,
182            _ => return Err(alloy_rlp::Error::Custom("Invalid message type")),
183        };
184        buf.advance(1);
185        Ok(id)
186    }
187}
188
189/// Represents a message in the ress protocol.
190#[derive(PartialEq, Eq, Clone, Debug)]
191#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
192pub enum RessMessage {
193    /// Represents a node type message required for handshake.
194    NodeType(NodeType),
195
196    /// Represents a headers request message.
197    GetHeaders(RequestPair<GetHeaders>),
198    /// Represents a headers response message.
199    Headers(RequestPair<Vec<Header>>),
200
201    /// Represents a block bodies request message.
202    GetBlockBodies(RequestPair<Vec<B256>>),
203    /// Represents a block bodies response message.
204    BlockBodies(RequestPair<Vec<BlockBody>>),
205
206    /// Represents a bytecode request message.
207    GetBytecode(RequestPair<B256>),
208    /// Represents a bytecode response message.
209    Bytecode(RequestPair<Bytes>),
210
211    /// Represents a witness request message.
212    GetWitness(RequestPair<BlockHash>),
213    /// Represents a witness response message.
214    Witness(RequestPair<Vec<Bytes>>),
215}
216
217impl RessMessage {
218    /// Return [`RessMessageID`] that corresponds to the given message.
219    pub const fn message_id(&self) -> RessMessageID {
220        match self {
221            Self::NodeType(_) => RessMessageID::NodeType,
222            Self::GetHeaders(_) => RessMessageID::GetHeaders,
223            Self::Headers(_) => RessMessageID::Headers,
224            Self::GetBlockBodies(_) => RessMessageID::GetBlockBodies,
225            Self::BlockBodies(_) => RessMessageID::BlockBodies,
226            Self::GetBytecode(_) => RessMessageID::GetBytecode,
227            Self::Bytecode(_) => RessMessageID::Bytecode,
228            Self::GetWitness(_) => RessMessageID::GetWitness,
229            Self::Witness(_) => RessMessageID::Witness,
230        }
231    }
232
233    /// Convert message into [`RessProtocolMessage`].
234    pub const fn into_protocol_message(self) -> RessProtocolMessage {
235        let message_type = self.message_id();
236        RessProtocolMessage { message_type, message: self }
237    }
238}
239
240impl From<RessMessage> for RessProtocolMessage {
241    fn from(value: RessMessage) -> Self {
242        value.into_protocol_message()
243    }
244}
245
246impl Encodable for RessMessage {
247    fn encode(&self, out: &mut dyn BufMut) {
248        match self {
249            Self::NodeType(node_type) => node_type.encode(out),
250            Self::GetHeaders(request) => request.encode(out),
251            Self::Headers(header) => header.encode(out),
252            Self::GetBlockBodies(request) => request.encode(out),
253            Self::BlockBodies(body) => body.encode(out),
254            Self::GetBytecode(request) | Self::GetWitness(request) => request.encode(out),
255            Self::Bytecode(bytecode) => bytecode.encode(out),
256            Self::Witness(witness) => witness.encode(out),
257        }
258    }
259
260    fn length(&self) -> usize {
261        match self {
262            Self::NodeType(node_type) => node_type.length(),
263            Self::GetHeaders(request) => request.length(),
264            Self::Headers(header) => header.length(),
265            Self::GetBlockBodies(request) => request.length(),
266            Self::BlockBodies(body) => body.length(),
267            Self::GetBytecode(request) | Self::GetWitness(request) => request.length(),
268            Self::Bytecode(bytecode) => bytecode.length(),
269            Self::Witness(witness) => witness.length(),
270        }
271    }
272}
273
274/// A request for a peer to return block headers starting at the requested block.
275/// The peer must return at most [`limit`](#structfield.limit) headers.
276/// The headers will be returned starting at [`start_hash`](#structfield.start_hash), traversing
277/// towards the genesis block.
278#[derive(PartialEq, Eq, Clone, Copy, Debug, RlpEncodable, RlpDecodable)]
279#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
280pub struct GetHeaders {
281    /// The block hash that the peer should start returning headers from.
282    pub start_hash: BlockHash,
283
284    /// The maximum number of headers to return.
285    pub limit: u64,
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use proptest::prelude::*;
292    use proptest_arbitrary_interop::arb;
293    use std::fmt;
294    use strum::EnumCount;
295
296    fn rlp_roundtrip<V>(value: V)
297    where
298        V: Encodable + Decodable + PartialEq + fmt::Debug,
299    {
300        let encoded = alloy_rlp::encode(&value);
301        let decoded = V::decode(&mut &encoded[..]);
302        assert_eq!(Ok(value), decoded);
303    }
304
305    #[test]
306    fn protocol_message_count() {
307        let protocol = RessProtocolMessage::protocol();
308        assert_eq!(protocol.messages(), RessMessageID::COUNT as u8);
309    }
310
311    proptest! {
312        #[test]
313        fn message_type_roundtrip(message_type in arb::<RessMessageID>()) {
314            rlp_roundtrip(message_type);
315        }
316
317        #[test]
318        fn message_roundtrip(message in arb::<RessProtocolMessage>()) {
319            let encoded = alloy_rlp::encode(&message);
320            let decoded = RessProtocolMessage::decode_message(&mut &encoded[..]);
321            assert_eq!(Ok(message), decoded);
322        }
323    }
324}