1use super::message::MAX_MESSAGE_SIZE;
7use crate::{
8 message::{EthBroadcastMessage, ProtocolBroadcastMessage},
9 EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion, NetworkPrimitives, ProtocolMessage,
10 RawCapabilityMessage, SnapMessageId, SnapProtocolMessage,
11};
12use alloy_rlp::{Bytes, BytesMut, Encodable};
13use core::fmt::Debug;
14use futures::{Sink, SinkExt};
15use pin_project::pin_project;
16use std::{
17 marker::PhantomData,
18 pin::Pin,
19 task::{ready, Context, Poll},
20};
21use tokio_stream::Stream;
22
23#[derive(thiserror::Error, Debug)]
25pub enum EthSnapStreamError {
26 #[error("invalid message for version {0:?}: {1}")]
28 InvalidMessage(EthVersion, String),
29
30 #[error("unknown message id: {0}")]
32 UnknownMessageId(u8),
33
34 #[error("message too large: {0} > {1}")]
36 MessageTooLarge(usize, usize),
37
38 #[error("rlp error: {0}")]
40 Rlp(#[from] alloy_rlp::Error),
41
42 #[error("status message received outside handshake")]
44 StatusNotInHandshake,
45}
46
47#[derive(Debug)]
49pub enum EthSnapMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
50 Eth(EthMessage<N>),
52 Snap(SnapProtocolMessage),
54}
55
56#[pin_project]
59#[derive(Debug, Clone)]
60pub struct EthSnapStream<S, N = EthNetworkPrimitives> {
61 eth_snap: EthSnapStreamInner<N>,
63 #[pin]
65 inner: S,
66}
67
68impl<S, N> EthSnapStream<S, N>
69where
70 N: NetworkPrimitives,
71{
72 pub const fn new(stream: S, eth_version: EthVersion) -> Self {
74 Self { eth_snap: EthSnapStreamInner::new(eth_version), inner: stream }
75 }
76
77 pub const fn with_max_message_size(
79 stream: S,
80 eth_version: EthVersion,
81 max_message_size: usize,
82 ) -> Self {
83 Self {
84 eth_snap: EthSnapStreamInner::with_max_message_size(eth_version, max_message_size),
85 inner: stream,
86 }
87 }
88
89 #[inline]
91 pub const fn eth_version(&self) -> EthVersion {
92 self.eth_snap.eth_version()
93 }
94
95 #[inline]
97 pub const fn inner(&self) -> &S {
98 &self.inner
99 }
100
101 #[inline]
103 pub const fn inner_mut(&mut self) -> &mut S {
104 &mut self.inner
105 }
106
107 #[inline]
109 pub fn into_inner(self) -> S {
110 self.inner
111 }
112}
113
114impl<S, E, N> EthSnapStream<S, N>
115where
116 S: Sink<Bytes, Error = E> + Unpin,
117 EthSnapStreamError: From<E>,
118 N: NetworkPrimitives,
119{
120 pub fn start_send_broadcast(
122 &mut self,
123 item: EthBroadcastMessage<N>,
124 ) -> Result<(), EthSnapStreamError> {
125 self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
126 ProtocolBroadcastMessage::from(item),
127 )))?;
128
129 Ok(())
130 }
131
132 pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthSnapStreamError> {
134 let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
135 msg.id.encode(&mut bytes);
136 bytes.extend_from_slice(&msg.payload);
137
138 self.inner.start_send_unpin(bytes.into())?;
139 Ok(())
140 }
141}
142
143impl<S, E, N> Stream for EthSnapStream<S, N>
144where
145 S: Stream<Item = Result<BytesMut, E>> + Unpin,
146 EthSnapStreamError: From<E>,
147 N: NetworkPrimitives,
148{
149 type Item = Result<EthSnapMessage<N>, EthSnapStreamError>;
150
151 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152 let this = self.project();
153 let res = ready!(this.inner.poll_next(cx));
154
155 match res {
156 Some(Ok(bytes)) => Poll::Ready(Some(this.eth_snap.decode_message(bytes))),
157 Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
158 None => Poll::Ready(None),
159 }
160 }
161}
162
163impl<S, E, N> Sink<EthSnapMessage<N>> for EthSnapStream<S, N>
164where
165 S: Sink<Bytes, Error = E> + Unpin,
166 EthSnapStreamError: From<E>,
167 N: NetworkPrimitives,
168{
169 type Error = EthSnapStreamError;
170
171 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172 self.project().inner.poll_ready(cx).map_err(Into::into)
173 }
174
175 fn start_send(mut self: Pin<&mut Self>, item: EthSnapMessage<N>) -> Result<(), Self::Error> {
176 let mut this = self.as_mut().project();
177
178 let bytes = match item {
179 EthSnapMessage::Eth(eth_msg) => this.eth_snap.encode_eth_message(eth_msg)?,
180 EthSnapMessage::Snap(snap_msg) => this.eth_snap.encode_snap_message(snap_msg),
181 };
182
183 this.inner.start_send_unpin(bytes)?;
184 Ok(())
185 }
186
187 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188 self.project().inner.poll_flush(cx).map_err(Into::into)
189 }
190
191 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192 self.project().inner.poll_close(cx).map_err(Into::into)
193 }
194}
195
196#[derive(Debug, Clone)]
200struct EthSnapStreamInner<N> {
201 eth_version: EthVersion,
203 max_message_size: usize,
205 _pd: PhantomData<N>,
207}
208
209impl<N> EthSnapStreamInner<N>
210where
211 N: NetworkPrimitives,
212{
213 const fn new(eth_version: EthVersion) -> Self {
215 Self::with_max_message_size(eth_version, MAX_MESSAGE_SIZE)
216 }
217
218 const fn with_max_message_size(eth_version: EthVersion, max_message_size: usize) -> Self {
220 Self { eth_version, max_message_size, _pd: PhantomData }
221 }
222
223 #[inline]
224 const fn eth_version(&self) -> EthVersion {
225 self.eth_version
226 }
227
228 fn decode_message(&self, bytes: BytesMut) -> Result<EthSnapMessage<N>, EthSnapStreamError> {
230 if bytes.len() > self.max_message_size {
231 return Err(EthSnapStreamError::MessageTooLarge(bytes.len(), self.max_message_size));
232 }
233
234 if bytes.is_empty() {
235 return Err(EthSnapStreamError::Rlp(alloy_rlp::Error::InputTooShort));
236 }
237
238 let message_id = bytes[0];
239
240 if message_id <= EthMessageID::max(self.eth_version) {
246 let mut buf = bytes.as_ref();
247 match ProtocolMessage::decode_message(self.eth_version, &mut buf) {
248 Ok(protocol_msg) => {
249 if matches!(protocol_msg.message, EthMessage::Status(_)) {
250 return Err(EthSnapStreamError::StatusNotInHandshake);
251 }
252 Ok(EthSnapMessage::Eth(protocol_msg.message))
253 }
254 Err(err) => {
255 Err(EthSnapStreamError::InvalidMessage(self.eth_version, err.to_string()))
256 }
257 }
258 } else if message_id > EthMessageID::max(self.eth_version) &&
259 message_id <=
260 EthMessageID::message_count(self.eth_version) + SnapMessageId::TrieNodes as u8
261 {
262 let adjusted_message_id = message_id - EthMessageID::message_count(self.eth_version);
269 let mut buf = &bytes[1..];
270
271 match SnapProtocolMessage::decode(adjusted_message_id, &mut buf) {
272 Ok(snap_msg) => Ok(EthSnapMessage::Snap(snap_msg)),
273 Err(err) => Err(EthSnapStreamError::Rlp(err)),
274 }
275 } else {
276 Err(EthSnapStreamError::UnknownMessageId(message_id))
277 }
278 }
279
280 fn encode_eth_message(&self, item: EthMessage<N>) -> Result<Bytes, EthSnapStreamError> {
282 if matches!(item, EthMessage::Status(_)) {
283 return Err(EthSnapStreamError::StatusNotInHandshake);
284 }
285
286 let protocol_msg = ProtocolMessage::from(item);
287 let mut buf = Vec::new();
288 protocol_msg.encode(&mut buf);
289 Ok(Bytes::from(buf))
290 }
291
292 fn encode_snap_message(&self, message: SnapProtocolMessage) -> Bytes {
295 let encoded = message.encode();
296
297 let message_id = encoded[0];
298 let adjusted_id = message_id + EthMessageID::message_count(self.eth_version);
299
300 let mut adjusted = Vec::with_capacity(encoded.len());
301 adjusted.push(adjusted_id);
302 adjusted.extend_from_slice(&encoded[1..]);
303
304 Bytes::from(adjusted)
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::{EthMessage, SnapProtocolMessage};
312 use alloy_eips::BlockHashOrNumber;
313 use alloy_primitives::B256;
314 use alloy_rlp::Encodable;
315 use reth_eth_wire_types::{
316 message::RequestPair, GetAccountRangeMessage, GetBlockAccessLists, GetBlockHeaders,
317 HeadersDirection,
318 };
319
320 fn create_eth_message() -> (EthMessage<EthNetworkPrimitives>, BytesMut) {
322 let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockHeaders(RequestPair {
323 request_id: 1,
324 message: GetBlockHeaders {
325 start_block: BlockHashOrNumber::Number(1),
326 limit: 10,
327 skip: 0,
328 direction: HeadersDirection::Rising,
329 },
330 });
331
332 let protocol_msg = ProtocolMessage::from(eth_msg.clone());
333 let mut buf = Vec::new();
334 protocol_msg.encode(&mut buf);
335
336 (eth_msg, BytesMut::from(&buf[..]))
337 }
338
339 fn create_snap_message() -> (SnapProtocolMessage, BytesMut) {
341 let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
342 request_id: 1,
343 root_hash: B256::default(),
344 starting_hash: B256::default(),
345 limit_hash: B256::default(),
346 response_bytes: 1000,
347 });
348
349 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
350 let encoded = inner.encode_snap_message(snap_msg.clone());
351
352 (snap_msg, BytesMut::from(&encoded[..]))
353 }
354
355 #[test]
356 fn test_eth_message_roundtrip() {
357 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
358 let (eth_msg, eth_bytes) = create_eth_message();
359
360 let encoded_result = inner.encode_eth_message(eth_msg.clone());
362 assert!(encoded_result.is_ok());
363
364 let decoded_result = inner.decode_message(eth_bytes.clone());
366 assert!(matches!(decoded_result, Ok(EthSnapMessage::Eth(_))));
367
368 if let Ok(EthSnapMessage::Eth(decoded_msg)) = inner.decode_message(eth_bytes) {
370 assert_eq!(decoded_msg, eth_msg);
371
372 let re_encoded = inner.encode_eth_message(decoded_msg.clone()).unwrap();
373 let re_encoded_bytes = BytesMut::from(&re_encoded[..]);
374 let re_decoded = inner.decode_message(re_encoded_bytes);
375
376 assert!(matches!(re_decoded, Ok(EthSnapMessage::Eth(_))));
377 if let Ok(EthSnapMessage::Eth(final_msg)) = re_decoded {
378 assert_eq!(final_msg, decoded_msg);
379 }
380 }
381 }
382
383 #[test]
384 fn test_snap_protocol() {
385 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
386 let (snap_msg, snap_bytes) = create_snap_message();
387
388 let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
390 assert!(!encoded_bytes.is_empty());
391
392 let decoded_result = inner.decode_message(snap_bytes.clone());
394 assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
395
396 if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
398 assert_eq!(decoded_msg, snap_msg);
399
400 let encoded = inner.encode_snap_message(decoded_msg.clone());
402
403 let re_encoded_bytes = BytesMut::from(&encoded[..]);
404
405 let re_decoded = inner.decode_message(re_encoded_bytes);
407
408 assert!(matches!(re_decoded, Ok(EthSnapMessage::Snap(_))));
409 if let Ok(EthSnapMessage::Snap(final_msg)) = re_decoded {
410 assert_eq!(final_msg, decoded_msg);
411 }
412 }
413 }
414
415 #[test]
416 fn test_message_id_boundaries() {
417 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
418
419 let eth_max_id = EthMessageID::max(EthVersion::Eth67);
421 let mut eth_boundary_bytes = BytesMut::new();
422 eth_boundary_bytes.extend_from_slice(&[eth_max_id]);
423 eth_boundary_bytes.extend_from_slice(&[0, 0]);
424
425 let eth_boundary_result = inner.decode_message(eth_boundary_bytes);
427 assert!(
428 eth_boundary_result.is_err() ||
429 matches!(eth_boundary_result, Ok(EthSnapMessage::Eth(_)))
430 );
431
432 let snap_min_id = eth_max_id + 1;
434 let mut snap_boundary_bytes = BytesMut::new();
435 snap_boundary_bytes.extend_from_slice(&[snap_min_id]);
436 snap_boundary_bytes.extend_from_slice(&[0, 0]);
437
438 let snap_boundary_result = inner.decode_message(snap_boundary_bytes);
440 assert!(snap_boundary_result.is_err());
441 }
442
443 #[test]
444 fn test_eth70_message_id_0x12_is_snap() {
445 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth70);
446 let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
447 request_id: 1,
448 root_hash: B256::default(),
449 starting_hash: B256::default(),
450 limit_hash: B256::default(),
451 response_bytes: 1000,
452 });
453
454 let encoded = inner.encode_snap_message(snap_msg);
455 assert_eq!(encoded[0], EthMessageID::message_count(EthVersion::Eth70));
456
457 let decoded = inner.decode_message(BytesMut::from(&encoded[..])).unwrap();
458 assert!(matches!(decoded, EthSnapMessage::Snap(_)));
459 }
460
461 #[test]
462 fn test_eth71_message_id_0x12_is_eth() {
463 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth71);
464 let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
465 request_id: 1,
466 message: GetBlockAccessLists(vec![B256::ZERO]),
467 });
468 let protocol_msg = ProtocolMessage::from(eth_msg.clone());
469 let mut buf = Vec::new();
470 protocol_msg.encode(&mut buf);
471
472 let decoded = inner.decode_message(BytesMut::from(&buf[..])).unwrap();
473 let EthSnapMessage::Eth(decoded_eth) = decoded else {
474 panic!("expected eth message");
475 };
476 assert_eq!(decoded_eth, eth_msg);
477 }
478}