1use super::{
10 broadcast::NewBlockHashes, BlockAccessLists, BlockBodies, BlockHeaders, GetBlockAccessLists,
11 GetBlockBodies, GetBlockHeaders, GetNodeData, GetPooledTransactions, GetReceipts,
12 GetReceipts70, NewPooledTransactionHashes66, NewPooledTransactionHashes68, NodeData,
13 PooledTransactions, Receipts, Status, StatusEth69, Transactions,
14};
15use crate::{
16 status::StatusMessage, BlockRangeUpdate, EthNetworkPrimitives, EthVersion, NetworkPrimitives,
17 RawCapabilityMessage, Receipts69, Receipts70, SharedTransactions,
18};
19use alloc::{boxed::Box, string::String, sync::Arc};
20use alloy_primitives::{
21 bytes::{Buf, BufMut},
22 Bytes,
23};
24use alloy_rlp::{length_of_length, Decodable, Encodable, Header};
25use core::fmt::Debug;
26
27pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
30
31#[derive(thiserror::Error, Debug)]
33pub enum MessageError {
34 #[error("message id {1:?} is invalid for version {0:?}")]
36 Invalid(EthVersion, EthMessageID),
37 #[error("expected status message but received {0:?}")]
39 ExpectedStatusMessage(EthMessageID),
40 #[error("RLP error: {0}")]
42 RlpError(#[from] alloy_rlp::Error),
43 #[error("{0}")]
45 Other(String),
46}
47
48#[derive(Clone, Debug, PartialEq, Eq)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
52 pub message_type: EthMessageID,
54 #[cfg_attr(
56 feature = "serde",
57 serde(bound = "EthMessage<N>: serde::Serialize + serde::de::DeserializeOwned")
58 )]
59 pub message: EthMessage<N>,
60}
61
62impl<N: NetworkPrimitives> ProtocolMessage<N> {
63 pub fn decode_status(
68 version: EthVersion,
69 buf: &mut &[u8],
70 ) -> Result<StatusMessage, MessageError> {
71 let message_type = EthMessageID::decode(buf)?;
72
73 if message_type != EthMessageID::Status {
74 return Err(MessageError::ExpectedStatusMessage(message_type))
75 }
76
77 let status = if version < EthVersion::Eth69 {
78 StatusMessage::Legacy(Status::decode(buf)?)
79 } else {
80 StatusMessage::Eth69(StatusEth69::decode(buf)?)
81 };
82
83 Ok(status)
84 }
85
86 pub fn decode_message(version: EthVersion, buf: &mut &[u8]) -> Result<Self, MessageError> {
90 let message_type = EthMessageID::decode(buf)?;
91
92 let message = match message_type {
95 EthMessageID::Status => EthMessage::Status(if version < EthVersion::Eth69 {
96 StatusMessage::Legacy(Status::decode(buf)?)
97 } else {
98 StatusMessage::Eth69(StatusEth69::decode(buf)?)
99 }),
100 EthMessageID::NewBlockHashes => {
101 EthMessage::NewBlockHashes(NewBlockHashes::decode(buf)?)
102 }
103 EthMessageID::NewBlock => {
104 EthMessage::NewBlock(Box::new(N::NewBlockPayload::decode(buf)?))
105 }
106 EthMessageID::Transactions => EthMessage::Transactions(Transactions::decode(buf)?),
107 EthMessageID::NewPooledTransactionHashes => {
108 if version >= EthVersion::Eth68 {
109 EthMessage::NewPooledTransactionHashes68(NewPooledTransactionHashes68::decode(
110 buf,
111 )?)
112 } else {
113 EthMessage::NewPooledTransactionHashes66(NewPooledTransactionHashes66::decode(
114 buf,
115 )?)
116 }
117 }
118 EthMessageID::GetBlockHeaders => EthMessage::GetBlockHeaders(RequestPair::decode(buf)?),
119 EthMessageID::BlockHeaders => EthMessage::BlockHeaders(RequestPair::decode(buf)?),
120 EthMessageID::GetBlockBodies => EthMessage::GetBlockBodies(RequestPair::decode(buf)?),
121 EthMessageID::BlockBodies => EthMessage::BlockBodies(RequestPair::decode(buf)?),
122 EthMessageID::GetPooledTransactions => {
123 EthMessage::GetPooledTransactions(RequestPair::decode(buf)?)
124 }
125 EthMessageID::PooledTransactions => {
126 EthMessage::PooledTransactions(RequestPair::decode(buf)?)
127 }
128 EthMessageID::GetNodeData => {
129 if version >= EthVersion::Eth67 {
130 return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
131 }
132 EthMessage::GetNodeData(RequestPair::decode(buf)?)
133 }
134 EthMessageID::NodeData => {
135 if version >= EthVersion::Eth67 {
136 return Err(MessageError::Invalid(version, EthMessageID::NodeData))
137 }
138 EthMessage::NodeData(RequestPair::decode(buf)?)
139 }
140 EthMessageID::GetReceipts => {
141 if version >= EthVersion::Eth70 {
142 EthMessage::GetReceipts70(RequestPair::decode(buf)?)
143 } else {
144 EthMessage::GetReceipts(RequestPair::decode(buf)?)
145 }
146 }
147 EthMessageID::Receipts => {
148 match version {
149 v if v >= EthVersion::Eth70 => {
150 EthMessage::Receipts70(RequestPair::decode(buf)?)
154 }
155 EthVersion::Eth69 => {
156 EthMessage::Receipts69(RequestPair::decode(buf)?)
158 }
159 _ => {
160 EthMessage::Receipts(RequestPair::decode(buf)?)
162 }
163 }
164 }
165 EthMessageID::BlockRangeUpdate => {
166 if version < EthVersion::Eth69 {
167 return Err(MessageError::Invalid(version, EthMessageID::BlockRangeUpdate))
168 }
169 EthMessage::BlockRangeUpdate(BlockRangeUpdate::decode(buf)?)
170 }
171 EthMessageID::GetBlockAccessLists => {
172 if version < EthVersion::Eth71 {
173 return Err(MessageError::Invalid(version, EthMessageID::GetBlockAccessLists))
174 }
175 EthMessage::GetBlockAccessLists(RequestPair::decode(buf)?)
176 }
177 EthMessageID::BlockAccessLists => {
178 if version < EthVersion::Eth71 {
179 return Err(MessageError::Invalid(version, EthMessageID::BlockAccessLists))
180 }
181 EthMessage::BlockAccessLists(RequestPair::decode(buf)?)
182 }
183 EthMessageID::Other(_) => {
184 let raw_payload = Bytes::copy_from_slice(buf);
185 buf.advance(raw_payload.len());
186 EthMessage::Other(RawCapabilityMessage::new(
187 message_type.to_u8() as usize,
188 raw_payload.into(),
189 ))
190 }
191 };
192 Ok(Self { message_type, message })
193 }
194}
195
196impl<N: NetworkPrimitives> Encodable for ProtocolMessage<N> {
197 fn encode(&self, out: &mut dyn BufMut) {
200 self.message_type.encode(out);
201 self.message.encode(out);
202 }
203 fn length(&self) -> usize {
204 self.message_type.length() + self.message.length()
205 }
206}
207
208impl<N: NetworkPrimitives> From<EthMessage<N>> for ProtocolMessage<N> {
209 fn from(message: EthMessage<N>) -> Self {
210 Self { message_type: message.message_id(), message }
211 }
212}
213
214#[derive(Clone, Debug)]
216pub struct ProtocolBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
217 pub message_type: EthMessageID,
219 pub message: EthBroadcastMessage<N>,
222}
223
224impl<N: NetworkPrimitives> Encodable for ProtocolBroadcastMessage<N> {
225 fn encode(&self, out: &mut dyn BufMut) {
228 self.message_type.encode(out);
229 self.message.encode(out);
230 }
231 fn length(&self) -> usize {
232 self.message_type.length() + self.message.length()
233 }
234}
235
236impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for ProtocolBroadcastMessage<N> {
237 fn from(message: EthBroadcastMessage<N>) -> Self {
238 Self { message_type: message.message_id(), message }
239 }
240}
241
242#[derive(Clone, Debug, PartialEq, Eq)]
268#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
269pub enum EthMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
270 Status(StatusMessage),
272 NewBlockHashes(NewBlockHashes),
274 #[cfg_attr(
276 feature = "serde",
277 serde(bound = "N::NewBlockPayload: serde::Serialize + serde::de::DeserializeOwned")
278 )]
279 NewBlock(Box<N::NewBlockPayload>),
280 #[cfg_attr(
282 feature = "serde",
283 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
284 )]
285 Transactions(Transactions<N::BroadcastedTransaction>),
286 NewPooledTransactionHashes66(NewPooledTransactionHashes66),
288 NewPooledTransactionHashes68(NewPooledTransactionHashes68),
290 GetBlockHeaders(RequestPair<GetBlockHeaders>),
293 #[cfg_attr(
295 feature = "serde",
296 serde(bound = "N::BlockHeader: serde::Serialize + serde::de::DeserializeOwned")
297 )]
298 BlockHeaders(RequestPair<BlockHeaders<N::BlockHeader>>),
299 GetBlockBodies(RequestPair<GetBlockBodies>),
301 #[cfg_attr(
303 feature = "serde",
304 serde(bound = "N::BlockBody: serde::Serialize + serde::de::DeserializeOwned")
305 )]
306 BlockBodies(RequestPair<BlockBodies<N::BlockBody>>),
307 GetPooledTransactions(RequestPair<GetPooledTransactions>),
309 #[cfg_attr(
311 feature = "serde",
312 serde(bound = "N::PooledTransaction: serde::Serialize + serde::de::DeserializeOwned")
313 )]
314 PooledTransactions(RequestPair<PooledTransactions<N::PooledTransaction>>),
315 GetNodeData(RequestPair<GetNodeData>),
317 NodeData(RequestPair<NodeData>),
319 GetReceipts(RequestPair<GetReceipts>),
321 GetReceipts70(RequestPair<GetReceipts70>),
327 GetBlockAccessLists(RequestPair<GetBlockAccessLists>),
329 #[cfg_attr(
331 feature = "serde",
332 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
333 )]
334 Receipts(RequestPair<Receipts<N::Receipt>>),
335 #[cfg_attr(
337 feature = "serde",
338 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
339 )]
340 Receipts69(RequestPair<Receipts69<N::Receipt>>),
341 #[cfg_attr(
343 feature = "serde",
344 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
345 )]
346 Receipts70(RequestPair<Receipts70<N::Receipt>>),
351 BlockAccessLists(RequestPair<BlockAccessLists>),
353 #[cfg_attr(
355 feature = "serde",
356 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
357 )]
358 BlockRangeUpdate(BlockRangeUpdate),
359 Other(RawCapabilityMessage),
361}
362
363impl<N: NetworkPrimitives> EthMessage<N> {
364 pub const fn message_id(&self) -> EthMessageID {
366 match self {
367 Self::Status(_) => EthMessageID::Status,
368 Self::NewBlockHashes(_) => EthMessageID::NewBlockHashes,
369 Self::NewBlock(_) => EthMessageID::NewBlock,
370 Self::Transactions(_) => EthMessageID::Transactions,
371 Self::NewPooledTransactionHashes66(_) | Self::NewPooledTransactionHashes68(_) => {
372 EthMessageID::NewPooledTransactionHashes
373 }
374 Self::GetBlockHeaders(_) => EthMessageID::GetBlockHeaders,
375 Self::BlockHeaders(_) => EthMessageID::BlockHeaders,
376 Self::GetBlockBodies(_) => EthMessageID::GetBlockBodies,
377 Self::BlockBodies(_) => EthMessageID::BlockBodies,
378 Self::GetPooledTransactions(_) => EthMessageID::GetPooledTransactions,
379 Self::PooledTransactions(_) => EthMessageID::PooledTransactions,
380 Self::GetNodeData(_) => EthMessageID::GetNodeData,
381 Self::NodeData(_) => EthMessageID::NodeData,
382 Self::GetReceipts(_) | Self::GetReceipts70(_) => EthMessageID::GetReceipts,
383 Self::Receipts(_) | Self::Receipts69(_) | Self::Receipts70(_) => EthMessageID::Receipts,
384 Self::BlockRangeUpdate(_) => EthMessageID::BlockRangeUpdate,
385 Self::GetBlockAccessLists(_) => EthMessageID::GetBlockAccessLists,
386 Self::BlockAccessLists(_) => EthMessageID::BlockAccessLists,
387 Self::Other(msg) => EthMessageID::Other(msg.id as u8),
388 }
389 }
390
391 pub const fn is_request(&self) -> bool {
393 matches!(
394 self,
395 Self::GetBlockBodies(_) |
396 Self::GetBlockHeaders(_) |
397 Self::GetReceipts(_) |
398 Self::GetReceipts70(_) |
399 Self::GetBlockAccessLists(_) |
400 Self::GetPooledTransactions(_) |
401 Self::GetNodeData(_)
402 )
403 }
404
405 pub const fn is_response(&self) -> bool {
407 matches!(
408 self,
409 Self::PooledTransactions(_) |
410 Self::Receipts(_) |
411 Self::Receipts69(_) |
412 Self::Receipts70(_) |
413 Self::BlockAccessLists(_) |
414 Self::BlockHeaders(_) |
415 Self::BlockBodies(_) |
416 Self::NodeData(_)
417 )
418 }
419
420 pub fn map_versioned(self, version: EthVersion) -> Self {
425 if version >= EthVersion::Eth70 {
429 return match self {
430 Self::GetReceipts(pair) => {
431 let RequestPair { request_id, message } = pair;
432 let req = RequestPair {
433 request_id,
434 message: GetReceipts70 {
435 first_block_receipt_index: 0,
436 block_hashes: message.0,
437 },
438 };
439 Self::GetReceipts70(req)
440 }
441 other => other,
442 }
443 }
444
445 self
446 }
447}
448
449impl<N: NetworkPrimitives> Encodable for EthMessage<N> {
450 fn encode(&self, out: &mut dyn BufMut) {
451 match self {
452 Self::Status(status) => status.encode(out),
453 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.encode(out),
454 Self::NewBlock(new_block) => new_block.encode(out),
455 Self::Transactions(transactions) => transactions.encode(out),
456 Self::NewPooledTransactionHashes66(hashes) => hashes.encode(out),
457 Self::NewPooledTransactionHashes68(hashes) => hashes.encode(out),
458 Self::GetBlockHeaders(request) => request.encode(out),
459 Self::BlockHeaders(headers) => headers.encode(out),
460 Self::GetBlockBodies(request) => request.encode(out),
461 Self::BlockBodies(bodies) => bodies.encode(out),
462 Self::GetPooledTransactions(request) => request.encode(out),
463 Self::PooledTransactions(transactions) => transactions.encode(out),
464 Self::GetNodeData(request) => request.encode(out),
465 Self::NodeData(data) => data.encode(out),
466 Self::GetReceipts(request) => request.encode(out),
467 Self::GetReceipts70(request) => request.encode(out),
468 Self::GetBlockAccessLists(request) => request.encode(out),
469 Self::Receipts(receipts) => receipts.encode(out),
470 Self::Receipts69(receipt69) => receipt69.encode(out),
471 Self::Receipts70(receipt70) => receipt70.encode(out),
472 Self::BlockAccessLists(block_access_lists) => block_access_lists.encode(out),
473 Self::BlockRangeUpdate(block_range_update) => block_range_update.encode(out),
474 Self::Other(unknown) => out.put_slice(&unknown.payload),
475 }
476 }
477 fn length(&self) -> usize {
478 match self {
479 Self::Status(status) => status.length(),
480 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.length(),
481 Self::NewBlock(new_block) => new_block.length(),
482 Self::Transactions(transactions) => transactions.length(),
483 Self::NewPooledTransactionHashes66(hashes) => hashes.length(),
484 Self::NewPooledTransactionHashes68(hashes) => hashes.length(),
485 Self::GetBlockHeaders(request) => request.length(),
486 Self::BlockHeaders(headers) => headers.length(),
487 Self::GetBlockBodies(request) => request.length(),
488 Self::BlockBodies(bodies) => bodies.length(),
489 Self::GetPooledTransactions(request) => request.length(),
490 Self::PooledTransactions(transactions) => transactions.length(),
491 Self::GetNodeData(request) => request.length(),
492 Self::NodeData(data) => data.length(),
493 Self::GetReceipts(request) => request.length(),
494 Self::GetReceipts70(request) => request.length(),
495 Self::GetBlockAccessLists(request) => request.length(),
496 Self::Receipts(receipts) => receipts.length(),
497 Self::Receipts69(receipt69) => receipt69.length(),
498 Self::Receipts70(receipt70) => receipt70.length(),
499 Self::BlockAccessLists(block_access_lists) => block_access_lists.length(),
500 Self::BlockRangeUpdate(block_range_update) => block_range_update.length(),
501 Self::Other(unknown) => unknown.length(),
502 }
503 }
504}
505
506#[derive(Clone, Debug, PartialEq, Eq)]
514pub enum EthBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
515 NewBlock(Arc<N::NewBlockPayload>),
517 Transactions(SharedTransactions<N::BroadcastedTransaction>),
519}
520
521impl<N: NetworkPrimitives> EthBroadcastMessage<N> {
524 pub const fn message_id(&self) -> EthMessageID {
526 match self {
527 Self::NewBlock(_) => EthMessageID::NewBlock,
528 Self::Transactions(_) => EthMessageID::Transactions,
529 }
530 }
531}
532
533impl<N: NetworkPrimitives> Encodable for EthBroadcastMessage<N> {
534 fn encode(&self, out: &mut dyn BufMut) {
535 match self {
536 Self::NewBlock(new_block) => new_block.encode(out),
537 Self::Transactions(transactions) => transactions.encode(out),
538 }
539 }
540
541 fn length(&self) -> usize {
542 match self {
543 Self::NewBlock(new_block) => new_block.length(),
544 Self::Transactions(transactions) => transactions.length(),
545 }
546 }
547}
548
549#[repr(u8)]
551#[derive(Clone, Copy, Debug, PartialEq, Eq)]
552#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
553pub enum EthMessageID {
554 Status = 0x00,
556 NewBlockHashes = 0x01,
558 Transactions = 0x02,
560 GetBlockHeaders = 0x03,
562 BlockHeaders = 0x04,
564 GetBlockBodies = 0x05,
566 BlockBodies = 0x06,
568 NewBlock = 0x07,
570 NewPooledTransactionHashes = 0x08,
572 GetPooledTransactions = 0x09,
574 PooledTransactions = 0x0a,
576 GetNodeData = 0x0d,
578 NodeData = 0x0e,
580 GetReceipts = 0x0f,
582 Receipts = 0x10,
584 BlockRangeUpdate = 0x11,
588 GetBlockAccessLists = 0x12,
592 BlockAccessLists = 0x13,
596 Other(u8),
598}
599
600impl EthMessageID {
601 pub const fn to_u8(&self) -> u8 {
603 match self {
604 Self::Status => 0x00,
605 Self::NewBlockHashes => 0x01,
606 Self::Transactions => 0x02,
607 Self::GetBlockHeaders => 0x03,
608 Self::BlockHeaders => 0x04,
609 Self::GetBlockBodies => 0x05,
610 Self::BlockBodies => 0x06,
611 Self::NewBlock => 0x07,
612 Self::NewPooledTransactionHashes => 0x08,
613 Self::GetPooledTransactions => 0x09,
614 Self::PooledTransactions => 0x0a,
615 Self::GetNodeData => 0x0d,
616 Self::NodeData => 0x0e,
617 Self::GetReceipts => 0x0f,
618 Self::Receipts => 0x10,
619 Self::BlockRangeUpdate => 0x11,
620 Self::GetBlockAccessLists => 0x12,
621 Self::BlockAccessLists => 0x13,
622 Self::Other(value) => *value, }
624 }
625
626 pub const fn max(version: EthVersion) -> u8 {
628 if version.is_eth71() {
629 Self::BlockAccessLists.to_u8()
630 } else if version.is_eth69_or_newer() {
631 Self::BlockRangeUpdate.to_u8()
632 } else {
633 Self::Receipts.to_u8()
634 }
635 }
636
637 pub const fn message_count(version: EthVersion) -> u8 {
643 Self::max(version) + 1
644 }
645}
646
647impl Encodable for EthMessageID {
648 fn encode(&self, out: &mut dyn BufMut) {
649 out.put_u8(self.to_u8());
650 }
651 fn length(&self) -> usize {
652 1
653 }
654}
655
656impl Decodable for EthMessageID {
657 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
658 let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
659 0x00 => Self::Status,
660 0x01 => Self::NewBlockHashes,
661 0x02 => Self::Transactions,
662 0x03 => Self::GetBlockHeaders,
663 0x04 => Self::BlockHeaders,
664 0x05 => Self::GetBlockBodies,
665 0x06 => Self::BlockBodies,
666 0x07 => Self::NewBlock,
667 0x08 => Self::NewPooledTransactionHashes,
668 0x09 => Self::GetPooledTransactions,
669 0x0a => Self::PooledTransactions,
670 0x0d => Self::GetNodeData,
671 0x0e => Self::NodeData,
672 0x0f => Self::GetReceipts,
673 0x10 => Self::Receipts,
674 0x11 => Self::BlockRangeUpdate,
675 0x12 => Self::GetBlockAccessLists,
676 0x13 => Self::BlockAccessLists,
677 unknown => Self::Other(*unknown),
678 };
679 buf.advance(1);
680 Ok(id)
681 }
682}
683
684impl TryFrom<usize> for EthMessageID {
685 type Error = &'static str;
686
687 fn try_from(value: usize) -> Result<Self, Self::Error> {
688 match value {
689 0x00 => Ok(Self::Status),
690 0x01 => Ok(Self::NewBlockHashes),
691 0x02 => Ok(Self::Transactions),
692 0x03 => Ok(Self::GetBlockHeaders),
693 0x04 => Ok(Self::BlockHeaders),
694 0x05 => Ok(Self::GetBlockBodies),
695 0x06 => Ok(Self::BlockBodies),
696 0x07 => Ok(Self::NewBlock),
697 0x08 => Ok(Self::NewPooledTransactionHashes),
698 0x09 => Ok(Self::GetPooledTransactions),
699 0x0a => Ok(Self::PooledTransactions),
700 0x0d => Ok(Self::GetNodeData),
701 0x0e => Ok(Self::NodeData),
702 0x0f => Ok(Self::GetReceipts),
703 0x10 => Ok(Self::Receipts),
704 0x11 => Ok(Self::BlockRangeUpdate),
705 0x12 => Ok(Self::GetBlockAccessLists),
706 0x13 => Ok(Self::BlockAccessLists),
707 _ => Err("Invalid message ID"),
708 }
709 }
710}
711
712#[derive(Clone, Debug, PartialEq, Eq)]
716#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
717#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
718pub struct RequestPair<T> {
719 pub request_id: u64,
721
722 pub message: T,
724}
725
726impl<T> RequestPair<T> {
727 pub fn map<F, R>(self, f: F) -> RequestPair<R>
729 where
730 F: FnOnce(T) -> R,
731 {
732 let Self { request_id, message } = self;
733 RequestPair { request_id, message: f(message) }
734 }
735}
736
737impl<T> Encodable for RequestPair<T>
739where
740 T: Encodable,
741{
742 fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
743 let header =
744 Header { list: true, payload_length: self.request_id.length() + self.message.length() };
745
746 header.encode(out);
747 self.request_id.encode(out);
748 self.message.encode(out);
749 }
750
751 fn length(&self) -> usize {
752 let mut length = 0;
753 length += self.request_id.length();
754 length += self.message.length();
755 length += length_of_length(length);
756 length
757 }
758}
759
760impl<T> Decodable for RequestPair<T>
762where
763 T: Decodable,
764{
765 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
766 let header = Header::decode(buf)?;
767
768 let initial_length = buf.len();
769 let request_id = u64::decode(buf)?;
770 let message = T::decode(buf)?;
771
772 let consumed_len = initial_length - buf.len();
775 if consumed_len != header.payload_length {
776 return Err(alloy_rlp::Error::UnexpectedLength)
777 }
778
779 Ok(Self { request_id, message })
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::MessageError;
786 use crate::{
787 message::RequestPair, BlockAccessLists, EthMessage, EthMessageID, EthNetworkPrimitives,
788 EthVersion, GetBlockAccessLists, GetNodeData, NodeData, ProtocolMessage,
789 RawCapabilityMessage,
790 };
791 use alloy_primitives::hex;
792 use alloy_rlp::{Decodable, Encodable, Error};
793 use reth_ethereum_primitives::BlockBody;
794
795 fn encode<T: Encodable>(value: T) -> Vec<u8> {
796 let mut buf = vec![];
797 value.encode(&mut buf);
798 buf
799 }
800
801 #[test]
802 fn test_removed_message_at_eth67() {
803 let get_node_data = EthMessage::<EthNetworkPrimitives>::GetNodeData(RequestPair {
804 request_id: 1337,
805 message: GetNodeData(vec![]),
806 });
807 let buf = encode(ProtocolMessage {
808 message_type: EthMessageID::GetNodeData,
809 message: get_node_data,
810 });
811 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
812 crate::EthVersion::Eth67,
813 &mut &buf[..],
814 );
815 assert!(matches!(msg, Err(MessageError::Invalid(..))));
816
817 let node_data = EthMessage::<EthNetworkPrimitives>::NodeData(RequestPair {
818 request_id: 1337,
819 message: NodeData(vec![]),
820 });
821 let buf =
822 encode(ProtocolMessage { message_type: EthMessageID::NodeData, message: node_data });
823 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
824 crate::EthVersion::Eth67,
825 &mut &buf[..],
826 );
827 assert!(matches!(msg, Err(MessageError::Invalid(..))));
828 }
829
830 #[test]
831 fn test_bal_message_version_gating() {
832 let get_block_access_lists =
833 EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
834 request_id: 1337,
835 message: GetBlockAccessLists(vec![]),
836 });
837 let buf = encode(ProtocolMessage {
838 message_type: EthMessageID::GetBlockAccessLists,
839 message: get_block_access_lists,
840 });
841 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
842 EthVersion::Eth70,
843 &mut &buf[..],
844 );
845 assert!(matches!(
846 msg,
847 Err(MessageError::Invalid(EthVersion::Eth70, EthMessageID::GetBlockAccessLists))
848 ));
849
850 let block_access_lists =
851 EthMessage::<EthNetworkPrimitives>::BlockAccessLists(RequestPair {
852 request_id: 1337,
853 message: BlockAccessLists(vec![]),
854 });
855 let buf = encode(ProtocolMessage {
856 message_type: EthMessageID::BlockAccessLists,
857 message: block_access_lists,
858 });
859 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
860 EthVersion::Eth70,
861 &mut &buf[..],
862 );
863 assert!(matches!(
864 msg,
865 Err(MessageError::Invalid(EthVersion::Eth70, EthMessageID::BlockAccessLists))
866 ));
867 }
868
869 #[test]
870 fn test_bal_message_eth71_roundtrip() {
871 let msg = ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(
872 RequestPair { request_id: 42, message: GetBlockAccessLists(vec![]) },
873 ));
874 let encoded = encode(msg.clone());
875 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
876 EthVersion::Eth71,
877 &mut &encoded[..],
878 )
879 .unwrap();
880
881 assert_eq!(decoded, msg);
882 }
883
884 #[test]
885 fn request_pair_encode() {
886 let request_pair = RequestPair { request_id: 1337, message: vec![5u8] };
887
888 let expected = hex!("c5820539c105");
895 let got = encode(request_pair);
896 assert_eq!(expected[..], got, "expected: {expected:X?}, got: {got:X?}",);
897 }
898
899 #[test]
900 fn request_pair_decode() {
901 let raw_pair = &hex!("c5820539c105")[..];
902
903 let expected = RequestPair { request_id: 1337, message: vec![5u8] };
904
905 let got = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair).unwrap();
906 assert_eq!(expected.length(), raw_pair.len());
907 assert_eq!(expected, got);
908 }
909
910 #[test]
911 fn malicious_request_pair_decode() {
912 let raw_pair = &hex!("c5820539c20505")[..];
922
923 let result = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair);
924 assert!(matches!(result, Err(Error::UnexpectedLength)));
925 }
926
927 #[test]
928 fn empty_block_bodies_protocol() {
929 let empty_block_bodies =
930 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
931 request_id: 0,
932 message: Default::default(),
933 }));
934 let mut buf = Vec::new();
935 empty_block_bodies.encode(&mut buf);
936 let decoded =
937 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
938 assert_eq!(empty_block_bodies, decoded);
939 }
940
941 #[test]
942 fn empty_block_body_protocol() {
943 let empty_block_bodies =
944 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
945 request_id: 0,
946 message: vec![BlockBody {
947 transactions: vec![],
948 ommers: vec![],
949 withdrawals: Some(Default::default()),
950 }]
951 .into(),
952 }));
953 let mut buf = Vec::new();
954 empty_block_bodies.encode(&mut buf);
955 let decoded =
956 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
957 assert_eq!(empty_block_bodies, decoded);
958 }
959
960 #[test]
961 fn decode_block_bodies_message() {
962 let buf = hex!("06c48199c1c0");
963 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
964 EthVersion::Eth68,
965 &mut &buf[..],
966 )
967 .unwrap_err();
968 assert!(matches!(msg, MessageError::RlpError(alloy_rlp::Error::InputTooShort)));
969 }
970
971 #[test]
972 fn custom_message_roundtrip() {
973 let custom_payload = vec![1, 2, 3, 4, 5];
974 let custom_message = RawCapabilityMessage::new(0x20, custom_payload.into());
975 let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
976 message_type: EthMessageID::Other(0x20),
977 message: EthMessage::Other(custom_message),
978 };
979
980 let encoded = encode(protocol_message.clone());
981 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
982 EthVersion::Eth68,
983 &mut &encoded[..],
984 )
985 .unwrap();
986
987 assert_eq!(protocol_message, decoded);
988 }
989
990 #[test]
991 fn custom_message_empty_payload_roundtrip() {
992 let custom_message = RawCapabilityMessage::new(0x30, vec![].into());
993 let protocol_message = ProtocolMessage::<EthNetworkPrimitives> {
994 message_type: EthMessageID::Other(0x30),
995 message: EthMessage::Other(custom_message),
996 };
997
998 let encoded = encode(protocol_message.clone());
999 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
1000 EthVersion::Eth68,
1001 &mut &encoded[..],
1002 )
1003 .unwrap();
1004
1005 assert_eq!(protocol_message, decoded);
1006 }
1007
1008 #[test]
1009 fn decode_status_success() {
1010 use crate::{Status, StatusMessage};
1011 use alloy_hardforks::{ForkHash, ForkId};
1012 use alloy_primitives::{B256, U256};
1013
1014 let status = Status {
1015 version: EthVersion::Eth68,
1016 chain: alloy_chains::Chain::mainnet(),
1017 total_difficulty: U256::from(100u64),
1018 blockhash: B256::random(),
1019 genesis: B256::random(),
1020 forkid: ForkId { hash: ForkHash([0xb7, 0x15, 0x07, 0x7d]), next: 0 },
1021 };
1022
1023 let protocol_message = ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::Status(
1024 StatusMessage::Legacy(status),
1025 ));
1026 let encoded = encode(protocol_message);
1027
1028 let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
1029 EthVersion::Eth68,
1030 &mut &encoded[..],
1031 )
1032 .unwrap();
1033
1034 assert!(matches!(decoded, StatusMessage::Legacy(s) if s == status));
1035 }
1036
1037 #[test]
1038 fn eth_message_id_max_includes_block_range_update() {
1039 assert_eq!(EthMessageID::max(EthVersion::Eth69), EthMessageID::BlockRangeUpdate.to_u8(),);
1040 assert_eq!(EthMessageID::max(EthVersion::Eth70), EthMessageID::BlockRangeUpdate.to_u8(),);
1041 assert_eq!(EthMessageID::max(EthVersion::Eth68), EthMessageID::Receipts.to_u8());
1042 }
1043
1044 #[test]
1045 fn decode_status_rejects_non_status() {
1046 let msg = EthMessage::<EthNetworkPrimitives>::GetBlockBodies(RequestPair {
1047 request_id: 1,
1048 message: crate::GetBlockBodies::default(),
1049 });
1050 let protocol_message =
1051 ProtocolMessage { message_type: EthMessageID::GetBlockBodies, message: msg };
1052 let encoded = encode(protocol_message);
1053
1054 let result = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
1055 EthVersion::Eth68,
1056 &mut &encoded[..],
1057 );
1058
1059 assert!(matches!(
1060 result,
1061 Err(MessageError::ExpectedStatusMessage(EthMessageID::GetBlockBodies))
1062 ));
1063 }
1064}