1use super::{
10 broadcast::NewBlockHashes, BlockBodies, BlockHeaders, GetBlockBodies, GetBlockHeaders,
11 GetNodeData, GetPooledTransactions, GetReceipts, NewBlock, NewPooledTransactionHashes66,
12 NewPooledTransactionHashes68, NodeData, PooledTransactions, Receipts, Status, Transactions,
13};
14use crate::{EthNetworkPrimitives, EthVersion, NetworkPrimitives, SharedTransactions};
15use alloc::{boxed::Box, sync::Arc};
16use alloy_primitives::bytes::{Buf, BufMut};
17use alloy_rlp::{length_of_length, Decodable, Encodable, Header};
18use core::fmt::Debug;
19
20pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
23
24#[derive(thiserror::Error, Debug)]
26pub enum MessageError {
27 #[error("message id {1:?} is invalid for version {0:?}")]
29 Invalid(EthVersion, EthMessageID),
30 #[error("RLP error: {0}")]
32 RlpError(#[from] alloy_rlp::Error),
33}
34
35#[derive(Clone, Debug, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
39 pub message_type: EthMessageID,
41 #[cfg_attr(
43 feature = "serde",
44 serde(bound = "EthMessage<N>: serde::Serialize + serde::de::DeserializeOwned")
45 )]
46 pub message: EthMessage<N>,
47}
48
49impl<N: NetworkPrimitives> ProtocolMessage<N> {
50 pub fn decode_message(version: EthVersion, buf: &mut &[u8]) -> Result<Self, MessageError> {
52 let message_type = EthMessageID::decode(buf)?;
53
54 let message = match message_type {
55 EthMessageID::Status => EthMessage::Status(Status::decode(buf)?),
56 EthMessageID::NewBlockHashes => {
57 if version.is_eth69() {
58 return Err(MessageError::Invalid(version, EthMessageID::NewBlockHashes));
59 }
60 EthMessage::NewBlockHashes(NewBlockHashes::decode(buf)?)
61 }
62 EthMessageID::NewBlock => {
63 if version.is_eth69() {
64 return Err(MessageError::Invalid(version, EthMessageID::NewBlock));
65 }
66 EthMessage::NewBlock(Box::new(NewBlock::decode(buf)?))
67 }
68 EthMessageID::Transactions => EthMessage::Transactions(Transactions::decode(buf)?),
69 EthMessageID::NewPooledTransactionHashes => {
70 if version >= EthVersion::Eth68 {
71 EthMessage::NewPooledTransactionHashes68(NewPooledTransactionHashes68::decode(
72 buf,
73 )?)
74 } else {
75 EthMessage::NewPooledTransactionHashes66(NewPooledTransactionHashes66::decode(
76 buf,
77 )?)
78 }
79 }
80 EthMessageID::GetBlockHeaders => EthMessage::GetBlockHeaders(RequestPair::decode(buf)?),
81 EthMessageID::BlockHeaders => EthMessage::BlockHeaders(RequestPair::decode(buf)?),
82 EthMessageID::GetBlockBodies => EthMessage::GetBlockBodies(RequestPair::decode(buf)?),
83 EthMessageID::BlockBodies => EthMessage::BlockBodies(RequestPair::decode(buf)?),
84 EthMessageID::GetPooledTransactions => {
85 EthMessage::GetPooledTransactions(RequestPair::decode(buf)?)
86 }
87 EthMessageID::PooledTransactions => {
88 EthMessage::PooledTransactions(RequestPair::decode(buf)?)
89 }
90 EthMessageID::GetNodeData => {
91 if version >= EthVersion::Eth67 {
92 return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
93 }
94 EthMessage::GetNodeData(RequestPair::decode(buf)?)
95 }
96 EthMessageID::NodeData => {
97 if version >= EthVersion::Eth67 {
98 return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
99 }
100 EthMessage::NodeData(RequestPair::decode(buf)?)
101 }
102 EthMessageID::GetReceipts => EthMessage::GetReceipts(RequestPair::decode(buf)?),
103 EthMessageID::Receipts => EthMessage::Receipts(RequestPair::decode(buf)?),
104 };
105 Ok(Self { message_type, message })
106 }
107}
108
109impl<N: NetworkPrimitives> Encodable for ProtocolMessage<N> {
110 fn encode(&self, out: &mut dyn BufMut) {
113 self.message_type.encode(out);
114 self.message.encode(out);
115 }
116 fn length(&self) -> usize {
117 self.message_type.length() + self.message.length()
118 }
119}
120
121impl<N: NetworkPrimitives> From<EthMessage<N>> for ProtocolMessage<N> {
122 fn from(message: EthMessage<N>) -> Self {
123 Self { message_type: message.message_id(), message }
124 }
125}
126
127#[derive(Clone, Debug)]
129pub struct ProtocolBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
130 pub message_type: EthMessageID,
132 pub message: EthBroadcastMessage<N>,
135}
136
137impl<N: NetworkPrimitives> Encodable for ProtocolBroadcastMessage<N> {
138 fn encode(&self, out: &mut dyn BufMut) {
141 self.message_type.encode(out);
142 self.message.encode(out);
143 }
144 fn length(&self) -> usize {
145 self.message_type.length() + self.message.length()
146 }
147}
148
149impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for ProtocolBroadcastMessage<N> {
150 fn from(message: EthBroadcastMessage<N>) -> Self {
151 Self { message_type: message.message_id(), message }
152 }
153}
154
155#[derive(Clone, Debug, PartialEq, Eq)]
173#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
174pub enum EthMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
175 Status(Status),
177 NewBlockHashes(NewBlockHashes),
179 #[cfg_attr(
181 feature = "serde",
182 serde(bound = "N::Block: serde::Serialize + serde::de::DeserializeOwned")
183 )]
184 NewBlock(Box<NewBlock<N::Block>>),
185 #[cfg_attr(
187 feature = "serde",
188 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
189 )]
190 Transactions(Transactions<N::BroadcastedTransaction>),
191 NewPooledTransactionHashes66(NewPooledTransactionHashes66),
193 NewPooledTransactionHashes68(NewPooledTransactionHashes68),
195 GetBlockHeaders(RequestPair<GetBlockHeaders>),
198 #[cfg_attr(
200 feature = "serde",
201 serde(bound = "N::BlockHeader: serde::Serialize + serde::de::DeserializeOwned")
202 )]
203 BlockHeaders(RequestPair<BlockHeaders<N::BlockHeader>>),
204 GetBlockBodies(RequestPair<GetBlockBodies>),
206 #[cfg_attr(
208 feature = "serde",
209 serde(bound = "N::BlockBody: serde::Serialize + serde::de::DeserializeOwned")
210 )]
211 BlockBodies(RequestPair<BlockBodies<N::BlockBody>>),
212 GetPooledTransactions(RequestPair<GetPooledTransactions>),
214 #[cfg_attr(
216 feature = "serde",
217 serde(bound = "N::PooledTransaction: serde::Serialize + serde::de::DeserializeOwned")
218 )]
219 PooledTransactions(RequestPair<PooledTransactions<N::PooledTransaction>>),
220 GetNodeData(RequestPair<GetNodeData>),
222 NodeData(RequestPair<NodeData>),
224 GetReceipts(RequestPair<GetReceipts>),
226 #[cfg_attr(
228 feature = "serde",
229 serde(bound = "N::Receipt: serde::Serialize + serde::de::DeserializeOwned")
230 )]
231 Receipts(RequestPair<Receipts<N::Receipt>>),
232}
233
234impl<N: NetworkPrimitives> EthMessage<N> {
235 pub const fn message_id(&self) -> EthMessageID {
237 match self {
238 Self::Status(_) => EthMessageID::Status,
239 Self::NewBlockHashes(_) => EthMessageID::NewBlockHashes,
240 Self::NewBlock(_) => EthMessageID::NewBlock,
241 Self::Transactions(_) => EthMessageID::Transactions,
242 Self::NewPooledTransactionHashes66(_) | Self::NewPooledTransactionHashes68(_) => {
243 EthMessageID::NewPooledTransactionHashes
244 }
245 Self::GetBlockHeaders(_) => EthMessageID::GetBlockHeaders,
246 Self::BlockHeaders(_) => EthMessageID::BlockHeaders,
247 Self::GetBlockBodies(_) => EthMessageID::GetBlockBodies,
248 Self::BlockBodies(_) => EthMessageID::BlockBodies,
249 Self::GetPooledTransactions(_) => EthMessageID::GetPooledTransactions,
250 Self::PooledTransactions(_) => EthMessageID::PooledTransactions,
251 Self::GetNodeData(_) => EthMessageID::GetNodeData,
252 Self::NodeData(_) => EthMessageID::NodeData,
253 Self::GetReceipts(_) => EthMessageID::GetReceipts,
254 Self::Receipts(_) => EthMessageID::Receipts,
255 }
256 }
257
258 pub const fn is_request(&self) -> bool {
260 matches!(
261 self,
262 Self::GetBlockBodies(_) |
263 Self::GetBlockHeaders(_) |
264 Self::GetReceipts(_) |
265 Self::GetPooledTransactions(_) |
266 Self::GetNodeData(_)
267 )
268 }
269
270 pub const fn is_response(&self) -> bool {
272 matches!(
273 self,
274 Self::PooledTransactions(_) |
275 Self::Receipts(_) |
276 Self::BlockHeaders(_) |
277 Self::BlockBodies(_) |
278 Self::NodeData(_)
279 )
280 }
281}
282
283impl<N: NetworkPrimitives> Encodable for EthMessage<N> {
284 fn encode(&self, out: &mut dyn BufMut) {
285 match self {
286 Self::Status(status) => status.encode(out),
287 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.encode(out),
288 Self::NewBlock(new_block) => new_block.encode(out),
289 Self::Transactions(transactions) => transactions.encode(out),
290 Self::NewPooledTransactionHashes66(hashes) => hashes.encode(out),
291 Self::NewPooledTransactionHashes68(hashes) => hashes.encode(out),
292 Self::GetBlockHeaders(request) => request.encode(out),
293 Self::BlockHeaders(headers) => headers.encode(out),
294 Self::GetBlockBodies(request) => request.encode(out),
295 Self::BlockBodies(bodies) => bodies.encode(out),
296 Self::GetPooledTransactions(request) => request.encode(out),
297 Self::PooledTransactions(transactions) => transactions.encode(out),
298 Self::GetNodeData(request) => request.encode(out),
299 Self::NodeData(data) => data.encode(out),
300 Self::GetReceipts(request) => request.encode(out),
301 Self::Receipts(receipts) => receipts.encode(out),
302 }
303 }
304 fn length(&self) -> usize {
305 match self {
306 Self::Status(status) => status.length(),
307 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.length(),
308 Self::NewBlock(new_block) => new_block.length(),
309 Self::Transactions(transactions) => transactions.length(),
310 Self::NewPooledTransactionHashes66(hashes) => hashes.length(),
311 Self::NewPooledTransactionHashes68(hashes) => hashes.length(),
312 Self::GetBlockHeaders(request) => request.length(),
313 Self::BlockHeaders(headers) => headers.length(),
314 Self::GetBlockBodies(request) => request.length(),
315 Self::BlockBodies(bodies) => bodies.length(),
316 Self::GetPooledTransactions(request) => request.length(),
317 Self::PooledTransactions(transactions) => transactions.length(),
318 Self::GetNodeData(request) => request.length(),
319 Self::NodeData(data) => data.length(),
320 Self::GetReceipts(request) => request.length(),
321 Self::Receipts(receipts) => receipts.length(),
322 }
323 }
324}
325
326#[derive(Clone, Debug, PartialEq, Eq)]
334pub enum EthBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
335 NewBlock(Arc<NewBlock<N::Block>>),
337 Transactions(SharedTransactions<N::BroadcastedTransaction>),
339}
340
341impl<N: NetworkPrimitives> EthBroadcastMessage<N> {
344 pub const fn message_id(&self) -> EthMessageID {
346 match self {
347 Self::NewBlock(_) => EthMessageID::NewBlock,
348 Self::Transactions(_) => EthMessageID::Transactions,
349 }
350 }
351}
352
353impl<N: NetworkPrimitives> Encodable for EthBroadcastMessage<N> {
354 fn encode(&self, out: &mut dyn BufMut) {
355 match self {
356 Self::NewBlock(new_block) => new_block.encode(out),
357 Self::Transactions(transactions) => transactions.encode(out),
358 }
359 }
360
361 fn length(&self) -> usize {
362 match self {
363 Self::NewBlock(new_block) => new_block.length(),
364 Self::Transactions(transactions) => transactions.length(),
365 }
366 }
367}
368
369#[repr(u8)]
371#[derive(Clone, Copy, Debug, PartialEq, Eq)]
372#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
373pub enum EthMessageID {
374 Status = 0x00,
376 NewBlockHashes = 0x01,
378 Transactions = 0x02,
380 GetBlockHeaders = 0x03,
382 BlockHeaders = 0x04,
384 GetBlockBodies = 0x05,
386 BlockBodies = 0x06,
388 NewBlock = 0x07,
390 NewPooledTransactionHashes = 0x08,
392 GetPooledTransactions = 0x09,
394 PooledTransactions = 0x0a,
396 GetNodeData = 0x0d,
398 NodeData = 0x0e,
400 GetReceipts = 0x0f,
402 Receipts = 0x10,
404}
405
406impl EthMessageID {
407 pub const fn max() -> u8 {
409 Self::Receipts as u8
410 }
411}
412
413impl Encodable for EthMessageID {
414 fn encode(&self, out: &mut dyn BufMut) {
415 out.put_u8(*self as u8);
416 }
417 fn length(&self) -> usize {
418 1
419 }
420}
421
422impl Decodable for EthMessageID {
423 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
424 let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
425 0x00 => Self::Status,
426 0x01 => Self::NewBlockHashes,
427 0x02 => Self::Transactions,
428 0x03 => Self::GetBlockHeaders,
429 0x04 => Self::BlockHeaders,
430 0x05 => Self::GetBlockBodies,
431 0x06 => Self::BlockBodies,
432 0x07 => Self::NewBlock,
433 0x08 => Self::NewPooledTransactionHashes,
434 0x09 => Self::GetPooledTransactions,
435 0x0a => Self::PooledTransactions,
436 0x0d => Self::GetNodeData,
437 0x0e => Self::NodeData,
438 0x0f => Self::GetReceipts,
439 0x10 => Self::Receipts,
440 _ => return Err(alloy_rlp::Error::Custom("Invalid message ID")),
441 };
442 buf.advance(1);
443 Ok(id)
444 }
445}
446
447impl TryFrom<usize> for EthMessageID {
448 type Error = &'static str;
449
450 fn try_from(value: usize) -> Result<Self, Self::Error> {
451 match value {
452 0x00 => Ok(Self::Status),
453 0x01 => Ok(Self::NewBlockHashes),
454 0x02 => Ok(Self::Transactions),
455 0x03 => Ok(Self::GetBlockHeaders),
456 0x04 => Ok(Self::BlockHeaders),
457 0x05 => Ok(Self::GetBlockBodies),
458 0x06 => Ok(Self::BlockBodies),
459 0x07 => Ok(Self::NewBlock),
460 0x08 => Ok(Self::NewPooledTransactionHashes),
461 0x09 => Ok(Self::GetPooledTransactions),
462 0x0a => Ok(Self::PooledTransactions),
463 0x0d => Ok(Self::GetNodeData),
464 0x0e => Ok(Self::NodeData),
465 0x0f => Ok(Self::GetReceipts),
466 0x10 => Ok(Self::Receipts),
467 _ => Err("Invalid message ID"),
468 }
469 }
470}
471
472#[derive(Clone, Debug, PartialEq, Eq)]
476#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
477#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
478pub struct RequestPair<T> {
479 pub request_id: u64,
481
482 pub message: T,
484}
485
486impl<T> Encodable for RequestPair<T>
488where
489 T: Encodable,
490{
491 fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
492 let header =
493 Header { list: true, payload_length: self.request_id.length() + self.message.length() };
494
495 header.encode(out);
496 self.request_id.encode(out);
497 self.message.encode(out);
498 }
499
500 fn length(&self) -> usize {
501 let mut length = 0;
502 length += self.request_id.length();
503 length += self.message.length();
504 length += length_of_length(length);
505 length
506 }
507}
508
509impl<T> Decodable for RequestPair<T>
511where
512 T: Decodable,
513{
514 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
515 let header = Header::decode(buf)?;
516
517 let initial_length = buf.len();
518 let request_id = u64::decode(buf)?;
519 let message = T::decode(buf)?;
520
521 let consumed_len = initial_length - buf.len();
524 if consumed_len != header.payload_length {
525 return Err(alloy_rlp::Error::UnexpectedLength)
526 }
527
528 Ok(Self { request_id, message })
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::MessageError;
535 use crate::{
536 message::RequestPair, EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion,
537 GetNodeData, NodeData, ProtocolMessage,
538 };
539 use alloy_primitives::hex;
540 use alloy_rlp::{Decodable, Encodable, Error};
541 use reth_ethereum_primitives::BlockBody;
542
543 fn encode<T: Encodable>(value: T) -> Vec<u8> {
544 let mut buf = vec![];
545 value.encode(&mut buf);
546 buf
547 }
548
549 #[test]
550 fn test_removed_message_at_eth67() {
551 let get_node_data = EthMessage::<EthNetworkPrimitives>::GetNodeData(RequestPair {
552 request_id: 1337,
553 message: GetNodeData(vec![]),
554 });
555 let buf = encode(ProtocolMessage {
556 message_type: EthMessageID::GetNodeData,
557 message: get_node_data,
558 });
559 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
560 crate::EthVersion::Eth67,
561 &mut &buf[..],
562 );
563 assert!(matches!(msg, Err(MessageError::Invalid(..))));
564
565 let node_data = EthMessage::<EthNetworkPrimitives>::NodeData(RequestPair {
566 request_id: 1337,
567 message: NodeData(vec![]),
568 });
569 let buf =
570 encode(ProtocolMessage { message_type: EthMessageID::NodeData, message: node_data });
571 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
572 crate::EthVersion::Eth67,
573 &mut &buf[..],
574 );
575 assert!(matches!(msg, Err(MessageError::Invalid(..))));
576 }
577
578 #[test]
579 fn request_pair_encode() {
580 let request_pair = RequestPair { request_id: 1337, message: vec![5u8] };
581
582 let expected = hex!("c5820539c105");
589 let got = encode(request_pair);
590 assert_eq!(expected[..], got, "expected: {expected:X?}, got: {got:X?}",);
591 }
592
593 #[test]
594 fn request_pair_decode() {
595 let raw_pair = &hex!("c5820539c105")[..];
596
597 let expected = RequestPair { request_id: 1337, message: vec![5u8] };
598
599 let got = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair).unwrap();
600 assert_eq!(expected.length(), raw_pair.len());
601 assert_eq!(expected, got);
602 }
603
604 #[test]
605 fn malicious_request_pair_decode() {
606 let raw_pair = &hex!("c5820539c20505")[..];
616
617 let result = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair);
618 assert!(matches!(result, Err(Error::UnexpectedLength)));
619 }
620
621 #[test]
622 fn empty_block_bodies_protocol() {
623 let empty_block_bodies =
624 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
625 request_id: 0,
626 message: Default::default(),
627 }));
628 let mut buf = Vec::new();
629 empty_block_bodies.encode(&mut buf);
630 let decoded =
631 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
632 assert_eq!(empty_block_bodies, decoded);
633 }
634
635 #[test]
636 fn empty_block_body_protocol() {
637 let empty_block_bodies =
638 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
639 request_id: 0,
640 message: vec![BlockBody {
641 transactions: vec![],
642 ommers: vec![],
643 withdrawals: Some(Default::default()),
644 }]
645 .into(),
646 }));
647 let mut buf = Vec::new();
648 empty_block_bodies.encode(&mut buf);
649 let decoded =
650 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
651 assert_eq!(empty_block_bodies, decoded);
652 }
653
654 #[test]
655 fn decode_block_bodies_message() {
656 let buf = hex!("06c48199c1c0");
657 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
658 EthVersion::Eth68,
659 &mut &buf[..],
660 )
661 .unwrap_err();
662 assert!(matches!(msg, MessageError::RlpError(alloy_rlp::Error::InputTooShort)));
663 }
664}