reth_eth_wire_types/
snap.rs

1//! Implements Ethereum SNAP message types.
2//! Snap protocol runs on top of `RLPx`
3//! facilitating the exchange of Ethereum state snapshots between peers
4//! Reference: [Ethereum Snapshot Protocol](https://github.com/ethereum/devp2p/blob/master/caps/snap.md#protocol-messages)
5//!
6//! Current version: snap/1
7
8use alloc::vec::Vec;
9use alloy_primitives::{Bytes, B256};
10use alloy_rlp::{Decodable, Encodable, RlpDecodable, RlpEncodable};
11use reth_codecs_derive::add_arbitrary_tests;
12
13/// Message IDs for the snap sync protocol
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SnapMessageId {
16    /// Requests of an unknown number of accounts from a given account trie.
17    GetAccountRange = 0x00,
18    /// Response with the number of consecutive accounts and the Merkle proofs for the entire
19    /// range.
20    AccountRange = 0x01,
21    /// Requests for the storage slots of multiple accounts' storage tries.
22    GetStorageRanges = 0x02,
23    /// Response for the number of consecutive storage slots for the requested account.
24    StorageRanges = 0x03,
25    /// Request of the number of contract byte-codes by hash.
26    GetByteCodes = 0x04,
27    /// Response for the number of requested contract codes.
28    ByteCodes = 0x05,
29    /// Request of the number of state (either account or storage) Merkle trie nodes by path.
30    GetTrieNodes = 0x06,
31    /// Response for the number of requested state trie nodes.
32    TrieNodes = 0x07,
33}
34
35/// Request for a range of accounts from the state trie.
36// https://github.com/ethereum/devp2p/blob/master/caps/snap.md#getaccountrange-0x00
37#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
38#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
39#[add_arbitrary_tests(rlp)]
40pub struct GetAccountRangeMessage {
41    /// Request ID to match up responses with
42    pub request_id: u64,
43    /// Root hash of the account trie to serve
44    pub root_hash: B256,
45    /// Account hash of the first to retrieve
46    pub starting_hash: B256,
47    /// Account hash after which to stop serving data
48    pub limit_hash: B256,
49    /// Soft limit at which to stop returning data
50    pub response_bytes: u64,
51}
52
53/// Account data in the response.
54#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
55#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
56#[add_arbitrary_tests(rlp)]
57pub struct AccountData {
58    /// Hash of the account address (trie path)
59    pub hash: B256,
60    /// Account body in slim format
61    pub body: Bytes,
62}
63
64/// Response containing a number of consecutive accounts and the Merkle proofs for the entire range.
65// http://github.com/ethereum/devp2p/blob/master/caps/snap.md#accountrange-0x01
66#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
67#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
68#[add_arbitrary_tests(rlp)]
69pub struct AccountRangeMessage {
70    /// ID of the request this is a response for
71    pub request_id: u64,
72    /// List of consecutive accounts from the trie
73    pub accounts: Vec<AccountData>,
74    /// List of trie nodes proving the account range
75    pub proof: Vec<Bytes>,
76}
77
78/// Request for the storage slots of multiple accounts' storage tries.
79// https://github.com/ethereum/devp2p/blob/master/caps/snap.md#getstorageranges-0x02
80#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
81#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
82#[add_arbitrary_tests(rlp)]
83pub struct GetStorageRangesMessage {
84    /// Request ID to match up responses with
85    pub request_id: u64,
86    /// Root hash of the account trie to serve
87    pub root_hash: B256,
88    /// Account hashes of the storage tries to serve
89    pub account_hashes: Vec<B256>,
90    /// Storage slot hash of the first to retrieve
91    pub starting_hash: B256,
92    /// Storage slot hash after which to stop serving
93    pub limit_hash: B256,
94    /// Soft limit at which to stop returning data
95    pub response_bytes: u64,
96}
97
98/// Storage slot data in the response.
99#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
100#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
101#[add_arbitrary_tests(rlp)]
102pub struct StorageData {
103    /// Hash of the storage slot key (trie path)
104    pub hash: B256,
105    /// Data content of the slot
106    pub data: Bytes,
107}
108
109/// Response containing a number of consecutive storage slots for the requested account
110/// and optionally the merkle proofs for the last range (boundary proofs) if it only partially
111/// covers the storage trie.
112// https://github.com/ethereum/devp2p/blob/master/caps/snap.md#storageranges-0x03
113#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
114#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
115#[add_arbitrary_tests(rlp)]
116pub struct StorageRangesMessage {
117    /// ID of the request this is a response for
118    pub request_id: u64,
119    /// List of list of consecutive slots from the trie (one list per account)
120    pub slots: Vec<Vec<StorageData>>,
121    /// List of trie nodes proving the slot range (if partial)
122    pub proof: Vec<Bytes>,
123}
124
125/// Request to get a number of requested contract codes.
126// https://github.com/ethereum/devp2p/blob/master/caps/snap.md#getbytecodes-0x04
127#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
128#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
129#[add_arbitrary_tests(rlp)]
130pub struct GetByteCodesMessage {
131    /// Request ID to match up responses with
132    pub request_id: u64,
133    /// Code hashes to retrieve the code for
134    pub hashes: Vec<B256>,
135    /// Soft limit at which to stop returning data (in bytes)
136    pub response_bytes: u64,
137}
138
139/// Response containing a number of requested contract codes.
140// https://github.com/ethereum/devp2p/blob/master/caps/snap.md#bytecodes-0x05
141#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
142#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
143#[add_arbitrary_tests(rlp)]
144pub struct ByteCodesMessage {
145    /// ID of the request this is a response for
146    pub request_id: u64,
147    /// The requested bytecodes in order
148    pub codes: Vec<Bytes>,
149}
150
151/// Path in the trie for an account and its storage
152#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
153#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
154#[add_arbitrary_tests(rlp)]
155pub struct TriePath {
156    /// Path in the account trie
157    pub account_path: Bytes,
158    /// Paths in the storage trie
159    pub slot_paths: Vec<Bytes>,
160}
161
162/// Request a number of state (either account or storage) Merkle trie nodes by path
163// https://github.com/ethereum/devp2p/blob/master/caps/snap.md#gettrienodes-0x06
164#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
165#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
166#[add_arbitrary_tests(rlp)]
167pub struct GetTrieNodesMessage {
168    /// Request ID to match up responses with
169    pub request_id: u64,
170    /// Root hash of the account trie to serve
171    pub root_hash: B256,
172    /// Trie paths to retrieve the nodes for, grouped by account
173    pub paths: Vec<TriePath>,
174    /// Soft limit at which to stop returning data (in bytes)
175    pub response_bytes: u64,
176}
177
178/// Response containing a number of requested state trie nodes
179// https://github.com/ethereum/devp2p/blob/master/caps/snap.md#trienodes-0x07
180#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
181#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
182#[add_arbitrary_tests(rlp)]
183pub struct TrieNodesMessage {
184    /// ID of the request this is a response for
185    pub request_id: u64,
186    /// The requested trie nodes in order
187    pub nodes: Vec<Bytes>,
188}
189
190/// Represents all types of messages in the snap sync protocol.
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum SnapProtocolMessage {
193    /// Request for an account range - see [`GetAccountRangeMessage`]
194    GetAccountRange(GetAccountRangeMessage),
195    /// Response with accounts and proofs - see [`AccountRangeMessage`]
196    AccountRange(AccountRangeMessage),
197    /// Request for storage slots - see [`GetStorageRangesMessage`]
198    GetStorageRanges(GetStorageRangesMessage),
199    /// Response with storage slots - see [`StorageRangesMessage`]
200    StorageRanges(StorageRangesMessage),
201    /// Request for contract bytecodes - see [`GetByteCodesMessage`]
202    GetByteCodes(GetByteCodesMessage),
203    /// Response with contract codes - see [`ByteCodesMessage`]
204    ByteCodes(ByteCodesMessage),
205    /// Request for trie nodes - see [`GetTrieNodesMessage`]
206    GetTrieNodes(GetTrieNodesMessage),
207    /// Response with trie nodes - see [`TrieNodesMessage`]
208    TrieNodes(TrieNodesMessage),
209}
210
211impl SnapProtocolMessage {
212    /// Returns the protocol message ID for this message type.
213    ///
214    /// The message ID is used in the `RLPx` protocol to identify different types of messages.
215    pub const fn message_id(&self) -> SnapMessageId {
216        match self {
217            Self::GetAccountRange(_) => SnapMessageId::GetAccountRange,
218            Self::AccountRange(_) => SnapMessageId::AccountRange,
219            Self::GetStorageRanges(_) => SnapMessageId::GetStorageRanges,
220            Self::StorageRanges(_) => SnapMessageId::StorageRanges,
221            Self::GetByteCodes(_) => SnapMessageId::GetByteCodes,
222            Self::ByteCodes(_) => SnapMessageId::ByteCodes,
223            Self::GetTrieNodes(_) => SnapMessageId::GetTrieNodes,
224            Self::TrieNodes(_) => SnapMessageId::TrieNodes,
225        }
226    }
227
228    /// Encode the message to bytes
229    pub fn encode(&self) -> Bytes {
230        let mut buf = Vec::new();
231        // Add message ID as first byte
232        buf.push(self.message_id() as u8);
233
234        // Encode the message body based on its type
235        match self {
236            Self::GetAccountRange(msg) => msg.encode(&mut buf),
237            Self::AccountRange(msg) => msg.encode(&mut buf),
238            Self::GetStorageRanges(msg) => msg.encode(&mut buf),
239            Self::StorageRanges(msg) => msg.encode(&mut buf),
240            Self::GetByteCodes(msg) => msg.encode(&mut buf),
241            Self::ByteCodes(msg) => msg.encode(&mut buf),
242            Self::GetTrieNodes(msg) => msg.encode(&mut buf),
243            Self::TrieNodes(msg) => msg.encode(&mut buf),
244        }
245
246        Bytes::from(buf)
247    }
248
249    /// Decodes a SNAP protocol message from its message ID and RLP-encoded body.
250    pub fn decode(message_id: u8, buf: &mut &[u8]) -> Result<Self, alloy_rlp::Error> {
251        // Decoding protocol message variants based on message ID
252        macro_rules! decode_snap_message_variant {
253            ($message_id:expr, $buf:expr, $id:expr, $variant:ident, $msg_type:ty) => {
254                if $message_id == $id as u8 {
255                    return Ok(Self::$variant(<$msg_type>::decode($buf)?));
256                }
257            };
258        }
259
260        // Try to decode each message type based on the message ID
261        decode_snap_message_variant!(
262            message_id,
263            buf,
264            SnapMessageId::GetAccountRange,
265            GetAccountRange,
266            GetAccountRangeMessage
267        );
268        decode_snap_message_variant!(
269            message_id,
270            buf,
271            SnapMessageId::AccountRange,
272            AccountRange,
273            AccountRangeMessage
274        );
275        decode_snap_message_variant!(
276            message_id,
277            buf,
278            SnapMessageId::GetStorageRanges,
279            GetStorageRanges,
280            GetStorageRangesMessage
281        );
282        decode_snap_message_variant!(
283            message_id,
284            buf,
285            SnapMessageId::StorageRanges,
286            StorageRanges,
287            StorageRangesMessage
288        );
289        decode_snap_message_variant!(
290            message_id,
291            buf,
292            SnapMessageId::GetByteCodes,
293            GetByteCodes,
294            GetByteCodesMessage
295        );
296        decode_snap_message_variant!(
297            message_id,
298            buf,
299            SnapMessageId::ByteCodes,
300            ByteCodes,
301            ByteCodesMessage
302        );
303        decode_snap_message_variant!(
304            message_id,
305            buf,
306            SnapMessageId::GetTrieNodes,
307            GetTrieNodes,
308            GetTrieNodesMessage
309        );
310        decode_snap_message_variant!(
311            message_id,
312            buf,
313            SnapMessageId::TrieNodes,
314            TrieNodes,
315            TrieNodesMessage
316        );
317
318        Err(alloy_rlp::Error::Custom("Unknown message ID"))
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    // Helper function to create a B256 from a u64 for testing
327    fn b256_from_u64(value: u64) -> B256 {
328        B256::left_padding_from(&value.to_be_bytes())
329    }
330
331    // Helper function to test roundtrip encoding/decoding
332    fn test_roundtrip(original: SnapProtocolMessage) {
333        let encoded = original.encode();
334
335        // Verify the first byte matches the expected message ID
336        assert_eq!(encoded[0], original.message_id() as u8);
337
338        let mut buf = &encoded[1..];
339        let decoded = SnapProtocolMessage::decode(encoded[0], &mut buf).unwrap();
340
341        // Verify the match
342        assert_eq!(decoded, original);
343    }
344
345    #[test]
346    fn test_all_message_roundtrips() {
347        test_roundtrip(SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
348            request_id: 42,
349            root_hash: b256_from_u64(123),
350            starting_hash: b256_from_u64(456),
351            limit_hash: b256_from_u64(789),
352            response_bytes: 1024,
353        }));
354
355        test_roundtrip(SnapProtocolMessage::AccountRange(AccountRangeMessage {
356            request_id: 42,
357            accounts: vec![AccountData {
358                hash: b256_from_u64(123),
359                body: Bytes::from(vec![1, 2, 3]),
360            }],
361            proof: vec![Bytes::from(vec![4, 5, 6])],
362        }));
363
364        test_roundtrip(SnapProtocolMessage::GetStorageRanges(GetStorageRangesMessage {
365            request_id: 42,
366            root_hash: b256_from_u64(123),
367            account_hashes: vec![b256_from_u64(456)],
368            starting_hash: b256_from_u64(789),
369            limit_hash: b256_from_u64(101112),
370            response_bytes: 2048,
371        }));
372
373        test_roundtrip(SnapProtocolMessage::StorageRanges(StorageRangesMessage {
374            request_id: 42,
375            slots: vec![vec![StorageData {
376                hash: b256_from_u64(123),
377                data: Bytes::from(vec![1, 2, 3]),
378            }]],
379            proof: vec![Bytes::from(vec![4, 5, 6])],
380        }));
381
382        test_roundtrip(SnapProtocolMessage::GetByteCodes(GetByteCodesMessage {
383            request_id: 42,
384            hashes: vec![b256_from_u64(123)],
385            response_bytes: 1024,
386        }));
387
388        test_roundtrip(SnapProtocolMessage::ByteCodes(ByteCodesMessage {
389            request_id: 42,
390            codes: vec![Bytes::from(vec![1, 2, 3])],
391        }));
392
393        test_roundtrip(SnapProtocolMessage::GetTrieNodes(GetTrieNodesMessage {
394            request_id: 42,
395            root_hash: b256_from_u64(123),
396            paths: vec![TriePath {
397                account_path: Bytes::from(vec![1, 2, 3]),
398                slot_paths: vec![Bytes::from(vec![4, 5, 6])],
399            }],
400            response_bytes: 1024,
401        }));
402
403        test_roundtrip(SnapProtocolMessage::TrieNodes(TrieNodesMessage {
404            request_id: 42,
405            nodes: vec![Bytes::from(vec![1, 2, 3])],
406        }));
407    }
408
409    #[test]
410    fn test_unknown_message_id() {
411        // Create some random data
412        let data = Bytes::from(vec![1, 2, 3, 4]);
413        let mut buf = data.as_ref();
414
415        // Try to decode with an invalid message ID
416        let result = SnapProtocolMessage::decode(255, &mut buf);
417
418        assert!(result.is_err());
419        if let Err(e) = result {
420            assert_eq!(e.to_string(), "Unknown message ID");
421        }
422    }
423}