1use alloc::vec::Vec;
9use alloy_primitives::{Bytes, B256};
10use alloy_rlp::{Decodable, Encodable, RlpDecodable, RlpEncodable};
11use reth_codecs_derive::add_arbitrary_tests;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SnapMessageId {
16 GetAccountRange = 0x00,
18 AccountRange = 0x01,
21 GetStorageRanges = 0x02,
23 StorageRanges = 0x03,
25 GetByteCodes = 0x04,
27 ByteCodes = 0x05,
29 GetTrieNodes = 0x06,
31 TrieNodes = 0x07,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
38#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
39#[add_arbitrary_tests(rlp)]
40pub struct GetAccountRangeMessage {
41 pub request_id: u64,
43 pub root_hash: B256,
45 pub starting_hash: B256,
47 pub limit_hash: B256,
49 pub response_bytes: u64,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
55#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
56#[add_arbitrary_tests(rlp)]
57pub struct AccountData {
58 pub hash: B256,
60 pub body: Bytes,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
67#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
68#[add_arbitrary_tests(rlp)]
69pub struct AccountRangeMessage {
70 pub request_id: u64,
72 pub accounts: Vec<AccountData>,
74 pub proof: Vec<Bytes>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
81#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
82#[add_arbitrary_tests(rlp)]
83pub struct GetStorageRangesMessage {
84 pub request_id: u64,
86 pub root_hash: B256,
88 pub account_hashes: Vec<B256>,
90 pub starting_hash: B256,
92 pub limit_hash: B256,
94 pub response_bytes: u64,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
100#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
101#[add_arbitrary_tests(rlp)]
102pub struct StorageData {
103 pub hash: B256,
105 pub data: Bytes,
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
114#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
115#[add_arbitrary_tests(rlp)]
116pub struct StorageRangesMessage {
117 pub request_id: u64,
119 pub slots: Vec<Vec<StorageData>>,
121 pub proof: Vec<Bytes>,
123}
124
125#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
128#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
129#[add_arbitrary_tests(rlp)]
130pub struct GetByteCodesMessage {
131 pub request_id: u64,
133 pub hashes: Vec<B256>,
135 pub response_bytes: u64,
137}
138
139#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
142#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
143#[add_arbitrary_tests(rlp)]
144pub struct ByteCodesMessage {
145 pub request_id: u64,
147 pub codes: Vec<Bytes>,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
153#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
154#[add_arbitrary_tests(rlp)]
155pub struct TriePath {
156 pub account_path: Bytes,
158 pub slot_paths: Vec<Bytes>,
160}
161
162#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
165#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
166#[add_arbitrary_tests(rlp)]
167pub struct GetTrieNodesMessage {
168 pub request_id: u64,
170 pub root_hash: B256,
172 pub paths: Vec<TriePath>,
174 pub response_bytes: u64,
176}
177
178#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
181#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
182#[add_arbitrary_tests(rlp)]
183pub struct TrieNodesMessage {
184 pub request_id: u64,
186 pub nodes: Vec<Bytes>,
188}
189
190#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum SnapProtocolMessage {
193 GetAccountRange(GetAccountRangeMessage),
195 AccountRange(AccountRangeMessage),
197 GetStorageRanges(GetStorageRangesMessage),
199 StorageRanges(StorageRangesMessage),
201 GetByteCodes(GetByteCodesMessage),
203 ByteCodes(ByteCodesMessage),
205 GetTrieNodes(GetTrieNodesMessage),
207 TrieNodes(TrieNodesMessage),
209}
210
211impl SnapProtocolMessage {
212 pub const fn message_id(&self) -> SnapMessageId {
216 match self {
217 Self::GetAccountRange(_) => SnapMessageId::GetAccountRange,
218 Self::AccountRange(_) => SnapMessageId::AccountRange,
219 Self::GetStorageRanges(_) => SnapMessageId::GetStorageRanges,
220 Self::StorageRanges(_) => SnapMessageId::StorageRanges,
221 Self::GetByteCodes(_) => SnapMessageId::GetByteCodes,
222 Self::ByteCodes(_) => SnapMessageId::ByteCodes,
223 Self::GetTrieNodes(_) => SnapMessageId::GetTrieNodes,
224 Self::TrieNodes(_) => SnapMessageId::TrieNodes,
225 }
226 }
227
228 pub fn encode(&self) -> Bytes {
230 let mut buf = Vec::new();
231 buf.push(self.message_id() as u8);
233
234 match self {
236 Self::GetAccountRange(msg) => msg.encode(&mut buf),
237 Self::AccountRange(msg) => msg.encode(&mut buf),
238 Self::GetStorageRanges(msg) => msg.encode(&mut buf),
239 Self::StorageRanges(msg) => msg.encode(&mut buf),
240 Self::GetByteCodes(msg) => msg.encode(&mut buf),
241 Self::ByteCodes(msg) => msg.encode(&mut buf),
242 Self::GetTrieNodes(msg) => msg.encode(&mut buf),
243 Self::TrieNodes(msg) => msg.encode(&mut buf),
244 }
245
246 Bytes::from(buf)
247 }
248
249 pub fn decode(message_id: u8, buf: &mut &[u8]) -> Result<Self, alloy_rlp::Error> {
251 macro_rules! decode_snap_message_variant {
253 ($message_id:expr, $buf:expr, $id:expr, $variant:ident, $msg_type:ty) => {
254 if $message_id == $id as u8 {
255 return Ok(Self::$variant(<$msg_type>::decode($buf)?));
256 }
257 };
258 }
259
260 decode_snap_message_variant!(
262 message_id,
263 buf,
264 SnapMessageId::GetAccountRange,
265 GetAccountRange,
266 GetAccountRangeMessage
267 );
268 decode_snap_message_variant!(
269 message_id,
270 buf,
271 SnapMessageId::AccountRange,
272 AccountRange,
273 AccountRangeMessage
274 );
275 decode_snap_message_variant!(
276 message_id,
277 buf,
278 SnapMessageId::GetStorageRanges,
279 GetStorageRanges,
280 GetStorageRangesMessage
281 );
282 decode_snap_message_variant!(
283 message_id,
284 buf,
285 SnapMessageId::StorageRanges,
286 StorageRanges,
287 StorageRangesMessage
288 );
289 decode_snap_message_variant!(
290 message_id,
291 buf,
292 SnapMessageId::GetByteCodes,
293 GetByteCodes,
294 GetByteCodesMessage
295 );
296 decode_snap_message_variant!(
297 message_id,
298 buf,
299 SnapMessageId::ByteCodes,
300 ByteCodes,
301 ByteCodesMessage
302 );
303 decode_snap_message_variant!(
304 message_id,
305 buf,
306 SnapMessageId::GetTrieNodes,
307 GetTrieNodes,
308 GetTrieNodesMessage
309 );
310 decode_snap_message_variant!(
311 message_id,
312 buf,
313 SnapMessageId::TrieNodes,
314 TrieNodes,
315 TrieNodesMessage
316 );
317
318 Err(alloy_rlp::Error::Custom("Unknown message ID"))
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn b256_from_u64(value: u64) -> B256 {
328 B256::left_padding_from(&value.to_be_bytes())
329 }
330
331 fn test_roundtrip(original: SnapProtocolMessage) {
333 let encoded = original.encode();
334
335 assert_eq!(encoded[0], original.message_id() as u8);
337
338 let mut buf = &encoded[1..];
339 let decoded = SnapProtocolMessage::decode(encoded[0], &mut buf).unwrap();
340
341 assert_eq!(decoded, original);
343 }
344
345 #[test]
346 fn test_all_message_roundtrips() {
347 test_roundtrip(SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
348 request_id: 42,
349 root_hash: b256_from_u64(123),
350 starting_hash: b256_from_u64(456),
351 limit_hash: b256_from_u64(789),
352 response_bytes: 1024,
353 }));
354
355 test_roundtrip(SnapProtocolMessage::AccountRange(AccountRangeMessage {
356 request_id: 42,
357 accounts: vec![AccountData {
358 hash: b256_from_u64(123),
359 body: Bytes::from(vec![1, 2, 3]),
360 }],
361 proof: vec![Bytes::from(vec![4, 5, 6])],
362 }));
363
364 test_roundtrip(SnapProtocolMessage::GetStorageRanges(GetStorageRangesMessage {
365 request_id: 42,
366 root_hash: b256_from_u64(123),
367 account_hashes: vec![b256_from_u64(456)],
368 starting_hash: b256_from_u64(789),
369 limit_hash: b256_from_u64(101112),
370 response_bytes: 2048,
371 }));
372
373 test_roundtrip(SnapProtocolMessage::StorageRanges(StorageRangesMessage {
374 request_id: 42,
375 slots: vec![vec![StorageData {
376 hash: b256_from_u64(123),
377 data: Bytes::from(vec![1, 2, 3]),
378 }]],
379 proof: vec![Bytes::from(vec![4, 5, 6])],
380 }));
381
382 test_roundtrip(SnapProtocolMessage::GetByteCodes(GetByteCodesMessage {
383 request_id: 42,
384 hashes: vec![b256_from_u64(123)],
385 response_bytes: 1024,
386 }));
387
388 test_roundtrip(SnapProtocolMessage::ByteCodes(ByteCodesMessage {
389 request_id: 42,
390 codes: vec![Bytes::from(vec![1, 2, 3])],
391 }));
392
393 test_roundtrip(SnapProtocolMessage::GetTrieNodes(GetTrieNodesMessage {
394 request_id: 42,
395 root_hash: b256_from_u64(123),
396 paths: vec![TriePath {
397 account_path: Bytes::from(vec![1, 2, 3]),
398 slot_paths: vec![Bytes::from(vec![4, 5, 6])],
399 }],
400 response_bytes: 1024,
401 }));
402
403 test_roundtrip(SnapProtocolMessage::TrieNodes(TrieNodesMessage {
404 request_id: 42,
405 nodes: vec![Bytes::from(vec![1, 2, 3])],
406 }));
407 }
408
409 #[test]
410 fn test_unknown_message_id() {
411 let data = Bytes::from(vec![1, 2, 3, 4]);
413 let mut buf = data.as_ref();
414
415 let result = SnapProtocolMessage::decode(255, &mut buf);
417
418 assert!(result.is_err());
419 if let Err(e) = result {
420 assert_eq!(e.to_string(), "Unknown message ID");
421 }
422 }
423}