use super::{
broadcast::NewBlockHashes, BlockBodies, BlockHeaders, GetBlockBodies, GetBlockHeaders,
GetNodeData, GetPooledTransactions, GetReceipts, NewBlock, NewPooledTransactionHashes66,
NewPooledTransactionHashes68, NodeData, PooledTransactions, Receipts, Status, Transactions,
};
use crate::{EthNetworkPrimitives, EthVersion, NetworkPrimitives, SharedTransactions};
use alloy_primitives::bytes::{Buf, BufMut};
use alloy_rlp::{length_of_length, Decodable, Encodable, Header};
use std::{fmt::Debug, sync::Arc};
pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
#[derive(thiserror::Error, Debug)]
pub enum MessageError {
#[error("message id {1:?} is invalid for version {0:?}")]
Invalid(EthVersion, EthMessageID),
#[error("RLP error: {0}")]
RlpError(#[from] alloy_rlp::Error),
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
pub message_type: EthMessageID,
#[cfg_attr(
feature = "serde",
serde(bound = "EthMessage<N>: serde::Serialize + serde::de::DeserializeOwned")
)]
pub message: EthMessage<N>,
}
impl<N: NetworkPrimitives> ProtocolMessage<N> {
pub fn decode_message(version: EthVersion, buf: &mut &[u8]) -> Result<Self, MessageError> {
let message_type = EthMessageID::decode(buf)?;
let message = match message_type {
EthMessageID::Status => EthMessage::Status(Status::decode(buf)?),
EthMessageID::NewBlockHashes => {
if version.is_eth69() {
return Err(MessageError::Invalid(version, EthMessageID::NewBlockHashes));
}
EthMessage::NewBlockHashes(NewBlockHashes::decode(buf)?)
}
EthMessageID::NewBlock => {
if version.is_eth69() {
return Err(MessageError::Invalid(version, EthMessageID::NewBlock));
}
EthMessage::NewBlock(Box::new(NewBlock::decode(buf)?))
}
EthMessageID::Transactions => EthMessage::Transactions(Transactions::decode(buf)?),
EthMessageID::NewPooledTransactionHashes => {
if version >= EthVersion::Eth68 {
EthMessage::NewPooledTransactionHashes68(NewPooledTransactionHashes68::decode(
buf,
)?)
} else {
EthMessage::NewPooledTransactionHashes66(NewPooledTransactionHashes66::decode(
buf,
)?)
}
}
EthMessageID::GetBlockHeaders => EthMessage::GetBlockHeaders(RequestPair::decode(buf)?),
EthMessageID::BlockHeaders => EthMessage::BlockHeaders(RequestPair::decode(buf)?),
EthMessageID::GetBlockBodies => EthMessage::GetBlockBodies(RequestPair::decode(buf)?),
EthMessageID::BlockBodies => EthMessage::BlockBodies(RequestPair::decode(buf)?),
EthMessageID::GetPooledTransactions => {
EthMessage::GetPooledTransactions(RequestPair::decode(buf)?)
}
EthMessageID::PooledTransactions => {
EthMessage::PooledTransactions(RequestPair::decode(buf)?)
}
EthMessageID::GetNodeData => {
if version >= EthVersion::Eth67 {
return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
}
EthMessage::GetNodeData(RequestPair::decode(buf)?)
}
EthMessageID::NodeData => {
if version >= EthVersion::Eth67 {
return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
}
EthMessage::NodeData(RequestPair::decode(buf)?)
}
EthMessageID::GetReceipts => EthMessage::GetReceipts(RequestPair::decode(buf)?),
EthMessageID::Receipts => EthMessage::Receipts(RequestPair::decode(buf)?),
};
Ok(Self { message_type, message })
}
}
impl<N: NetworkPrimitives> Encodable for ProtocolMessage<N> {
fn encode(&self, out: &mut dyn BufMut) {
self.message_type.encode(out);
self.message.encode(out);
}
fn length(&self) -> usize {
self.message_type.length() + self.message.length()
}
}
impl<N: NetworkPrimitives> From<EthMessage<N>> for ProtocolMessage<N> {
fn from(message: EthMessage<N>) -> Self {
Self { message_type: message.message_id(), message }
}
}
#[derive(Clone, Debug)]
pub struct ProtocolBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
pub message_type: EthMessageID,
pub message: EthBroadcastMessage<N>,
}
impl<N: NetworkPrimitives> Encodable for ProtocolBroadcastMessage<N> {
fn encode(&self, out: &mut dyn BufMut) {
self.message_type.encode(out);
self.message.encode(out);
}
fn length(&self) -> usize {
self.message_type.length() + self.message.length()
}
}
impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for ProtocolBroadcastMessage<N> {
fn from(message: EthBroadcastMessage<N>) -> Self {
Self { message_type: message.message_id(), message }
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum EthMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
Status(Status),
NewBlockHashes(NewBlockHashes),
#[cfg_attr(
feature = "serde",
serde(bound = "N::Block: serde::Serialize + serde::de::DeserializeOwned")
)]
NewBlock(Box<NewBlock<N::Block>>),
#[cfg_attr(
feature = "serde",
serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
)]
Transactions(Transactions<N::BroadcastedTransaction>),
NewPooledTransactionHashes66(NewPooledTransactionHashes66),
NewPooledTransactionHashes68(NewPooledTransactionHashes68),
GetBlockHeaders(RequestPair<GetBlockHeaders>),
#[cfg_attr(
feature = "serde",
serde(bound = "N::BlockHeader: serde::Serialize + serde::de::DeserializeOwned")
)]
BlockHeaders(RequestPair<BlockHeaders<N::BlockHeader>>),
GetBlockBodies(RequestPair<GetBlockBodies>),
#[cfg_attr(
feature = "serde",
serde(bound = "N::BlockBody: serde::Serialize + serde::de::DeserializeOwned")
)]
BlockBodies(RequestPair<BlockBodies<N::BlockBody>>),
GetPooledTransactions(RequestPair<GetPooledTransactions>),
#[cfg_attr(
feature = "serde",
serde(bound = "N::PooledTransaction: serde::Serialize + serde::de::DeserializeOwned")
)]
PooledTransactions(RequestPair<PooledTransactions<N::PooledTransaction>>),
GetNodeData(RequestPair<GetNodeData>),
NodeData(RequestPair<NodeData>),
GetReceipts(RequestPair<GetReceipts>),
Receipts(RequestPair<Receipts>),
}
impl<N: NetworkPrimitives> EthMessage<N> {
pub const fn message_id(&self) -> EthMessageID {
match self {
Self::Status(_) => EthMessageID::Status,
Self::NewBlockHashes(_) => EthMessageID::NewBlockHashes,
Self::NewBlock(_) => EthMessageID::NewBlock,
Self::Transactions(_) => EthMessageID::Transactions,
Self::NewPooledTransactionHashes66(_) | Self::NewPooledTransactionHashes68(_) => {
EthMessageID::NewPooledTransactionHashes
}
Self::GetBlockHeaders(_) => EthMessageID::GetBlockHeaders,
Self::BlockHeaders(_) => EthMessageID::BlockHeaders,
Self::GetBlockBodies(_) => EthMessageID::GetBlockBodies,
Self::BlockBodies(_) => EthMessageID::BlockBodies,
Self::GetPooledTransactions(_) => EthMessageID::GetPooledTransactions,
Self::PooledTransactions(_) => EthMessageID::PooledTransactions,
Self::GetNodeData(_) => EthMessageID::GetNodeData,
Self::NodeData(_) => EthMessageID::NodeData,
Self::GetReceipts(_) => EthMessageID::GetReceipts,
Self::Receipts(_) => EthMessageID::Receipts,
}
}
}
impl<N: NetworkPrimitives> Encodable for EthMessage<N> {
fn encode(&self, out: &mut dyn BufMut) {
match self {
Self::Status(status) => status.encode(out),
Self::NewBlockHashes(new_block_hashes) => new_block_hashes.encode(out),
Self::NewBlock(new_block) => new_block.encode(out),
Self::Transactions(transactions) => transactions.encode(out),
Self::NewPooledTransactionHashes66(hashes) => hashes.encode(out),
Self::NewPooledTransactionHashes68(hashes) => hashes.encode(out),
Self::GetBlockHeaders(request) => request.encode(out),
Self::BlockHeaders(headers) => headers.encode(out),
Self::GetBlockBodies(request) => request.encode(out),
Self::BlockBodies(bodies) => bodies.encode(out),
Self::GetPooledTransactions(request) => request.encode(out),
Self::PooledTransactions(transactions) => transactions.encode(out),
Self::GetNodeData(request) => request.encode(out),
Self::NodeData(data) => data.encode(out),
Self::GetReceipts(request) => request.encode(out),
Self::Receipts(receipts) => receipts.encode(out),
}
}
fn length(&self) -> usize {
match self {
Self::Status(status) => status.length(),
Self::NewBlockHashes(new_block_hashes) => new_block_hashes.length(),
Self::NewBlock(new_block) => new_block.length(),
Self::Transactions(transactions) => transactions.length(),
Self::NewPooledTransactionHashes66(hashes) => hashes.length(),
Self::NewPooledTransactionHashes68(hashes) => hashes.length(),
Self::GetBlockHeaders(request) => request.length(),
Self::BlockHeaders(headers) => headers.length(),
Self::GetBlockBodies(request) => request.length(),
Self::BlockBodies(bodies) => bodies.length(),
Self::GetPooledTransactions(request) => request.length(),
Self::PooledTransactions(transactions) => transactions.length(),
Self::GetNodeData(request) => request.length(),
Self::NodeData(data) => data.length(),
Self::GetReceipts(request) => request.length(),
Self::Receipts(receipts) => receipts.length(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EthBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
NewBlock(Arc<NewBlock<N::Block>>),
Transactions(SharedTransactions<N::BroadcastedTransaction>),
}
impl<N: NetworkPrimitives> EthBroadcastMessage<N> {
pub const fn message_id(&self) -> EthMessageID {
match self {
Self::NewBlock(_) => EthMessageID::NewBlock,
Self::Transactions(_) => EthMessageID::Transactions,
}
}
}
impl<N: NetworkPrimitives> Encodable for EthBroadcastMessage<N> {
fn encode(&self, out: &mut dyn BufMut) {
match self {
Self::NewBlock(new_block) => new_block.encode(out),
Self::Transactions(transactions) => transactions.encode(out),
}
}
fn length(&self) -> usize {
match self {
Self::NewBlock(new_block) => new_block.length(),
Self::Transactions(transactions) => transactions.length(),
}
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum EthMessageID {
Status = 0x00,
NewBlockHashes = 0x01,
Transactions = 0x02,
GetBlockHeaders = 0x03,
BlockHeaders = 0x04,
GetBlockBodies = 0x05,
BlockBodies = 0x06,
NewBlock = 0x07,
NewPooledTransactionHashes = 0x08,
GetPooledTransactions = 0x09,
PooledTransactions = 0x0a,
GetNodeData = 0x0d,
NodeData = 0x0e,
GetReceipts = 0x0f,
Receipts = 0x10,
}
impl EthMessageID {
pub const fn max() -> u8 {
Self::Receipts as u8
}
}
impl Encodable for EthMessageID {
fn encode(&self, out: &mut dyn BufMut) {
out.put_u8(*self as u8);
}
fn length(&self) -> usize {
1
}
}
impl Decodable for EthMessageID {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
0x00 => Self::Status,
0x01 => Self::NewBlockHashes,
0x02 => Self::Transactions,
0x03 => Self::GetBlockHeaders,
0x04 => Self::BlockHeaders,
0x05 => Self::GetBlockBodies,
0x06 => Self::BlockBodies,
0x07 => Self::NewBlock,
0x08 => Self::NewPooledTransactionHashes,
0x09 => Self::GetPooledTransactions,
0x0a => Self::PooledTransactions,
0x0d => Self::GetNodeData,
0x0e => Self::NodeData,
0x0f => Self::GetReceipts,
0x10 => Self::Receipts,
_ => return Err(alloy_rlp::Error::Custom("Invalid message ID")),
};
buf.advance(1);
Ok(id)
}
}
impl TryFrom<usize> for EthMessageID {
type Error = &'static str;
fn try_from(value: usize) -> Result<Self, Self::Error> {
match value {
0x00 => Ok(Self::Status),
0x01 => Ok(Self::NewBlockHashes),
0x02 => Ok(Self::Transactions),
0x03 => Ok(Self::GetBlockHeaders),
0x04 => Ok(Self::BlockHeaders),
0x05 => Ok(Self::GetBlockBodies),
0x06 => Ok(Self::BlockBodies),
0x07 => Ok(Self::NewBlock),
0x08 => Ok(Self::NewPooledTransactionHashes),
0x09 => Ok(Self::GetPooledTransactions),
0x0a => Ok(Self::PooledTransactions),
0x0d => Ok(Self::GetNodeData),
0x0e => Ok(Self::NodeData),
0x0f => Ok(Self::GetReceipts),
0x10 => Ok(Self::Receipts),
_ => Err("Invalid message ID"),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RequestPair<T> {
pub request_id: u64,
pub message: T,
}
impl<T> Encodable for RequestPair<T>
where
T: Encodable,
{
fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
let header =
Header { list: true, payload_length: self.request_id.length() + self.message.length() };
header.encode(out);
self.request_id.encode(out);
self.message.encode(out);
}
fn length(&self) -> usize {
let mut length = 0;
length += self.request_id.length();
length += self.message.length();
length += length_of_length(length);
length
}
}
impl<T> Decodable for RequestPair<T>
where
T: Decodable,
{
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let header = Header::decode(buf)?;
let initial_length = buf.len();
let request_id = u64::decode(buf)?;
let message = T::decode(buf)?;
let consumed_len = initial_length - buf.len();
if consumed_len != header.payload_length {
return Err(alloy_rlp::Error::UnexpectedLength)
}
Ok(Self { request_id, message })
}
}
#[cfg(test)]
mod tests {
use super::MessageError;
use crate::{
message::RequestPair, EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion,
GetNodeData, NodeData, ProtocolMessage,
};
use alloy_primitives::hex;
use alloy_rlp::{Decodable, Encodable, Error};
fn encode<T: Encodable>(value: T) -> Vec<u8> {
let mut buf = vec![];
value.encode(&mut buf);
buf
}
#[test]
fn test_removed_message_at_eth67() {
let get_node_data = EthMessage::<EthNetworkPrimitives>::GetNodeData(RequestPair {
request_id: 1337,
message: GetNodeData(vec![]),
});
let buf = encode(ProtocolMessage {
message_type: EthMessageID::GetNodeData,
message: get_node_data,
});
let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
crate::EthVersion::Eth67,
&mut &buf[..],
);
assert!(matches!(msg, Err(MessageError::Invalid(..))));
let node_data = EthMessage::<EthNetworkPrimitives>::NodeData(RequestPair {
request_id: 1337,
message: NodeData(vec![]),
});
let buf =
encode(ProtocolMessage { message_type: EthMessageID::NodeData, message: node_data });
let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
crate::EthVersion::Eth67,
&mut &buf[..],
);
assert!(matches!(msg, Err(MessageError::Invalid(..))));
}
#[test]
fn request_pair_encode() {
let request_pair = RequestPair { request_id: 1337, message: vec![5u8] };
let expected = hex!("c5820539c105");
let got = encode(request_pair);
assert_eq!(expected[..], got, "expected: {expected:X?}, got: {got:X?}",);
}
#[test]
fn request_pair_decode() {
let raw_pair = &hex!("c5820539c105")[..];
let expected = RequestPair { request_id: 1337, message: vec![5u8] };
let got = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair).unwrap();
assert_eq!(expected.length(), raw_pair.len());
assert_eq!(expected, got);
}
#[test]
fn malicious_request_pair_decode() {
let raw_pair = &hex!("c5820539c20505")[..];
let result = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair);
assert!(matches!(result, Err(Error::UnexpectedLength)));
}
#[test]
fn empty_block_bodies_protocol() {
let empty_block_bodies =
ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
request_id: 0,
message: Default::default(),
}));
let mut buf = Vec::new();
empty_block_bodies.encode(&mut buf);
let decoded =
ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
assert_eq!(empty_block_bodies, decoded);
}
}