1use super::{
10 broadcast::NewBlockHashes, BlockAccessLists, BlockBodies, BlockHeaders, GetBlockAccessLists,
11 GetBlockBodies, GetBlockHeaders, GetNodeData, GetPooledTransactions, GetReceipts,
12 GetReceipts70, NewPooledTransactionHashes66, NewPooledTransactionHashes68, NodeData,
13 PooledTransactions, Receipts, Status, StatusEth69, Transactions,
14};
15use crate::{
16 status::StatusMessage, BlockRangeUpdate, Cells, EthNetworkPrimitives, EthVersion, GetCells,
17 NetworkPrimitives, RawCapabilityMessage, Receipts69, Receipts70, SharedTransactions,
18};
19use alloc::{boxed::Box, string::String, sync::Arc};
20use alloy_primitives::{
21 bytes::{Buf, BufMut},
22 Bytes,
23};
24use alloy_rlp::{length_of_length, Decodable, Encodable, Header};
25use core::fmt::Debug;
26
27pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
30
31pub const TX_MEMORY_BUDGET_MULTIPLIER: usize = 2;
39
40#[derive(thiserror::Error, Debug)]
42pub enum MessageError {
43 #[error("message id {1:?} is invalid for version {0:?}")]
45 Invalid(EthVersion, EthMessageID),
46 #[error("expected status message but received {0:?}")]
48 ExpectedStatusMessage(EthMessageID),
49 #[error("RLP error: {0}")]
51 RlpError(#[from] alloy_rlp::Error),
52 #[error("{0}")]
54 Other(String),
55}
56
57#[derive(Clone, Debug, PartialEq, Eq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
61 pub message_type: EthMessageID,
63 #[cfg_attr(
65 feature = "serde",
66 serde(bound = "EthMessage<N>: serde::Serialize + serde::de::DeserializeOwned")
67 )]
68 pub message: EthMessage<N>,
69}
70
71impl<N: NetworkPrimitives> ProtocolMessage<N> {
72 pub fn decode_status(
77 version: EthVersion,
78 buf: &mut &[u8],
79 ) -> Result<StatusMessage, MessageError> {
80 let message_type = EthMessageID::decode(buf)?;
81
82 if message_type != EthMessageID::Status {
83 return Err(MessageError::ExpectedStatusMessage(message_type))
84 }
85
86 let status = if version < EthVersion::Eth69 {
87 StatusMessage::Legacy(Status::decode(buf)?)
88 } else {
89 StatusMessage::Eth69(StatusEth69::decode(buf)?)
90 };
91
92 Ok(status)
93 }
94
95 pub fn decode_message(version: EthVersion, buf: &mut &[u8]) -> Result<Self, MessageError> {
99 Self::decode_message_with_tx_memory_budget(version, buf, usize::MAX)
100 }
101
102 pub fn decode_message_with_tx_memory_budget(
108 version: EthVersion,
109 buf: &mut &[u8],
110 tx_memory_budget: usize,
111 ) -> Result<Self, MessageError> {
112 let message_type = EthMessageID::decode(buf)?;
113
114 let message = match message_type {
117 EthMessageID::Status => EthMessage::Status(if version < EthVersion::Eth69 {
118 StatusMessage::Legacy(Status::decode(buf)?)
119 } else {
120 StatusMessage::Eth69(StatusEth69::decode(buf)?)
121 }),
122 EthMessageID::NewBlockHashes => {
123 EthMessage::NewBlockHashes(NewBlockHashes::decode(buf)?)
124 }
125 EthMessageID::NewBlock => {
126 EthMessage::NewBlock(Box::new(N::NewBlockPayload::decode(buf)?))
127 }
128 EthMessageID::Transactions => EthMessage::Transactions(
129 Transactions::decode_with_memory_budget(buf, tx_memory_budget)?,
130 ),
131 EthMessageID::NewPooledTransactionHashes => {
132 if version >= EthVersion::Eth68 {
133 EthMessage::NewPooledTransactionHashes68(NewPooledTransactionHashes68::decode(
134 buf,
135 )?)
136 } else {
137 EthMessage::NewPooledTransactionHashes66(NewPooledTransactionHashes66::decode(
138 buf,
139 )?)
140 }
141 }
142 EthMessageID::GetBlockHeaders => EthMessage::GetBlockHeaders(RequestPair::decode(buf)?),
143 EthMessageID::BlockHeaders => EthMessage::BlockHeaders(RequestPair::decode(buf)?),
144 EthMessageID::GetBlockBodies => EthMessage::GetBlockBodies(RequestPair::decode(buf)?),
145 EthMessageID::BlockBodies => EthMessage::BlockBodies(RequestPair::decode(buf)?),
146 EthMessageID::GetPooledTransactions => {
147 EthMessage::GetPooledTransactions(RequestPair::decode(buf)?)
148 }
149 EthMessageID::PooledTransactions => {
150 EthMessage::PooledTransactions(RequestPair::decode_with(buf, |buf| {
151 PooledTransactions::decode_with_memory_budget(buf, tx_memory_budget)
152 })?)
153 }
154 EthMessageID::GetNodeData => {
155 if version >= EthVersion::Eth67 {
156 return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
157 }
158 EthMessage::GetNodeData(RequestPair::decode(buf)?)
159 }
160 EthMessageID::NodeData => {
161 if version >= EthVersion::Eth67 {
162 return Err(MessageError::Invalid(version, EthMessageID::NodeData))
163 }
164 EthMessage::NodeData(RequestPair::decode(buf)?)
165 }
166 EthMessageID::GetReceipts => {
167 if version >= EthVersion::Eth70 {
168 EthMessage::GetReceipts70(RequestPair::decode(buf)?)
169 } else {
170 EthMessage::GetReceipts(RequestPair::decode(buf)?)
171 }
172 }
173 EthMessageID::Receipts => {
174 match version {
175 v if v >= EthVersion::Eth70 => {
176 EthMessage::Receipts70(RequestPair::decode(buf)?)
180 }
181 EthVersion::Eth69 => {
182 EthMessage::Receipts69(RequestPair::decode(buf)?)
184 }
185 _ => {
186 EthMessage::Receipts(RequestPair::decode(buf)?)
188 }
189 }
190 }
191 EthMessageID::BlockRangeUpdate => {
192 if version < EthVersion::Eth69 {
193 return Err(MessageError::Invalid(version, EthMessageID::BlockRangeUpdate))
194 }
195 EthMessage::BlockRangeUpdate(BlockRangeUpdate::decode(buf)?)
196 }
197 EthMessageID::GetBlockAccessLists => {
198 if version < EthVersion::Eth71 {
199 return Err(MessageError::Invalid(version, EthMessageID::GetBlockAccessLists))
200 }
201 EthMessage::GetBlockAccessLists(RequestPair::decode(buf)?)
202 }
203 EthMessageID::BlockAccessLists => {
204 if version < EthVersion::Eth71 {
205 return Err(MessageError::Invalid(version, EthMessageID::BlockAccessLists))
206 }
207 EthMessage::BlockAccessLists(RequestPair::decode(buf)?)
208 }
209 EthMessageID::Cells => {
210 if version < EthVersion::Eth72 {
211 return Err(MessageError::Invalid(version, EthMessageID::Cells))
212 }
213 EthMessage::Cells(RequestPair::decode(buf)?)
214 }
215 EthMessageID::GetCells => {
216 if version < EthVersion::Eth72 {
217 return Err(MessageError::Invalid(version, EthMessageID::GetCells))
218 }
219 EthMessage::GetCells(RequestPair::decode(buf)?)
220 }
221 EthMessageID::Other(_) => {
222 let raw_payload = Bytes::copy_from_slice(buf);
223 buf.advance(raw_payload.len());
224 EthMessage::Other(RawCapabilityMessage::new(
225 message_type.to_u8() as usize,
226 raw_payload.into(),
227 ))
228 }
229 };
230 Ok(Self { message_type, message })
231 }
232}
233
234impl<N: NetworkPrimitives> Encodable for ProtocolMessage<N> {
235 fn encode(&self, out: &mut dyn BufMut) {
238 self.message_type.encode(out);
239 self.message.encode(out);
240 }
241 fn length(&self) -> usize {
242 self.message_type.length() + self.message.length()
243 }
244}
245
246impl<N: NetworkPrimitives> From<EthMessage<N>> for ProtocolMessage<N> {
247 fn from(message: EthMessage<N>) -> Self {
248 Self { message_type: message.message_id(), message }
249 }
250}
251
252#[derive(Clone, Debug)]
254pub struct ProtocolBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
255 pub message_type: EthMessageID,
257 pub message: EthBroadcastMessage<N>,
260}
261
262impl<N: NetworkPrimitives> Encodable for ProtocolBroadcastMessage<N> {
263 fn encode(&self, out: &mut dyn BufMut) {
266 self.message_type.encode(out);
267 self.message.encode(out);
268 }
269 fn length(&self) -> usize {
270 self.message_type.length() + self.message.length()
271 }
272}
273
274impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for ProtocolBroadcastMessage<N> {
275 fn from(message: EthBroadcastMessage<N>) -> Self {
276 Self { message_type: message.message_id(), message }
277 }
278}
279
280#[derive(Clone, Debug, PartialEq, Eq)]
306#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
307pub enum EthMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
308 Status(StatusMessage),
310 NewBlockHashes(NewBlockHashes),
312 #[cfg_attr(
314 feature = "serde",
315 serde(bound = "N::NewBlockPayload: serde::Serialize + serde::de::DeserializeOwned")
316 )]
317 NewBlock(Box<N::NewBlockPayload>),
318 #[cfg_attr(
320 feature = "serde",
321 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
322 )]
323 Transactions(Transactions<N::BroadcastedTransaction>),
324 NewPooledTransactionHashes66(NewPooledTransactionHashes66),
326 NewPooledTransactionHashes68(NewPooledTransactionHashes68),
328 GetBlockHeaders(RequestPair<GetBlockHeaders>),
331 #[cfg_attr(
333 feature = "serde",
334 serde(bound = "N::BlockHeader: serde::Serialize + serde::de::DeserializeOwned")
335 )]
336 BlockHeaders(RequestPair<BlockHeaders<N::BlockHeader>>),
337 GetBlockBodies(RequestPair<GetBlockBodies>),
339 #[cfg_attr(
341 feature = "serde",
342 serde(bound = "N::BlockBody: serde::Serialize + serde::de::DeserializeOwned")
343 )]
344 BlockBodies(RequestPair<BlockBodies<N::BlockBody>>),
345 GetPooledTransactions(RequestPair<GetPooledTransactions>),
347 #[cfg_attr(
349 feature = "serde",
350 serde(bound = "N::PooledTransaction: serde::Serialize + serde::de::DeserializeOwned")
351 )]
352 PooledTransactions(RequestPair<PooledTransactions<N::PooledTransaction>>),
353 GetNodeData(RequestPair<GetNodeData>),
355 NodeData(RequestPair<NodeData>),
357 GetReceipts(RequestPair<GetReceipts>),
359 GetReceipts70(RequestPair<GetReceipts70>),
365 GetBlockAccessLists(RequestPair<GetBlockAccessLists>),
367 #[cfg_attr(
369 feature = "serde",
370 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
371 )]
372 Receipts(RequestPair<Receipts<N::Receipt>>),
373 #[cfg_attr(
375 feature = "serde",
376 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
377 )]
378 Receipts69(RequestPair<Receipts69<N::Receipt>>),
379 #[cfg_attr(
381 feature = "serde",
382 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
383 )]
384 Receipts70(RequestPair<Receipts70<N::Receipt>>),
389 BlockAccessLists(RequestPair<BlockAccessLists>),
391 Cells(RequestPair<Cells>),
393 GetCells(RequestPair<GetCells>),
395 #[cfg_attr(
397 feature = "serde",
398 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
399 )]
400 BlockRangeUpdate(BlockRangeUpdate),
401 Other(RawCapabilityMessage),
403}
404
405impl<N: NetworkPrimitives> EthMessage<N> {
406 pub const fn message_id(&self) -> EthMessageID {
408 match self {
409 Self::Status(_) => EthMessageID::Status,
410 Self::NewBlockHashes(_) => EthMessageID::NewBlockHashes,
411 Self::NewBlock(_) => EthMessageID::NewBlock,
412 Self::Transactions(_) => EthMessageID::Transactions,
413 Self::NewPooledTransactionHashes66(_) | Self::NewPooledTransactionHashes68(_) => {
414 EthMessageID::NewPooledTransactionHashes
415 }
416 Self::GetBlockHeaders(_) => EthMessageID::GetBlockHeaders,
417 Self::BlockHeaders(_) => EthMessageID::BlockHeaders,
418 Self::GetBlockBodies(_) => EthMessageID::GetBlockBodies,
419 Self::BlockBodies(_) => EthMessageID::BlockBodies,
420 Self::GetPooledTransactions(_) => EthMessageID::GetPooledTransactions,
421 Self::PooledTransactions(_) => EthMessageID::PooledTransactions,
422 Self::GetNodeData(_) => EthMessageID::GetNodeData,
423 Self::NodeData(_) => EthMessageID::NodeData,
424 Self::GetReceipts(_) | Self::GetReceipts70(_) => EthMessageID::GetReceipts,
425 Self::Receipts(_) | Self::Receipts69(_) | Self::Receipts70(_) => EthMessageID::Receipts,
426 Self::BlockRangeUpdate(_) => EthMessageID::BlockRangeUpdate,
427 Self::GetBlockAccessLists(_) => EthMessageID::GetBlockAccessLists,
428 Self::BlockAccessLists(_) => EthMessageID::BlockAccessLists,
429 Self::Cells(_) => EthMessageID::Cells,
430 Self::GetCells(_) => EthMessageID::GetCells,
431 Self::Other(msg) => EthMessageID::Other(msg.id as u8),
432 }
433 }
434
435 pub const fn is_request(&self) -> bool {
437 matches!(
438 self,
439 Self::GetBlockBodies(_) |
440 Self::GetBlockHeaders(_) |
441 Self::GetReceipts(_) |
442 Self::GetReceipts70(_) |
443 Self::GetBlockAccessLists(_) |
444 Self::GetCells(_) |
445 Self::GetPooledTransactions(_) |
446 Self::GetNodeData(_)
447 )
448 }
449
450 pub const fn is_response(&self) -> bool {
452 matches!(
453 self,
454 Self::PooledTransactions(_) |
455 Self::Receipts(_) |
456 Self::Receipts69(_) |
457 Self::Receipts70(_) |
458 Self::BlockAccessLists(_) |
459 Self::BlockHeaders(_) |
460 Self::BlockBodies(_) |
461 Self::NodeData(_) |
462 Self::Cells(_)
463 )
464 }
465
466 pub fn map_versioned(self, version: EthVersion) -> Self {
471 if version >= EthVersion::Eth70 {
475 return match self {
476 Self::GetReceipts(pair) => {
477 let RequestPair { request_id, message } = pair;
478 let req = RequestPair {
479 request_id,
480 message: GetReceipts70 {
481 first_block_receipt_index: 0,
482 block_hashes: message.0,
483 },
484 };
485 Self::GetReceipts70(req)
486 }
487 other => other,
488 }
489 }
490
491 self
492 }
493}
494
495impl<N: NetworkPrimitives> Encodable for EthMessage<N> {
496 fn encode(&self, out: &mut dyn BufMut) {
497 match self {
498 Self::Status(status) => status.encode(out),
499 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.encode(out),
500 Self::NewBlock(new_block) => new_block.encode(out),
501 Self::Transactions(transactions) => transactions.encode(out),
502 Self::NewPooledTransactionHashes66(hashes) => hashes.encode(out),
503 Self::NewPooledTransactionHashes68(hashes) => hashes.encode(out),
504 Self::GetBlockHeaders(request) => request.encode(out),
505 Self::BlockHeaders(headers) => headers.encode(out),
506 Self::GetBlockBodies(request) => request.encode(out),
507 Self::BlockBodies(bodies) => bodies.encode(out),
508 Self::GetPooledTransactions(request) => request.encode(out),
509 Self::PooledTransactions(transactions) => transactions.encode(out),
510 Self::GetNodeData(request) => request.encode(out),
511 Self::NodeData(data) => data.encode(out),
512 Self::GetReceipts(request) => request.encode(out),
513 Self::GetReceipts70(request) => request.encode(out),
514 Self::GetBlockAccessLists(request) => request.encode(out),
515 Self::GetCells(request) => request.encode(out),
516 Self::Receipts(receipts) => receipts.encode(out),
517 Self::Receipts69(receipt69) => receipt69.encode(out),
518 Self::Receipts70(receipt70) => receipt70.encode(out),
519 Self::BlockAccessLists(block_access_lists) => block_access_lists.encode(out),
520 Self::BlockRangeUpdate(block_range_update) => block_range_update.encode(out),
521 Self::Cells(cells) => cells.encode(out),
522 Self::Other(unknown) => out.put_slice(&unknown.payload),
523 }
524 }
525 fn length(&self) -> usize {
526 match self {
527 Self::Status(status) => status.length(),
528 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.length(),
529 Self::NewBlock(new_block) => new_block.length(),
530 Self::Transactions(transactions) => transactions.length(),
531 Self::NewPooledTransactionHashes66(hashes) => hashes.length(),
532 Self::NewPooledTransactionHashes68(hashes) => hashes.length(),
533 Self::GetBlockHeaders(request) => request.length(),
534 Self::BlockHeaders(headers) => headers.length(),
535 Self::GetBlockBodies(request) => request.length(),
536 Self::BlockBodies(bodies) => bodies.length(),
537 Self::GetPooledTransactions(request) => request.length(),
538 Self::PooledTransactions(transactions) => transactions.length(),
539 Self::GetNodeData(request) => request.length(),
540 Self::NodeData(data) => data.length(),
541 Self::GetReceipts(request) => request.length(),
542 Self::GetReceipts70(request) => request.length(),
543 Self::GetBlockAccessLists(request) => request.length(),
544 Self::GetCells(request) => request.length(),
545 Self::Receipts(receipts) => receipts.length(),
546 Self::Receipts69(receipt69) => receipt69.length(),
547 Self::Receipts70(receipt70) => receipt70.length(),
548 Self::BlockAccessLists(block_access_lists) => block_access_lists.length(),
549 Self::BlockRangeUpdate(block_range_update) => block_range_update.length(),
550 Self::Cells(cells) => cells.length(),
551 Self::Other(unknown) => unknown.length(),
552 }
553 }
554}
555
556#[derive(Clone, Debug, PartialEq, Eq)]
564pub enum EthBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
565 NewBlock(Arc<N::NewBlockPayload>),
567 Transactions(SharedTransactions<N::BroadcastedTransaction>),
569}
570
571impl<N: NetworkPrimitives> EthBroadcastMessage<N> {
574 pub const fn message_id(&self) -> EthMessageID {
576 match self {
577 Self::NewBlock(_) => EthMessageID::NewBlock,
578 Self::Transactions(_) => EthMessageID::Transactions,
579 }
580 }
581}
582
583impl<N: NetworkPrimitives> Encodable for EthBroadcastMessage<N> {
584 fn encode(&self, out: &mut dyn BufMut) {
585 match self {
586 Self::NewBlock(new_block) => new_block.encode(out),
587 Self::Transactions(transactions) => transactions.encode(out),
588 }
589 }
590
591 fn length(&self) -> usize {
592 match self {
593 Self::NewBlock(new_block) => new_block.length(),
594 Self::Transactions(transactions) => transactions.length(),
595 }
596 }
597}
598
599#[repr(u8)]
601#[derive(Clone, Copy, Debug, PartialEq, Eq)]
602#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
603pub enum EthMessageID {
604 Status = 0x00,
606 NewBlockHashes = 0x01,
608 Transactions = 0x02,
610 GetBlockHeaders = 0x03,
612 BlockHeaders = 0x04,
614 GetBlockBodies = 0x05,
616 BlockBodies = 0x06,
618 NewBlock = 0x07,
620 NewPooledTransactionHashes = 0x08,
622 GetPooledTransactions = 0x09,
624 PooledTransactions = 0x0a,
626 GetNodeData = 0x0d,
628 NodeData = 0x0e,
630 GetReceipts = 0x0f,
632 Receipts = 0x10,
634 BlockRangeUpdate = 0x11,
638 GetBlockAccessLists = 0x12,
642 BlockAccessLists = 0x13,
646
647 GetCells = 0x14,
651 Cells = 0x15,
655 Other(u8),
657}
658
659impl EthMessageID {
660 pub const fn to_u8(&self) -> u8 {
662 match self {
663 Self::Status => 0x00,
664 Self::NewBlockHashes => 0x01,
665 Self::Transactions => 0x02,
666 Self::GetBlockHeaders => 0x03,
667 Self::BlockHeaders => 0x04,
668 Self::GetBlockBodies => 0x05,
669 Self::BlockBodies => 0x06,
670 Self::NewBlock => 0x07,
671 Self::NewPooledTransactionHashes => 0x08,
672 Self::GetPooledTransactions => 0x09,
673 Self::PooledTransactions => 0x0a,
674 Self::GetNodeData => 0x0d,
675 Self::NodeData => 0x0e,
676 Self::GetReceipts => 0x0f,
677 Self::Receipts => 0x10,
678 Self::BlockRangeUpdate => 0x11,
679 Self::GetBlockAccessLists => 0x12,
680 Self::BlockAccessLists => 0x13,
681 Self::GetCells => 0x14,
682 Self::Cells => 0x15,
683 Self::Other(value) => *value, }
685 }
686
687 pub const fn max(version: EthVersion) -> u8 {
689 if version.is_eth72() {
690 Self::Cells.to_u8()
691 } else if version.is_eth71() {
692 Self::BlockAccessLists.to_u8()
693 } else if version.is_eth69_or_newer() {
694 Self::BlockRangeUpdate.to_u8()
695 } else {
696 Self::Receipts.to_u8()
697 }
698 }
699
700 pub const fn message_count(version: EthVersion) -> u8 {
706 Self::max(version) + 1
707 }
708}
709
710impl Encodable for EthMessageID {
711 fn encode(&self, out: &mut dyn BufMut) {
712 out.put_u8(self.to_u8());
713 }
714 fn length(&self) -> usize {
715 1
716 }
717}
718
719impl Decodable for EthMessageID {
720 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
721 let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
722 0x00 => Self::Status,
723 0x01 => Self::NewBlockHashes,
724 0x02 => Self::Transactions,
725 0x03 => Self::GetBlockHeaders,
726 0x04 => Self::BlockHeaders,
727 0x05 => Self::GetBlockBodies,
728 0x06 => Self::BlockBodies,
729 0x07 => Self::NewBlock,
730 0x08 => Self::NewPooledTransactionHashes,
731 0x09 => Self::GetPooledTransactions,
732 0x0a => Self::PooledTransactions,
733 0x0d => Self::GetNodeData,
734 0x0e => Self::NodeData,
735 0x0f => Self::GetReceipts,
736 0x10 => Self::Receipts,
737 0x11 => Self::BlockRangeUpdate,
738 0x12 => Self::GetBlockAccessLists,
739 0x13 => Self::BlockAccessLists,
740 0x14 => Self::GetCells,
741 0x15 => Self::Cells,
742 unknown => Self::Other(*unknown),
743 };
744 buf.advance(1);
745 Ok(id)
746 }
747}
748
749impl TryFrom<usize> for EthMessageID {
750 type Error = &'static str;
751
752 fn try_from(value: usize) -> Result<Self, Self::Error> {
753 match value {
754 0x00 => Ok(Self::Status),
755 0x01 => Ok(Self::NewBlockHashes),
756 0x02 => Ok(Self::Transactions),
757 0x03 => Ok(Self::GetBlockHeaders),
758 0x04 => Ok(Self::BlockHeaders),
759 0x05 => Ok(Self::GetBlockBodies),
760 0x06 => Ok(Self::BlockBodies),
761 0x07 => Ok(Self::NewBlock),
762 0x08 => Ok(Self::NewPooledTransactionHashes),
763 0x09 => Ok(Self::GetPooledTransactions),
764 0x0a => Ok(Self::PooledTransactions),
765 0x0d => Ok(Self::GetNodeData),
766 0x0e => Ok(Self::NodeData),
767 0x0f => Ok(Self::GetReceipts),
768 0x10 => Ok(Self::Receipts),
769 0x11 => Ok(Self::BlockRangeUpdate),
770 0x12 => Ok(Self::GetBlockAccessLists),
771 0x13 => Ok(Self::BlockAccessLists),
772 0x14 => Ok(Self::GetCells),
773 0x15 => Ok(Self::Cells),
774 _ => Err("Invalid message ID"),
775 }
776 }
777}
778
779#[derive(Clone, Debug, PartialEq, Eq)]
783#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
784#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
785pub struct RequestPair<T> {
786 pub request_id: u64,
788
789 pub message: T,
791}
792
793impl<T> RequestPair<T> {
794 pub fn map<F, R>(self, f: F) -> RequestPair<R>
796 where
797 F: FnOnce(T) -> R,
798 {
799 let Self { request_id, message } = self;
800 RequestPair { request_id, message: f(message) }
801 }
802
803 pub fn decode_with<F>(buf: &mut &[u8], decode_msg: F) -> alloy_rlp::Result<Self>
805 where
806 F: FnOnce(&mut &[u8]) -> alloy_rlp::Result<T>,
807 {
808 let header = Header::decode(buf)?;
809
810 let initial_length = buf.len();
811 let request_id = u64::decode(buf)?;
812 let message = decode_msg(buf)?;
813
814 let consumed_len = initial_length - buf.len();
815 if consumed_len != header.payload_length {
816 return Err(alloy_rlp::Error::UnexpectedLength)
817 }
818
819 Ok(Self { request_id, message })
820 }
821}
822
823impl<T> Encodable for RequestPair<T>
825where
826 T: Encodable,
827{
828 fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
829 let header =
830 Header { list: true, payload_length: self.request_id.length() + self.message.length() };
831
832 header.encode(out);
833 self.request_id.encode(out);
834 self.message.encode(out);
835 }
836
837 fn length(&self) -> usize {
838 let mut length = 0;
839 length += self.request_id.length();
840 length += self.message.length();
841 length += length_of_length(length);
842 length
843 }
844}
845
846impl<T> Decodable for RequestPair<T>
848where
849 T: Decodable,
850{
851 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
852 let header = Header::decode(buf)?;
853
854 let initial_length = buf.len();
855 let request_id = u64::decode(buf)?;
856 let message = T::decode(buf)?;
857
858 let consumed_len = initial_length - buf.len();
861 if consumed_len != header.payload_length {
862 return Err(alloy_rlp::Error::UnexpectedLength)
863 }
864
865 Ok(Self { request_id, message })
866 }
867}
868
869#[cfg(test)]
870mod tests {
871 use super::MessageError;
872 use crate::{
873 message::RequestPair, BlockAccessLists, EthMessage, EthMessageID, EthNetworkPrimitives,
874 EthVersion, GetBlockAccessLists, GetNodeData, NodeData, ProtocolMessage,
875 RawCapabilityMessage,
876 };
877 use alloy_primitives::hex;
878 use alloy_rlp::{Decodable, Encodable, Error};
879 use reth_ethereum_primitives::BlockBody;
880
881 fn encode<T: Encodable>(value: T) -> Vec<u8> {
882 let mut buf = vec![];
883 value.encode(&mut buf);
884 buf
885 }
886
887 #[test]
888 fn test_removed_message_at_eth67() {
889 let get_node_data = EthMessage::<EthNetworkPrimitives>::GetNodeData(RequestPair {
890 request_id: 1337,
891 message: GetNodeData(vec![]),
892 });
893 let buf = encode(ProtocolMessage {
894 message_type: EthMessageID::GetNodeData,
895 message: get_node_data,
896 });
897 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
898 crate::EthVersion::Eth67,
899 &mut &buf[..],
900 );
901 assert!(matches!(msg, Err(MessageError::Invalid(..))));
902
903 let node_data = EthMessage::<EthNetworkPrimitives>::NodeData(RequestPair {
904 request_id: 1337,
905 message: NodeData(vec![]),
906 });
907 let buf =
908 encode(ProtocolMessage { message_type: EthMessageID::NodeData, message: node_data });
909 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
910 crate::EthVersion::Eth67,
911 &mut &buf[..],
912 );
913 assert!(matches!(msg, Err(MessageError::Invalid(..))));
914 }
915
916 #[test]
917 fn test_bal_message_version_gating() {
918 let get_block_access_lists =
919 EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
920 request_id: 1337,
921 message: GetBlockAccessLists(vec![]),
922 });
923 let buf = encode(ProtocolMessage {
924 message_type: EthMessageID::GetBlockAccessLists,
925 message: get_block_access_lists,
926 });
927 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
928 EthVersion::Eth70,
929 &mut &buf[..],
930 );
931 assert!(matches!(
932 msg,
933 Err(MessageError::Invalid(EthVersion::Eth70, EthMessageID::GetBlockAccessLists))
934 ));
935
936 let block_access_lists =
937 EthMessage::<EthNetworkPrimitives>::BlockAccessLists(RequestPair {
938 request_id: 1337,
939 message: BlockAccessLists(vec![]),
940 });
941 let buf = encode(ProtocolMessage {
942 message_type: EthMessageID::BlockAccessLists,
943 message: block_access_lists,
944 });
945 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
946 EthVersion::Eth70,
947 &mut &buf[..],
948 );
949 assert!(matches!(
950 msg,
951 Err(MessageError::Invalid(EthVersion::Eth70, EthMessageID::BlockAccessLists))
952 ));
953 }
954
955 #[test]
956 fn test_bal_message_eth71_roundtrip() {
957 let msg = ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(
958 RequestPair { request_id: 42, message: GetBlockAccessLists(vec![]) },
959 ));
960 let encoded = encode(msg.clone());
961 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
962 EthVersion::Eth71,
963 &mut &encoded[..],
964 )
965 .unwrap();
966
967 assert_eq!(decoded, msg);
968 }
969
970 #[test]
971 fn request_pair_encode() {
972 let request_pair = RequestPair { request_id: 1337, message: vec![5u8] };
973
974 let expected = hex!("c5820539c105");
981 let got = encode(request_pair);
982 assert_eq!(expected[..], got, "expected: {expected:X?}, got: {got:X?}",);
983 }
984
985 #[test]
986 fn request_pair_decode() {
987 let raw_pair = &hex!("c5820539c105")[..];
988
989 let expected = RequestPair { request_id: 1337, message: vec![5u8] };
990
991 let got = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair).unwrap();
992 assert_eq!(expected.length(), raw_pair.len());
993 assert_eq!(expected, got);
994 }
995
996 #[test]
997 fn malicious_request_pair_decode() {
998 let raw_pair = &hex!("c5820539c20505")[..];
1008
1009 let result = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair);
1010 assert!(matches!(result, Err(Error::UnexpectedLength)));
1011 }
1012
1013 #[test]
1014 fn empty_block_bodies_protocol() {
1015 let empty_block_bodies =
1016 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
1017 request_id: 0,
1018 message: Default::default(),
1019 }));
1020 let mut buf = Vec::new();
1021 empty_block_bodies.encode(&mut buf);
1022 let decoded =
1023 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
1024 assert_eq!(empty_block_bodies, decoded);
1025 }
1026
1027 #[test]
1028 fn empty_block_body_protocol() {
1029 let empty_block_bodies =
1030 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
1031 request_id: 0,
1032 message: vec![BlockBody {
1033 transactions: vec![],
1034 ommers: vec![],
1035 withdrawals: Some(Default::default()),
1036 }]
1037 .into(),
1038 }));
1039 let mut buf = Vec::new();
1040 empty_block_bodies.encode(&mut buf);
1041 let decoded =
1042 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
1043 assert_eq!(empty_block_bodies, decoded);
1044 }
1045
1046 #[test]
1047 fn decode_block_bodies_message() {
1048 let buf = hex!("06c48199c1c0");
1049 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
1050 EthVersion::Eth68,
1051 &mut &buf[..],
1052 )
1053 .unwrap_err();
1054 assert!(matches!(msg, MessageError::RlpError(alloy_rlp::Error::InputTooShort)));
1055 }
1056
1057 #[test]
1058 fn custom_message_roundtrip() {
1059 let custom_payload = vec![1, 2, 3, 4, 5];
1060 let custom_message = RawCapabilityMessage::new(0x20, custom_payload.into());
1061 let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
1062 message_type: EthMessageID::Other(0x20),
1063 message: EthMessage::Other(custom_message),
1064 };
1065
1066 let encoded = encode(protocol_message.clone());
1067 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
1068 EthVersion::Eth68,
1069 &mut &encoded[..],
1070 )
1071 .unwrap();
1072
1073 assert_eq!(protocol_message, decoded);
1074 }
1075
1076 #[test]
1077 fn custom_message_empty_payload_roundtrip() {
1078 let custom_message = RawCapabilityMessage::new(0x30, vec![].into());
1079 let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
1080 message_type: EthMessageID::Other(0x30),
1081 message: EthMessage::Other(custom_message),
1082 };
1083
1084 let encoded = encode(protocol_message.clone());
1085 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
1086 EthVersion::Eth68,
1087 &mut &encoded[..],
1088 )
1089 .unwrap();
1090
1091 assert_eq!(protocol_message, decoded);
1092 }
1093
1094 #[test]
1095 fn decode_status_success() {
1096 use crate::{Status, StatusMessage};
1097 use alloy_hardforks::{ForkHash, ForkId};
1098 use alloy_primitives::{B256, U256};
1099
1100 let status = Status {
1101 version: EthVersion::Eth68,
1102 chain: alloy_chains::Chain::mainnet(),
1103 total_difficulty: U256::from(100u64),
1104 blockhash: B256::random(),
1105 genesis: B256::random(),
1106 forkid: ForkId { hash: ForkHash([0xb7, 0x15, 0x07, 0x7d]), next: 0 },
1107 };
1108
1109 let protocol_message = ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::Status(
1110 StatusMessage::Legacy(status),
1111 ));
1112 let encoded = encode(protocol_message);
1113
1114 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
1115 EthVersion::Eth68,
1116 &mut &encoded[..],
1117 )
1118 .unwrap();
1119
1120 assert!(matches!(decoded, StatusMessage::Legacy(s) if s == status));
1121 }
1122
1123 #[test]
1124 fn eth_message_id_max_includes_block_range_update() {
1125 assert_eq!(EthMessageID::max(EthVersion::Eth69), EthMessageID::BlockRangeUpdate.to_u8(),);
1126 assert_eq!(EthMessageID::max(EthVersion::Eth70), EthMessageID::BlockRangeUpdate.to_u8(),);
1127 assert_eq!(EthMessageID::max(EthVersion::Eth68), EthMessageID::Receipts.to_u8());
1128 }
1129
1130 #[test]
1131 fn decode_status_rejects_non_status() {
1132 let msg = EthMessage::<EthNetworkPrimitives>::GetBlockBodies(RequestPair {
1133 request_id: 1,
1134 message: crate::GetBlockBodies::default(),
1135 });
1136 let protocol_message =
1137 ProtocolMessage { message_type: EthMessageID::GetBlockBodies, message: msg };
1138 let encoded = encode(protocol_message);
1139
1140 let result = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
1141 EthVersion::Eth68,
1142 &mut &encoded[..],
1143 );
1144
1145 assert!(matches!(
1146 result,
1147 Err(MessageError::ExpectedStatusMessage(EthMessageID::GetBlockBodies))
1148 ));
1149 }
1150}