1use crate::NodeType;
7use alloy_consensus::Header;
8use alloy_primitives::{
9 bytes::{Buf, BufMut},
10 BlockHash, Bytes, B256,
11};
12use alloy_rlp::{BytesMut, Decodable, Encodable, RlpDecodable, RlpEncodable};
13use reth_eth_wire::{message::RequestPair, protocol::Protocol, Capability};
14use reth_ethereum_primitives::BlockBody;
15
16#[derive(PartialEq, Eq, Clone, Debug)]
18pub struct RessProtocolMessage {
19 pub message_type: RessMessageID,
21 pub message: RessMessage,
23}
24
25#[cfg(any(test, feature = "arbitrary"))]
26impl<'a> arbitrary::Arbitrary<'a> for RessProtocolMessage {
27 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
28 let message: RessMessage = u.arbitrary()?;
29 Ok(Self { message_type: message.message_id(), message })
30 }
31}
32
33impl RessProtocolMessage {
34 pub const fn capability() -> Capability {
36 Capability::new_static("ress", 1)
37 }
38
39 pub const fn protocol() -> Protocol {
41 Protocol::new(Self::capability(), 9)
42 }
43
44 pub const fn node_type(node_type: NodeType) -> Self {
46 RessMessage::NodeType(node_type).into_protocol_message()
47 }
48
49 pub const fn get_headers(request_id: u64, request: GetHeaders) -> Self {
51 RessMessage::GetHeaders(RequestPair { request_id, message: request })
52 .into_protocol_message()
53 }
54
55 pub const fn headers(request_id: u64, headers: Vec<Header>) -> Self {
57 RessMessage::Headers(RequestPair { request_id, message: headers }).into_protocol_message()
58 }
59
60 pub const fn get_block_bodies(request_id: u64, block_hashes: Vec<B256>) -> Self {
62 RessMessage::GetBlockBodies(RequestPair { request_id, message: block_hashes })
63 .into_protocol_message()
64 }
65
66 pub const fn block_bodies(request_id: u64, bodies: Vec<BlockBody>) -> Self {
68 RessMessage::BlockBodies(RequestPair { request_id, message: bodies })
69 .into_protocol_message()
70 }
71
72 pub const fn get_bytecode(request_id: u64, code_hash: B256) -> Self {
74 RessMessage::GetBytecode(RequestPair { request_id, message: code_hash })
75 .into_protocol_message()
76 }
77
78 pub const fn bytecode(request_id: u64, bytecode: Bytes) -> Self {
80 RessMessage::Bytecode(RequestPair { request_id, message: bytecode }).into_protocol_message()
81 }
82
83 pub const fn get_witness(request_id: u64, block_hash: BlockHash) -> Self {
85 RessMessage::GetWitness(RequestPair { request_id, message: block_hash })
86 .into_protocol_message()
87 }
88
89 pub const fn witness(request_id: u64, witness: Vec<Bytes>) -> Self {
91 RessMessage::Witness(RequestPair { request_id, message: witness }).into_protocol_message()
92 }
93
94 pub fn encoded(&self) -> BytesMut {
96 let mut buf = BytesMut::with_capacity(self.length());
97 self.encode(&mut buf);
98 buf
99 }
100
101 pub fn decode_message(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
103 let message_type = RessMessageID::decode(buf)?;
104 let message = match message_type {
105 RessMessageID::NodeType => RessMessage::NodeType(NodeType::decode(buf)?),
106 RessMessageID::GetHeaders => RessMessage::GetHeaders(RequestPair::decode(buf)?),
107 RessMessageID::Headers => RessMessage::Headers(RequestPair::decode(buf)?),
108 RessMessageID::GetBlockBodies => RessMessage::GetBlockBodies(RequestPair::decode(buf)?),
109 RessMessageID::BlockBodies => RessMessage::BlockBodies(RequestPair::decode(buf)?),
110 RessMessageID::GetBytecode => RessMessage::GetBytecode(RequestPair::decode(buf)?),
111 RessMessageID::Bytecode => RessMessage::Bytecode(RequestPair::decode(buf)?),
112 RessMessageID::GetWitness => RessMessage::GetWitness(RequestPair::decode(buf)?),
113 RessMessageID::Witness => RessMessage::Witness(RequestPair::decode(buf)?),
114 };
115 Ok(Self { message_type, message })
116 }
117}
118
119impl Encodable for RessProtocolMessage {
120 fn encode(&self, out: &mut dyn BufMut) {
121 self.message_type.encode(out);
122 self.message.encode(out);
123 }
124
125 fn length(&self) -> usize {
126 self.message_type.length() + self.message.length()
127 }
128}
129
130#[repr(u8)]
132#[derive(PartialEq, Eq, Clone, Copy, Debug)]
133#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
134#[cfg_attr(test, derive(strum_macros::EnumCount))]
135pub enum RessMessageID {
136 NodeType = 0x00,
138
139 GetHeaders = 0x01,
141 Headers = 0x02,
143
144 GetBlockBodies = 0x03,
146 BlockBodies = 0x04,
148
149 GetBytecode = 0x05,
151 Bytecode = 0x06,
153
154 GetWitness = 0x07,
156 Witness = 0x08,
158}
159
160impl Encodable for RessMessageID {
161 fn encode(&self, out: &mut dyn BufMut) {
162 out.put_u8(*self as u8);
163 }
164
165 fn length(&self) -> usize {
166 1
167 }
168}
169
170impl Decodable for RessMessageID {
171 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
172 let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
173 0x00 => Self::NodeType,
174 0x01 => Self::GetHeaders,
175 0x02 => Self::Headers,
176 0x03 => Self::GetBlockBodies,
177 0x04 => Self::BlockBodies,
178 0x05 => Self::GetBytecode,
179 0x06 => Self::Bytecode,
180 0x07 => Self::GetWitness,
181 0x08 => Self::Witness,
182 _ => return Err(alloy_rlp::Error::Custom("Invalid message type")),
183 };
184 buf.advance(1);
185 Ok(id)
186 }
187}
188
189#[derive(PartialEq, Eq, Clone, Debug)]
191#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
192pub enum RessMessage {
193 NodeType(NodeType),
195
196 GetHeaders(RequestPair<GetHeaders>),
198 Headers(RequestPair<Vec<Header>>),
200
201 GetBlockBodies(RequestPair<Vec<B256>>),
203 BlockBodies(RequestPair<Vec<BlockBody>>),
205
206 GetBytecode(RequestPair<B256>),
208 Bytecode(RequestPair<Bytes>),
210
211 GetWitness(RequestPair<BlockHash>),
213 Witness(RequestPair<Vec<Bytes>>),
215}
216
217impl RessMessage {
218 pub const fn message_id(&self) -> RessMessageID {
220 match self {
221 Self::NodeType(_) => RessMessageID::NodeType,
222 Self::GetHeaders(_) => RessMessageID::GetHeaders,
223 Self::Headers(_) => RessMessageID::Headers,
224 Self::GetBlockBodies(_) => RessMessageID::GetBlockBodies,
225 Self::BlockBodies(_) => RessMessageID::BlockBodies,
226 Self::GetBytecode(_) => RessMessageID::GetBytecode,
227 Self::Bytecode(_) => RessMessageID::Bytecode,
228 Self::GetWitness(_) => RessMessageID::GetWitness,
229 Self::Witness(_) => RessMessageID::Witness,
230 }
231 }
232
233 pub const fn into_protocol_message(self) -> RessProtocolMessage {
235 let message_type = self.message_id();
236 RessProtocolMessage { message_type, message: self }
237 }
238}
239
240impl From<RessMessage> for RessProtocolMessage {
241 fn from(value: RessMessage) -> Self {
242 value.into_protocol_message()
243 }
244}
245
246impl Encodable for RessMessage {
247 fn encode(&self, out: &mut dyn BufMut) {
248 match self {
249 Self::NodeType(node_type) => node_type.encode(out),
250 Self::GetHeaders(request) => request.encode(out),
251 Self::Headers(header) => header.encode(out),
252 Self::GetBlockBodies(request) => request.encode(out),
253 Self::BlockBodies(body) => body.encode(out),
254 Self::GetBytecode(request) | Self::GetWitness(request) => request.encode(out),
255 Self::Bytecode(bytecode) => bytecode.encode(out),
256 Self::Witness(witness) => witness.encode(out),
257 }
258 }
259
260 fn length(&self) -> usize {
261 match self {
262 Self::NodeType(node_type) => node_type.length(),
263 Self::GetHeaders(request) => request.length(),
264 Self::Headers(header) => header.length(),
265 Self::GetBlockBodies(request) => request.length(),
266 Self::BlockBodies(body) => body.length(),
267 Self::GetBytecode(request) | Self::GetWitness(request) => request.length(),
268 Self::Bytecode(bytecode) => bytecode.length(),
269 Self::Witness(witness) => witness.length(),
270 }
271 }
272}
273
274#[derive(PartialEq, Eq, Clone, Copy, Debug, RlpEncodable, RlpDecodable)]
279#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
280pub struct GetHeaders {
281 pub start_hash: BlockHash,
283
284 pub limit: u64,
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use proptest::prelude::*;
292 use proptest_arbitrary_interop::arb;
293 use std::fmt;
294 use strum::EnumCount;
295
296 fn rlp_roundtrip<V>(value: V)
297 where
298 V: Encodable + Decodable + PartialEq + fmt::Debug,
299 {
300 let encoded = alloy_rlp::encode(&value);
301 let decoded = V::decode(&mut &encoded[..]);
302 assert_eq!(Ok(value), decoded);
303 }
304
305 #[test]
306 fn protocol_message_count() {
307 let protocol = RessProtocolMessage::protocol();
308 assert_eq!(protocol.messages(), RessMessageID::COUNT as u8);
309 }
310
311 proptest! {
312 #[test]
313 fn message_type_roundtrip(message_type in arb::<RessMessageID>()) {
314 rlp_roundtrip(message_type);
315 }
316
317 #[test]
318 fn message_roundtrip(message in arb::<RessProtocolMessage>()) {
319 let encoded = alloy_rlp::encode(&message);
320 let decoded = RessProtocolMessage::decode_message(&mut &encoded[..]);
321 assert_eq!(Ok(message), decoded);
322 }
323 }
324}