1use super::{
10 broadcast::NewBlockHashes, BlockAccessLists, BlockBodies, BlockHeaders, GetBlockAccessLists,
11 GetBlockBodies, GetBlockHeaders, GetNodeData, GetPooledTransactions, GetReceipts,
12 GetReceipts70, NewPooledTransactionHashes66, NewPooledTransactionHashes68, NodeData,
13 PooledTransactions, Receipts, Status, StatusEth69, Transactions,
14};
15use crate::{
16 status::StatusMessage, BlockRangeUpdate, Cells, EthNetworkPrimitives, EthVersion, GetCells,
17 NetworkPrimitives, NewPooledTransactionHashes72, RawCapabilityMessage, Receipts69, Receipts70,
18 SharedTransactions,
19};
20use alloc::{boxed::Box, string::String, sync::Arc};
21use alloy_primitives::{
22 bytes::{Buf, BufMut},
23 Bytes,
24};
25use alloy_rlp::{length_of_length, Decodable, Encodable, Header};
26use core::fmt::Debug;
27
28pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
31
32pub const TX_MEMORY_BUDGET_MULTIPLIER: usize = 2;
40
41#[derive(thiserror::Error, Debug)]
43pub enum MessageError {
44 #[error("message id {1:?} is invalid for version {0:?}")]
46 Invalid(EthVersion, EthMessageID),
47 #[error("expected status message but received {0:?}")]
49 ExpectedStatusMessage(EthMessageID),
50 #[error("RLP error: {0}")]
52 RlpError(#[from] alloy_rlp::Error),
53 #[error("{0}")]
55 Other(String),
56}
57
58#[derive(Clone, Debug, PartialEq, Eq)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
62 pub message_type: EthMessageID,
64 #[cfg_attr(
66 feature = "serde",
67 serde(bound = "EthMessage<N>: serde::Serialize + serde::de::DeserializeOwned")
68 )]
69 pub message: EthMessage<N>,
70}
71
72impl<N: NetworkPrimitives> ProtocolMessage<N> {
73 pub fn decode_status(
78 version: EthVersion,
79 buf: &mut &[u8],
80 ) -> Result<StatusMessage, MessageError> {
81 let message_type = EthMessageID::decode(buf)?;
82
83 if message_type != EthMessageID::Status {
84 return Err(MessageError::ExpectedStatusMessage(message_type))
85 }
86
87 let status = if version < EthVersion::Eth69 {
88 StatusMessage::Legacy(Status::decode(buf)?)
89 } else {
90 StatusMessage::Eth69(StatusEth69::decode(buf)?)
91 };
92
93 Ok(status)
94 }
95
96 pub fn decode_message(version: EthVersion, buf: &mut &[u8]) -> Result<Self, MessageError> {
100 Self::decode_message_with_tx_memory_budget(version, buf, usize::MAX)
101 }
102
103 pub fn decode_message_with_tx_memory_budget(
109 version: EthVersion,
110 buf: &mut &[u8],
111 tx_memory_budget: usize,
112 ) -> Result<Self, MessageError> {
113 let message_type = EthMessageID::decode(buf)?;
114
115 let message = match message_type {
118 EthMessageID::Status => EthMessage::Status(if version < EthVersion::Eth69 {
119 StatusMessage::Legacy(Status::decode(buf)?)
120 } else {
121 StatusMessage::Eth69(StatusEth69::decode(buf)?)
122 }),
123 EthMessageID::NewBlockHashes => {
124 EthMessage::NewBlockHashes(NewBlockHashes::decode(buf)?)
125 }
126 EthMessageID::NewBlock => {
127 EthMessage::NewBlock(Box::new(N::NewBlockPayload::decode(buf)?))
128 }
129 EthMessageID::Transactions => EthMessage::Transactions(
130 Transactions::decode_with_memory_budget(buf, tx_memory_budget)?,
131 ),
132 EthMessageID::NewPooledTransactionHashes => {
133 if version >= EthVersion::Eth72 {
134 EthMessage::NewPooledTransactionHashes72(NewPooledTransactionHashes72::decode(
135 buf,
136 )?)
137 } else if version >= EthVersion::Eth68 {
138 EthMessage::NewPooledTransactionHashes68(NewPooledTransactionHashes68::decode(
139 buf,
140 )?)
141 } else {
142 EthMessage::NewPooledTransactionHashes66(NewPooledTransactionHashes66::decode(
143 buf,
144 )?)
145 }
146 }
147 EthMessageID::GetBlockHeaders => EthMessage::GetBlockHeaders(RequestPair::decode(buf)?),
148 EthMessageID::BlockHeaders => EthMessage::BlockHeaders(RequestPair::decode(buf)?),
149 EthMessageID::GetBlockBodies => EthMessage::GetBlockBodies(RequestPair::decode(buf)?),
150 EthMessageID::BlockBodies => EthMessage::BlockBodies(RequestPair::decode(buf)?),
151 EthMessageID::GetPooledTransactions => {
152 EthMessage::GetPooledTransactions(RequestPair::decode(buf)?)
153 }
154 EthMessageID::PooledTransactions => {
155 EthMessage::PooledTransactions(RequestPair::decode_with(buf, |buf| {
156 PooledTransactions::decode_with_memory_budget(buf, tx_memory_budget)
157 })?)
158 }
159 EthMessageID::GetNodeData => {
160 if version >= EthVersion::Eth67 {
161 return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
162 }
163 EthMessage::GetNodeData(RequestPair::decode(buf)?)
164 }
165 EthMessageID::NodeData => {
166 if version >= EthVersion::Eth67 {
167 return Err(MessageError::Invalid(version, EthMessageID::NodeData))
168 }
169 EthMessage::NodeData(RequestPair::decode(buf)?)
170 }
171 EthMessageID::GetReceipts => {
172 if version >= EthVersion::Eth70 {
173 EthMessage::GetReceipts70(RequestPair::decode(buf)?)
174 } else {
175 EthMessage::GetReceipts(RequestPair::decode(buf)?)
176 }
177 }
178 EthMessageID::Receipts => {
179 match version {
180 v if v >= EthVersion::Eth70 => {
181 EthMessage::Receipts70(RequestPair::decode(buf)?)
185 }
186 EthVersion::Eth69 => {
187 EthMessage::Receipts69(RequestPair::decode(buf)?)
189 }
190 _ => {
191 EthMessage::Receipts(RequestPair::decode(buf)?)
193 }
194 }
195 }
196 EthMessageID::BlockRangeUpdate => {
197 if version < EthVersion::Eth69 {
198 return Err(MessageError::Invalid(version, EthMessageID::BlockRangeUpdate))
199 }
200 EthMessage::BlockRangeUpdate(BlockRangeUpdate::decode(buf)?)
201 }
202 EthMessageID::GetBlockAccessLists => {
203 if version < EthVersion::Eth71 {
204 return Err(MessageError::Invalid(version, EthMessageID::GetBlockAccessLists))
205 }
206 EthMessage::GetBlockAccessLists(RequestPair::decode(buf)?)
207 }
208 EthMessageID::BlockAccessLists => {
209 if version < EthVersion::Eth71 {
210 return Err(MessageError::Invalid(version, EthMessageID::BlockAccessLists))
211 }
212 EthMessage::BlockAccessLists(RequestPair::decode(buf)?)
213 }
214 EthMessageID::Cells => {
215 if version < EthVersion::Eth72 {
216 return Err(MessageError::Invalid(version, EthMessageID::Cells))
217 }
218 EthMessage::Cells(RequestPair::decode(buf)?)
219 }
220 EthMessageID::GetCells => {
221 if version < EthVersion::Eth72 {
222 return Err(MessageError::Invalid(version, EthMessageID::GetCells))
223 }
224 EthMessage::GetCells(RequestPair::decode(buf)?)
225 }
226 EthMessageID::Other(_) => {
227 let raw_payload = Bytes::copy_from_slice(buf);
228 buf.advance(raw_payload.len());
229 EthMessage::Other(RawCapabilityMessage::new(
230 message_type.to_u8() as usize,
231 raw_payload.into(),
232 ))
233 }
234 };
235 Ok(Self { message_type, message })
236 }
237}
238
239impl<N: NetworkPrimitives> Encodable for ProtocolMessage<N> {
240 fn encode(&self, out: &mut dyn BufMut) {
243 self.message_type.encode(out);
244 self.message.encode(out);
245 }
246 fn length(&self) -> usize {
247 self.message_type.length() + self.message.length()
248 }
249}
250
251impl<N: NetworkPrimitives> From<EthMessage<N>> for ProtocolMessage<N> {
252 fn from(message: EthMessage<N>) -> Self {
253 Self { message_type: message.message_id(), message }
254 }
255}
256
257#[derive(Clone, Debug)]
259pub struct ProtocolBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
260 pub message_type: EthMessageID,
262 pub message: EthBroadcastMessage<N>,
265}
266
267impl<N: NetworkPrimitives> Encodable for ProtocolBroadcastMessage<N> {
268 fn encode(&self, out: &mut dyn BufMut) {
271 self.message_type.encode(out);
272 self.message.encode(out);
273 }
274 fn length(&self) -> usize {
275 self.message_type.length() + self.message.length()
276 }
277}
278
279impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for ProtocolBroadcastMessage<N> {
280 fn from(message: EthBroadcastMessage<N>) -> Self {
281 Self { message_type: message.message_id(), message }
282 }
283}
284
285#[derive(Clone, Debug, PartialEq, Eq)]
311#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
312pub enum EthMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
313 Status(StatusMessage),
315 NewBlockHashes(NewBlockHashes),
317 #[cfg_attr(
319 feature = "serde",
320 serde(bound = "N::NewBlockPayload: serde::Serialize + serde::de::DeserializeOwned")
321 )]
322 NewBlock(Box<N::NewBlockPayload>),
323 #[cfg_attr(
325 feature = "serde",
326 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
327 )]
328 Transactions(Transactions<N::BroadcastedTransaction>),
329 NewPooledTransactionHashes66(NewPooledTransactionHashes66),
331 NewPooledTransactionHashes68(NewPooledTransactionHashes68),
333 NewPooledTransactionHashes72(NewPooledTransactionHashes72),
335
336 GetBlockHeaders(RequestPair<GetBlockHeaders>),
339 #[cfg_attr(
341 feature = "serde",
342 serde(bound = "N::BlockHeader: serde::Serialize + serde::de::DeserializeOwned")
343 )]
344 BlockHeaders(RequestPair<BlockHeaders<N::BlockHeader>>),
345 GetBlockBodies(RequestPair<GetBlockBodies>),
347 #[cfg_attr(
349 feature = "serde",
350 serde(bound = "N::BlockBody: serde::Serialize + serde::de::DeserializeOwned")
351 )]
352 BlockBodies(RequestPair<BlockBodies<N::BlockBody>>),
353 GetPooledTransactions(RequestPair<GetPooledTransactions>),
355 #[cfg_attr(
357 feature = "serde",
358 serde(bound = "N::PooledTransaction: serde::Serialize + serde::de::DeserializeOwned")
359 )]
360 PooledTransactions(RequestPair<PooledTransactions<N::PooledTransaction>>),
361 GetNodeData(RequestPair<GetNodeData>),
363 NodeData(RequestPair<NodeData>),
365 GetReceipts(RequestPair<GetReceipts>),
367 GetReceipts70(RequestPair<GetReceipts70>),
373 GetBlockAccessLists(RequestPair<GetBlockAccessLists>),
375 #[cfg_attr(
377 feature = "serde",
378 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
379 )]
380 Receipts(RequestPair<Receipts<N::Receipt>>),
381 #[cfg_attr(
383 feature = "serde",
384 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
385 )]
386 Receipts69(RequestPair<Receipts69<N::Receipt>>),
387 #[cfg_attr(
389 feature = "serde",
390 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
391 )]
392 Receipts70(RequestPair<Receipts70<N::Receipt>>),
397 BlockAccessLists(RequestPair<BlockAccessLists>),
399 Cells(RequestPair<Cells>),
401 GetCells(RequestPair<GetCells>),
403 #[cfg_attr(
405 feature = "serde",
406 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
407 )]
408 BlockRangeUpdate(BlockRangeUpdate),
409 Other(RawCapabilityMessage),
411}
412
413impl<N: NetworkPrimitives> EthMessage<N> {
414 pub const fn message_id(&self) -> EthMessageID {
416 match self {
417 Self::Status(_) => EthMessageID::Status,
418 Self::NewBlockHashes(_) => EthMessageID::NewBlockHashes,
419 Self::NewBlock(_) => EthMessageID::NewBlock,
420 Self::Transactions(_) => EthMessageID::Transactions,
421 Self::NewPooledTransactionHashes66(_) |
422 Self::NewPooledTransactionHashes68(_) |
423 Self::NewPooledTransactionHashes72(_) => EthMessageID::NewPooledTransactionHashes,
424 Self::GetBlockHeaders(_) => EthMessageID::GetBlockHeaders,
425 Self::BlockHeaders(_) => EthMessageID::BlockHeaders,
426 Self::GetBlockBodies(_) => EthMessageID::GetBlockBodies,
427 Self::BlockBodies(_) => EthMessageID::BlockBodies,
428 Self::GetPooledTransactions(_) => EthMessageID::GetPooledTransactions,
429 Self::PooledTransactions(_) => EthMessageID::PooledTransactions,
430 Self::GetNodeData(_) => EthMessageID::GetNodeData,
431 Self::NodeData(_) => EthMessageID::NodeData,
432 Self::GetReceipts(_) | Self::GetReceipts70(_) => EthMessageID::GetReceipts,
433 Self::Receipts(_) | Self::Receipts69(_) | Self::Receipts70(_) => EthMessageID::Receipts,
434 Self::BlockRangeUpdate(_) => EthMessageID::BlockRangeUpdate,
435 Self::GetBlockAccessLists(_) => EthMessageID::GetBlockAccessLists,
436 Self::BlockAccessLists(_) => EthMessageID::BlockAccessLists,
437 Self::Cells(_) => EthMessageID::Cells,
438 Self::GetCells(_) => EthMessageID::GetCells,
439 Self::Other(msg) => EthMessageID::Other(msg.id as u8),
440 }
441 }
442
443 pub const fn is_request(&self) -> bool {
445 matches!(
446 self,
447 Self::GetBlockBodies(_) |
448 Self::GetBlockHeaders(_) |
449 Self::GetReceipts(_) |
450 Self::GetReceipts70(_) |
451 Self::GetBlockAccessLists(_) |
452 Self::GetCells(_) |
453 Self::GetPooledTransactions(_) |
454 Self::GetNodeData(_)
455 )
456 }
457
458 pub const fn is_response(&self) -> bool {
460 matches!(
461 self,
462 Self::PooledTransactions(_) |
463 Self::Receipts(_) |
464 Self::Receipts69(_) |
465 Self::Receipts70(_) |
466 Self::BlockAccessLists(_) |
467 Self::BlockHeaders(_) |
468 Self::BlockBodies(_) |
469 Self::NodeData(_) |
470 Self::Cells(_)
471 )
472 }
473
474 pub fn map_versioned(self, version: EthVersion) -> Self {
479 if version >= EthVersion::Eth70 {
483 return match self {
484 Self::GetReceipts(pair) => {
485 let RequestPair { request_id, message } = pair;
486 let req = RequestPair {
487 request_id,
488 message: GetReceipts70 {
489 first_block_receipt_index: 0,
490 block_hashes: message.0,
491 },
492 };
493 Self::GetReceipts70(req)
494 }
495 other => other,
496 }
497 }
498
499 self
500 }
501}
502
503impl<N: NetworkPrimitives> Encodable for EthMessage<N> {
504 fn encode(&self, out: &mut dyn BufMut) {
505 match self {
506 Self::Status(status) => status.encode(out),
507 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.encode(out),
508 Self::NewBlock(new_block) => new_block.encode(out),
509 Self::Transactions(transactions) => transactions.encode(out),
510 Self::NewPooledTransactionHashes66(hashes) => hashes.encode(out),
511 Self::NewPooledTransactionHashes68(hashes) => hashes.encode(out),
512 Self::NewPooledTransactionHashes72(hashes) => hashes.encode(out),
513 Self::GetBlockHeaders(request) => request.encode(out),
514 Self::BlockHeaders(headers) => headers.encode(out),
515 Self::GetBlockBodies(request) => request.encode(out),
516 Self::BlockBodies(bodies) => bodies.encode(out),
517 Self::GetPooledTransactions(request) => request.encode(out),
518 Self::PooledTransactions(transactions) => transactions.encode(out),
519 Self::GetNodeData(request) => request.encode(out),
520 Self::NodeData(data) => data.encode(out),
521 Self::GetReceipts(request) => request.encode(out),
522 Self::GetReceipts70(request) => request.encode(out),
523 Self::GetBlockAccessLists(request) => request.encode(out),
524 Self::GetCells(request) => request.encode(out),
525 Self::Receipts(receipts) => receipts.encode(out),
526 Self::Receipts69(receipt69) => receipt69.encode(out),
527 Self::Receipts70(receipt70) => receipt70.encode(out),
528 Self::BlockAccessLists(block_access_lists) => block_access_lists.encode(out),
529 Self::BlockRangeUpdate(block_range_update) => block_range_update.encode(out),
530 Self::Cells(cells) => cells.encode(out),
531 Self::Other(unknown) => out.put_slice(&unknown.payload),
532 }
533 }
534 fn length(&self) -> usize {
535 match self {
536 Self::Status(status) => status.length(),
537 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.length(),
538 Self::NewBlock(new_block) => new_block.length(),
539 Self::Transactions(transactions) => transactions.length(),
540 Self::NewPooledTransactionHashes66(hashes) => hashes.length(),
541 Self::NewPooledTransactionHashes68(hashes) => hashes.length(),
542 Self::NewPooledTransactionHashes72(hashes) => hashes.length(),
543 Self::GetBlockHeaders(request) => request.length(),
544 Self::BlockHeaders(headers) => headers.length(),
545 Self::GetBlockBodies(request) => request.length(),
546 Self::BlockBodies(bodies) => bodies.length(),
547 Self::GetPooledTransactions(request) => request.length(),
548 Self::PooledTransactions(transactions) => transactions.length(),
549 Self::GetNodeData(request) => request.length(),
550 Self::NodeData(data) => data.length(),
551 Self::GetReceipts(request) => request.length(),
552 Self::GetReceipts70(request) => request.length(),
553 Self::GetBlockAccessLists(request) => request.length(),
554 Self::GetCells(request) => request.length(),
555 Self::Receipts(receipts) => receipts.length(),
556 Self::Receipts69(receipt69) => receipt69.length(),
557 Self::Receipts70(receipt70) => receipt70.length(),
558 Self::BlockAccessLists(block_access_lists) => block_access_lists.length(),
559 Self::BlockRangeUpdate(block_range_update) => block_range_update.length(),
560 Self::Cells(cells) => cells.length(),
561 Self::Other(unknown) => unknown.length(),
562 }
563 }
564}
565
566#[derive(Clone, Debug, PartialEq, Eq)]
574pub enum EthBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
575 NewBlock(Arc<N::NewBlockPayload>),
577 Transactions(SharedTransactions<N::BroadcastedTransaction>),
579}
580
581impl<N: NetworkPrimitives> EthBroadcastMessage<N> {
584 pub const fn message_id(&self) -> EthMessageID {
586 match self {
587 Self::NewBlock(_) => EthMessageID::NewBlock,
588 Self::Transactions(_) => EthMessageID::Transactions,
589 }
590 }
591}
592
593impl<N: NetworkPrimitives> Encodable for EthBroadcastMessage<N> {
594 fn encode(&self, out: &mut dyn BufMut) {
595 match self {
596 Self::NewBlock(new_block) => new_block.encode(out),
597 Self::Transactions(transactions) => transactions.encode(out),
598 }
599 }
600
601 fn length(&self) -> usize {
602 match self {
603 Self::NewBlock(new_block) => new_block.length(),
604 Self::Transactions(transactions) => transactions.length(),
605 }
606 }
607}
608
609#[repr(u8)]
611#[derive(Clone, Copy, Debug, PartialEq, Eq)]
612#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
613pub enum EthMessageID {
614 Status = 0x00,
616 NewBlockHashes = 0x01,
618 Transactions = 0x02,
620 GetBlockHeaders = 0x03,
622 BlockHeaders = 0x04,
624 GetBlockBodies = 0x05,
626 BlockBodies = 0x06,
628 NewBlock = 0x07,
630 NewPooledTransactionHashes = 0x08,
632 GetPooledTransactions = 0x09,
634 PooledTransactions = 0x0a,
636 GetNodeData = 0x0d,
638 NodeData = 0x0e,
640 GetReceipts = 0x0f,
642 Receipts = 0x10,
644 BlockRangeUpdate = 0x11,
648 GetBlockAccessLists = 0x12,
652 BlockAccessLists = 0x13,
656
657 GetCells = 0x14,
661 Cells = 0x15,
665 Other(u8),
667}
668
669impl EthMessageID {
670 pub const fn to_u8(&self) -> u8 {
672 match self {
673 Self::Status => 0x00,
674 Self::NewBlockHashes => 0x01,
675 Self::Transactions => 0x02,
676 Self::GetBlockHeaders => 0x03,
677 Self::BlockHeaders => 0x04,
678 Self::GetBlockBodies => 0x05,
679 Self::BlockBodies => 0x06,
680 Self::NewBlock => 0x07,
681 Self::NewPooledTransactionHashes => 0x08,
682 Self::GetPooledTransactions => 0x09,
683 Self::PooledTransactions => 0x0a,
684 Self::GetNodeData => 0x0d,
685 Self::NodeData => 0x0e,
686 Self::GetReceipts => 0x0f,
687 Self::Receipts => 0x10,
688 Self::BlockRangeUpdate => 0x11,
689 Self::GetBlockAccessLists => 0x12,
690 Self::BlockAccessLists => 0x13,
691 Self::GetCells => 0x14,
692 Self::Cells => 0x15,
693 Self::Other(value) => *value, }
695 }
696
697 pub const fn max(version: EthVersion) -> u8 {
699 if version.is_eth72() {
700 Self::Cells.to_u8()
701 } else if version.is_eth71() {
702 Self::BlockAccessLists.to_u8()
703 } else if version.is_eth69_or_newer() {
704 Self::BlockRangeUpdate.to_u8()
705 } else {
706 Self::Receipts.to_u8()
707 }
708 }
709
710 pub const fn message_count(version: EthVersion) -> u8 {
716 Self::max(version) + 1
717 }
718}
719
720impl Encodable for EthMessageID {
721 fn encode(&self, out: &mut dyn BufMut) {
722 out.put_u8(self.to_u8());
723 }
724 fn length(&self) -> usize {
725 1
726 }
727}
728
729impl Decodable for EthMessageID {
730 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
731 let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
732 0x00 => Self::Status,
733 0x01 => Self::NewBlockHashes,
734 0x02 => Self::Transactions,
735 0x03 => Self::GetBlockHeaders,
736 0x04 => Self::BlockHeaders,
737 0x05 => Self::GetBlockBodies,
738 0x06 => Self::BlockBodies,
739 0x07 => Self::NewBlock,
740 0x08 => Self::NewPooledTransactionHashes,
741 0x09 => Self::GetPooledTransactions,
742 0x0a => Self::PooledTransactions,
743 0x0d => Self::GetNodeData,
744 0x0e => Self::NodeData,
745 0x0f => Self::GetReceipts,
746 0x10 => Self::Receipts,
747 0x11 => Self::BlockRangeUpdate,
748 0x12 => Self::GetBlockAccessLists,
749 0x13 => Self::BlockAccessLists,
750 0x14 => Self::GetCells,
751 0x15 => Self::Cells,
752 unknown => Self::Other(*unknown),
753 };
754 buf.advance(1);
755 Ok(id)
756 }
757}
758
759impl TryFrom<usize> for EthMessageID {
760 type Error = &'static str;
761
762 fn try_from(value: usize) -> Result<Self, Self::Error> {
763 match value {
764 0x00 => Ok(Self::Status),
765 0x01 => Ok(Self::NewBlockHashes),
766 0x02 => Ok(Self::Transactions),
767 0x03 => Ok(Self::GetBlockHeaders),
768 0x04 => Ok(Self::BlockHeaders),
769 0x05 => Ok(Self::GetBlockBodies),
770 0x06 => Ok(Self::BlockBodies),
771 0x07 => Ok(Self::NewBlock),
772 0x08 => Ok(Self::NewPooledTransactionHashes),
773 0x09 => Ok(Self::GetPooledTransactions),
774 0x0a => Ok(Self::PooledTransactions),
775 0x0d => Ok(Self::GetNodeData),
776 0x0e => Ok(Self::NodeData),
777 0x0f => Ok(Self::GetReceipts),
778 0x10 => Ok(Self::Receipts),
779 0x11 => Ok(Self::BlockRangeUpdate),
780 0x12 => Ok(Self::GetBlockAccessLists),
781 0x13 => Ok(Self::BlockAccessLists),
782 0x14 => Ok(Self::GetCells),
783 0x15 => Ok(Self::Cells),
784 _ => Err("Invalid message ID"),
785 }
786 }
787}
788
789#[derive(Clone, Debug, PartialEq, Eq)]
793#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
794#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
795pub struct RequestPair<T> {
796 pub request_id: u64,
798
799 pub message: T,
801}
802
803impl<T> RequestPair<T> {
804 pub fn map<F, R>(self, f: F) -> RequestPair<R>
806 where
807 F: FnOnce(T) -> R,
808 {
809 let Self { request_id, message } = self;
810 RequestPair { request_id, message: f(message) }
811 }
812
813 pub fn decode_with<F>(buf: &mut &[u8], decode_msg: F) -> alloy_rlp::Result<Self>
815 where
816 F: FnOnce(&mut &[u8]) -> alloy_rlp::Result<T>,
817 {
818 let header = Header::decode(buf)?;
819
820 let initial_length = buf.len();
821 let request_id = u64::decode(buf)?;
822 let message = decode_msg(buf)?;
823
824 let consumed_len = initial_length - buf.len();
825 if consumed_len != header.payload_length {
826 return Err(alloy_rlp::Error::UnexpectedLength)
827 }
828
829 Ok(Self { request_id, message })
830 }
831}
832
833impl<T> Encodable for RequestPair<T>
835where
836 T: Encodable,
837{
838 fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
839 let header =
840 Header { list: true, payload_length: self.request_id.length() + self.message.length() };
841
842 header.encode(out);
843 self.request_id.encode(out);
844 self.message.encode(out);
845 }
846
847 fn length(&self) -> usize {
848 let mut length = 0;
849 length += self.request_id.length();
850 length += self.message.length();
851 length += length_of_length(length);
852 length
853 }
854}
855
856impl<T> Decodable for RequestPair<T>
858where
859 T: Decodable,
860{
861 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
862 let header = Header::decode(buf)?;
863
864 let initial_length = buf.len();
865 let request_id = u64::decode(buf)?;
866 let message = T::decode(buf)?;
867
868 let consumed_len = initial_length - buf.len();
871 if consumed_len != header.payload_length {
872 return Err(alloy_rlp::Error::UnexpectedLength)
873 }
874
875 Ok(Self { request_id, message })
876 }
877}
878
879#[cfg(test)]
880mod tests {
881 use super::MessageError;
882 use crate::{
883 message::RequestPair, BlockAccessLists, EthMessage, EthMessageID, EthNetworkPrimitives,
884 EthVersion, GetBlockAccessLists, GetNodeData, NodeData, ProtocolMessage,
885 RawCapabilityMessage,
886 };
887 use alloy_primitives::hex;
888 use alloy_rlp::{Decodable, Encodable, Error};
889 use reth_ethereum_primitives::BlockBody;
890
891 fn encode<T: Encodable>(value: T) -> Vec<u8> {
892 let mut buf = vec![];
893 value.encode(&mut buf);
894 buf
895 }
896
897 #[test]
898 fn test_removed_message_at_eth67() {
899 let get_node_data = EthMessage::<EthNetworkPrimitives>::GetNodeData(RequestPair {
900 request_id: 1337,
901 message: GetNodeData(vec![]),
902 });
903 let buf = encode(ProtocolMessage {
904 message_type: EthMessageID::GetNodeData,
905 message: get_node_data,
906 });
907 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
908 crate::EthVersion::Eth67,
909 &mut &buf[..],
910 );
911 assert!(matches!(msg, Err(MessageError::Invalid(..))));
912
913 let node_data = EthMessage::<EthNetworkPrimitives>::NodeData(RequestPair {
914 request_id: 1337,
915 message: NodeData(vec![]),
916 });
917 let buf =
918 encode(ProtocolMessage { message_type: EthMessageID::NodeData, message: node_data });
919 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
920 crate::EthVersion::Eth67,
921 &mut &buf[..],
922 );
923 assert!(matches!(msg, Err(MessageError::Invalid(..))));
924 }
925
926 #[test]
927 fn test_bal_message_version_gating() {
928 let get_block_access_lists =
929 EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
930 request_id: 1337,
931 message: GetBlockAccessLists(vec![]),
932 });
933 let buf = encode(ProtocolMessage {
934 message_type: EthMessageID::GetBlockAccessLists,
935 message: get_block_access_lists,
936 });
937 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
938 EthVersion::Eth70,
939 &mut &buf[..],
940 );
941 assert!(matches!(
942 msg,
943 Err(MessageError::Invalid(EthVersion::Eth70, EthMessageID::GetBlockAccessLists))
944 ));
945
946 let block_access_lists =
947 EthMessage::<EthNetworkPrimitives>::BlockAccessLists(RequestPair {
948 request_id: 1337,
949 message: BlockAccessLists(vec![]),
950 });
951 let buf = encode(ProtocolMessage {
952 message_type: EthMessageID::BlockAccessLists,
953 message: block_access_lists,
954 });
955 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
956 EthVersion::Eth70,
957 &mut &buf[..],
958 );
959 assert!(matches!(
960 msg,
961 Err(MessageError::Invalid(EthVersion::Eth70, EthMessageID::BlockAccessLists))
962 ));
963 }
964
965 #[test]
966 fn test_bal_message_eth71_roundtrip() {
967 let msg = ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(
968 RequestPair { request_id: 42, message: GetBlockAccessLists(vec![]) },
969 ));
970 let encoded = encode(msg.clone());
971 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
972 EthVersion::Eth71,
973 &mut &encoded[..],
974 )
975 .unwrap();
976
977 assert_eq!(decoded, msg);
978 }
979
980 #[test]
981 fn request_pair_encode() {
982 let request_pair = RequestPair { request_id: 1337, message: vec![5u8] };
983
984 let expected = hex!("c5820539c105");
991 let got = encode(request_pair);
992 assert_eq!(expected[..], got, "expected: {expected:X?}, got: {got:X?}",);
993 }
994
995 #[test]
996 fn request_pair_decode() {
997 let raw_pair = &hex!("c5820539c105")[..];
998
999 let expected = RequestPair { request_id: 1337, message: vec![5u8] };
1000
1001 let got = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair).unwrap();
1002 assert_eq!(expected.length(), raw_pair.len());
1003 assert_eq!(expected, got);
1004 }
1005
1006 #[test]
1007 fn malicious_request_pair_decode() {
1008 let raw_pair = &hex!("c5820539c20505")[..];
1018
1019 let result = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair);
1020 assert!(matches!(result, Err(Error::UnexpectedLength)));
1021 }
1022
1023 #[test]
1024 fn empty_block_bodies_protocol() {
1025 let empty_block_bodies =
1026 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
1027 request_id: 0,
1028 message: Default::default(),
1029 }));
1030 let mut buf = Vec::new();
1031 empty_block_bodies.encode(&mut buf);
1032 let decoded =
1033 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
1034 assert_eq!(empty_block_bodies, decoded);
1035 }
1036
1037 #[test]
1038 fn empty_block_body_protocol() {
1039 let empty_block_bodies =
1040 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
1041 request_id: 0,
1042 message: vec![BlockBody {
1043 transactions: vec![],
1044 ommers: vec![],
1045 withdrawals: Some(Default::default()),
1046 }]
1047 .into(),
1048 }));
1049 let mut buf = Vec::new();
1050 empty_block_bodies.encode(&mut buf);
1051 let decoded =
1052 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
1053 assert_eq!(empty_block_bodies, decoded);
1054 }
1055
1056 #[test]
1057 fn decode_block_bodies_message() {
1058 let buf = hex!("06c48199c1c0");
1059 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
1060 EthVersion::Eth68,
1061 &mut &buf[..],
1062 )
1063 .unwrap_err();
1064 assert!(matches!(msg, MessageError::RlpError(alloy_rlp::Error::InputTooShort)));
1065 }
1066
1067 #[test]
1068 fn custom_message_roundtrip() {
1069 let custom_payload = vec![1, 2, 3, 4, 5];
1070 let custom_message = RawCapabilityMessage::new(0x20, custom_payload.into());
1071 let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
1072 message_type: EthMessageID::Other(0x20),
1073 message: EthMessage::Other(custom_message),
1074 };
1075
1076 let encoded = encode(protocol_message.clone());
1077 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
1078 EthVersion::Eth68,
1079 &mut &encoded[..],
1080 )
1081 .unwrap();
1082
1083 assert_eq!(protocol_message, decoded);
1084 }
1085
1086 #[test]
1087 fn custom_message_empty_payload_roundtrip() {
1088 let custom_message = RawCapabilityMessage::new(0x30, vec![].into());
1089 let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
1090 message_type: EthMessageID::Other(0x30),
1091 message: EthMessage::Other(custom_message),
1092 };
1093
1094 let encoded = encode(protocol_message.clone());
1095 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
1096 EthVersion::Eth68,
1097 &mut &encoded[..],
1098 )
1099 .unwrap();
1100
1101 assert_eq!(protocol_message, decoded);
1102 }
1103
1104 #[test]
1105 fn decode_status_success() {
1106 use crate::{Status, StatusMessage};
1107 use alloy_hardforks::{ForkHash, ForkId};
1108 use alloy_primitives::{B256, U256};
1109
1110 let status = Status {
1111 version: EthVersion::Eth68,
1112 chain: alloy_chains::Chain::mainnet(),
1113 total_difficulty: U256::from(100u64),
1114 blockhash: B256::random(),
1115 genesis: B256::random(),
1116 forkid: ForkId { hash: ForkHash([0xb7, 0x15, 0x07, 0x7d]), next: 0 },
1117 };
1118
1119 let protocol_message = ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::Status(
1120 StatusMessage::Legacy(status),
1121 ));
1122 let encoded = encode(protocol_message);
1123
1124 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
1125 EthVersion::Eth68,
1126 &mut &encoded[..],
1127 )
1128 .unwrap();
1129
1130 assert!(matches!(decoded, StatusMessage::Legacy(s) if s == status));
1131 }
1132
1133 #[test]
1134 fn eth_message_id_max_includes_block_range_update() {
1135 assert_eq!(EthMessageID::max(EthVersion::Eth69), EthMessageID::BlockRangeUpdate.to_u8(),);
1136 assert_eq!(EthMessageID::max(EthVersion::Eth70), EthMessageID::BlockRangeUpdate.to_u8(),);
1137 assert_eq!(EthMessageID::max(EthVersion::Eth68), EthMessageID::Receipts.to_u8());
1138 }
1139
1140 #[test]
1141 fn decode_status_rejects_non_status() {
1142 let msg = EthMessage::<EthNetworkPrimitives>::GetBlockBodies(RequestPair {
1143 request_id: 1,
1144 message: crate::GetBlockBodies::default(),
1145 });
1146 let protocol_message =
1147 ProtocolMessage { message_type: EthMessageID::GetBlockBodies, message: msg };
1148 let encoded = encode(protocol_message);
1149
1150 let result = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
1151 EthVersion::Eth68,
1152 &mut &encoded[..],
1153 );
1154
1155 assert!(matches!(
1156 result,
1157 Err(MessageError::ExpectedStatusMessage(EthMessageID::GetBlockBodies))
1158 ));
1159 }
1160}