1use crate::BlockAccessLists;
9use alloc::vec::Vec;
10use alloy_primitives::{Bytes, B256};
11use alloy_rlp::{Decodable, Encodable, RlpDecodable, RlpEncodable};
12use reth_codecs_derive::add_arbitrary_tests;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17#[repr(u8)]
18pub enum SnapVersion {
19 #[default]
21 V1 = 1,
22 V2 = 2,
24}
25
26impl SnapVersion {
27 pub const fn message_count(self) -> u8 {
29 match self {
30 Self::V1 => 8,
31 Self::V2 => 10,
32 }
33 }
34
35 pub const fn max_message_id(self) -> u8 {
37 self.message_count() - 1
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum SnapMessageId {
44 GetAccountRange = 0x00,
46 AccountRange = 0x01,
49 GetStorageRanges = 0x02,
51 StorageRanges = 0x03,
53 GetByteCodes = 0x04,
55 ByteCodes = 0x05,
57 GetTrieNodes = 0x06,
61 TrieNodes = 0x07,
65 GetBlockAccessLists = 0x08,
69 BlockAccessLists = 0x09,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
78#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
79#[add_arbitrary_tests(rlp)]
80pub struct GetAccountRangeMessage {
81 pub request_id: u64,
83 pub root_hash: B256,
85 pub starting_hash: B256,
87 pub limit_hash: B256,
89 pub response_bytes: u64,
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
95#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
96#[add_arbitrary_tests(rlp)]
97pub struct AccountData {
98 pub hash: B256,
100 pub body: Bytes,
102}
103
104#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
107#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
108#[add_arbitrary_tests(rlp)]
109pub struct AccountRangeMessage {
110 pub request_id: u64,
112 pub accounts: Vec<AccountData>,
114 pub proof: Vec<Bytes>,
116}
117
118#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
121#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
122#[add_arbitrary_tests(rlp)]
123pub struct GetStorageRangesMessage {
124 pub request_id: u64,
126 pub root_hash: B256,
128 pub account_hashes: Vec<B256>,
130 pub starting_hash: B256,
132 pub limit_hash: B256,
134 pub response_bytes: u64,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
140#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
141#[add_arbitrary_tests(rlp)]
142pub struct StorageData {
143 pub hash: B256,
145 pub data: Bytes,
147}
148
149#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
154#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
155#[add_arbitrary_tests(rlp)]
156pub struct StorageRangesMessage {
157 pub request_id: u64,
159 pub slots: Vec<Vec<StorageData>>,
161 pub proof: Vec<Bytes>,
163}
164
165#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
168#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
169#[add_arbitrary_tests(rlp)]
170pub struct GetByteCodesMessage {
171 pub request_id: u64,
173 pub hashes: Vec<B256>,
175 pub response_bytes: u64,
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
182#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
183#[add_arbitrary_tests(rlp)]
184pub struct ByteCodesMessage {
185 pub request_id: u64,
187 pub codes: Vec<Bytes>,
189}
190
191#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
193#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
194#[add_arbitrary_tests(rlp)]
195pub struct TriePath {
196 pub account_path: Bytes,
198 pub slot_paths: Vec<Bytes>,
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
205#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
206#[add_arbitrary_tests(rlp)]
207pub struct GetTrieNodesMessage {
208 pub request_id: u64,
210 pub root_hash: B256,
212 pub paths: Vec<TriePath>,
214 pub response_bytes: u64,
216}
217
218#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
221#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
222#[add_arbitrary_tests(rlp)]
223pub struct TrieNodesMessage {
224 pub request_id: u64,
226 pub nodes: Vec<Bytes>,
228}
229
230#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
232#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
233#[add_arbitrary_tests(rlp)]
234pub struct GetBlockAccessListsMessage {
235 pub request_id: u64,
237 pub block_hashes: Vec<B256>,
239 pub response_bytes: u64,
241}
242
243#[derive(Debug, Clone, PartialEq, Eq, RlpEncodable, RlpDecodable)]
245#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
246#[add_arbitrary_tests(rlp)]
247pub struct BlockAccessListsMessage {
248 pub request_id: u64,
250 pub block_access_lists: BlockAccessLists,
252}
253
254#[derive(Debug, Clone, PartialEq, Eq)]
256pub enum SnapProtocolMessage {
257 GetAccountRange(GetAccountRangeMessage),
259 AccountRange(AccountRangeMessage),
261 GetStorageRanges(GetStorageRangesMessage),
263 StorageRanges(StorageRangesMessage),
265 GetByteCodes(GetByteCodesMessage),
267 ByteCodes(ByteCodesMessage),
269 GetTrieNodes(GetTrieNodesMessage),
273 TrieNodes(TrieNodesMessage),
277 GetBlockAccessLists(GetBlockAccessListsMessage),
281 BlockAccessLists(BlockAccessListsMessage),
285}
286
287impl SnapProtocolMessage {
288 pub const fn message_id(&self) -> SnapMessageId {
292 match self {
293 Self::GetAccountRange(_) => SnapMessageId::GetAccountRange,
294 Self::AccountRange(_) => SnapMessageId::AccountRange,
295 Self::GetStorageRanges(_) => SnapMessageId::GetStorageRanges,
296 Self::StorageRanges(_) => SnapMessageId::StorageRanges,
297 Self::GetByteCodes(_) => SnapMessageId::GetByteCodes,
298 Self::ByteCodes(_) => SnapMessageId::ByteCodes,
299 Self::GetTrieNodes(_) => SnapMessageId::GetTrieNodes,
300 Self::TrieNodes(_) => SnapMessageId::TrieNodes,
301 Self::GetBlockAccessLists(_) => SnapMessageId::GetBlockAccessLists,
302 Self::BlockAccessLists(_) => SnapMessageId::BlockAccessLists,
303 }
304 }
305
306 pub fn encode(&self) -> Bytes {
308 let mut buf = Vec::new();
309 buf.push(self.message_id() as u8);
311
312 match self {
314 Self::GetAccountRange(msg) => msg.encode(&mut buf),
315 Self::AccountRange(msg) => msg.encode(&mut buf),
316 Self::GetStorageRanges(msg) => msg.encode(&mut buf),
317 Self::StorageRanges(msg) => msg.encode(&mut buf),
318 Self::GetByteCodes(msg) => msg.encode(&mut buf),
319 Self::ByteCodes(msg) => msg.encode(&mut buf),
320 Self::GetTrieNodes(msg) => msg.encode(&mut buf),
321 Self::TrieNodes(msg) => msg.encode(&mut buf),
322 Self::GetBlockAccessLists(msg) => msg.encode(&mut buf),
323 Self::BlockAccessLists(msg) => msg.encode(&mut buf),
324 }
325
326 Bytes::from(buf)
327 }
328
329 pub fn decode(message_id: u8, buf: &mut &[u8]) -> Result<Self, alloy_rlp::Error> {
331 macro_rules! decode_snap_message_variant {
333 ($message_id:expr, $buf:expr, $id:expr, $variant:ident, $msg_type:ty) => {
334 if $message_id == $id as u8 {
335 return Ok(Self::$variant(<$msg_type>::decode($buf)?));
336 }
337 };
338 }
339
340 decode_snap_message_variant!(
342 message_id,
343 buf,
344 SnapMessageId::GetAccountRange,
345 GetAccountRange,
346 GetAccountRangeMessage
347 );
348 decode_snap_message_variant!(
349 message_id,
350 buf,
351 SnapMessageId::AccountRange,
352 AccountRange,
353 AccountRangeMessage
354 );
355 decode_snap_message_variant!(
356 message_id,
357 buf,
358 SnapMessageId::GetStorageRanges,
359 GetStorageRanges,
360 GetStorageRangesMessage
361 );
362 decode_snap_message_variant!(
363 message_id,
364 buf,
365 SnapMessageId::StorageRanges,
366 StorageRanges,
367 StorageRangesMessage
368 );
369 decode_snap_message_variant!(
370 message_id,
371 buf,
372 SnapMessageId::GetByteCodes,
373 GetByteCodes,
374 GetByteCodesMessage
375 );
376 decode_snap_message_variant!(
377 message_id,
378 buf,
379 SnapMessageId::ByteCodes,
380 ByteCodes,
381 ByteCodesMessage
382 );
383 decode_snap_message_variant!(
384 message_id,
385 buf,
386 SnapMessageId::GetTrieNodes,
387 GetTrieNodes,
388 GetTrieNodesMessage
389 );
390 decode_snap_message_variant!(
391 message_id,
392 buf,
393 SnapMessageId::TrieNodes,
394 TrieNodes,
395 TrieNodesMessage
396 );
397 decode_snap_message_variant!(
398 message_id,
399 buf,
400 SnapMessageId::GetBlockAccessLists,
401 GetBlockAccessLists,
402 GetBlockAccessListsMessage
403 );
404 decode_snap_message_variant!(
405 message_id,
406 buf,
407 SnapMessageId::BlockAccessLists,
408 BlockAccessLists,
409 BlockAccessListsMessage
410 );
411
412 Err(alloy_rlp::Error::Custom("Unknown message ID"))
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 fn b256_from_u64(value: u64) -> B256 {
422 B256::left_padding_from(&value.to_be_bytes())
423 }
424
425 fn test_roundtrip(original: SnapProtocolMessage) {
427 let encoded = original.encode();
428
429 assert_eq!(encoded[0], original.message_id() as u8);
431
432 let mut buf = &encoded[1..];
433 let decoded = SnapProtocolMessage::decode(encoded[0], &mut buf).unwrap();
434
435 assert_eq!(decoded, original);
437 }
438
439 #[test]
440 fn test_all_message_roundtrips() {
441 assert_eq!(SnapVersion::V1.message_count(), 8);
442 assert_eq!(SnapVersion::V2.message_count(), 10);
443
444 test_roundtrip(SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
445 request_id: 42,
446 root_hash: b256_from_u64(123),
447 starting_hash: b256_from_u64(456),
448 limit_hash: b256_from_u64(789),
449 response_bytes: 1024,
450 }));
451
452 test_roundtrip(SnapProtocolMessage::AccountRange(AccountRangeMessage {
453 request_id: 42,
454 accounts: vec![AccountData {
455 hash: b256_from_u64(123),
456 body: Bytes::from(vec![1, 2, 3]),
457 }],
458 proof: vec![Bytes::from(vec![4, 5, 6])],
459 }));
460
461 test_roundtrip(SnapProtocolMessage::GetStorageRanges(GetStorageRangesMessage {
462 request_id: 42,
463 root_hash: b256_from_u64(123),
464 account_hashes: vec![b256_from_u64(456)],
465 starting_hash: b256_from_u64(789),
466 limit_hash: b256_from_u64(101112),
467 response_bytes: 2048,
468 }));
469
470 test_roundtrip(SnapProtocolMessage::StorageRanges(StorageRangesMessage {
471 request_id: 42,
472 slots: vec![vec![StorageData {
473 hash: b256_from_u64(123),
474 data: Bytes::from(vec![1, 2, 3]),
475 }]],
476 proof: vec![Bytes::from(vec![4, 5, 6])],
477 }));
478
479 test_roundtrip(SnapProtocolMessage::GetByteCodes(GetByteCodesMessage {
480 request_id: 42,
481 hashes: vec![b256_from_u64(123)],
482 response_bytes: 1024,
483 }));
484
485 test_roundtrip(SnapProtocolMessage::ByteCodes(ByteCodesMessage {
486 request_id: 42,
487 codes: vec![Bytes::from(vec![1, 2, 3])],
488 }));
489
490 test_roundtrip(SnapProtocolMessage::GetTrieNodes(GetTrieNodesMessage {
491 request_id: 42,
492 root_hash: b256_from_u64(123),
493 paths: vec![TriePath {
494 account_path: Bytes::from(vec![1, 2, 3]),
495 slot_paths: vec![Bytes::from(vec![4, 5, 6])],
496 }],
497 response_bytes: 1024,
498 }));
499
500 test_roundtrip(SnapProtocolMessage::TrieNodes(TrieNodesMessage {
501 request_id: 42,
502 nodes: vec![Bytes::from(vec![1, 2, 3])],
503 }));
504
505 test_roundtrip(SnapProtocolMessage::GetBlockAccessLists(GetBlockAccessListsMessage {
506 request_id: 42,
507 block_hashes: vec![b256_from_u64(123), b256_from_u64(456)],
508 response_bytes: 4096,
509 }));
510
511 test_roundtrip(SnapProtocolMessage::BlockAccessLists(BlockAccessListsMessage {
512 request_id: 42,
513 block_access_lists: BlockAccessLists(vec![
514 Some(Bytes::from_static(&[alloy_rlp::EMPTY_LIST_CODE])),
515 Some(Bytes::from_static(&[0xc1, alloy_rlp::EMPTY_LIST_CODE])),
516 ]),
517 }));
518 }
519
520 #[test]
521 fn test_unknown_message_id() {
522 let data = Bytes::from(vec![1, 2, 3, 4]);
524 let mut buf = data.as_ref();
525
526 let result = SnapProtocolMessage::decode(255, &mut buf);
528
529 assert!(result.is_err());
530 if let Err(e) = result {
531 assert_eq!(e.to_string(), "Unknown message ID");
532 }
533 }
534}