1use super::message::MAX_MESSAGE_SIZE;
7use crate::{
8 message::{EthBroadcastMessage, ProtocolBroadcastMessage, TX_MEMORY_BUDGET_MULTIPLIER},
9 EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion, NetworkPrimitives, ProtocolMessage,
10 RawCapabilityMessage, SnapProtocolMessage, SnapVersion,
11};
12use alloy_rlp::{Bytes, BytesMut, Encodable};
13use core::fmt::Debug;
14use futures::{Sink, SinkExt};
15use pin_project::pin_project;
16use std::{
17 marker::PhantomData,
18 pin::Pin,
19 task::{ready, Context, Poll},
20};
21use tokio_stream::Stream;
22
23#[derive(thiserror::Error, Debug)]
25pub enum EthSnapStreamError {
26 #[error("invalid message for version {0:?}: {1}")]
28 InvalidMessage(EthVersion, String),
29
30 #[error("unknown message id: {0}")]
32 UnknownMessageId(u8),
33
34 #[error("message too large: {0} > {1}")]
36 MessageTooLarge(usize, usize),
37
38 #[error("rlp error: {0}")]
40 Rlp(#[from] alloy_rlp::Error),
41
42 #[error("status message received outside handshake")]
44 StatusNotInHandshake,
45}
46
47#[derive(Debug)]
49pub enum EthSnapMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
50 Eth(EthMessage<N>),
52 Snap(SnapProtocolMessage),
54}
55
56#[pin_project]
59#[derive(Debug, Clone)]
60pub struct EthSnapStream<S, N = EthNetworkPrimitives> {
61 eth_snap: EthSnapStreamInner<N>,
63 #[pin]
65 inner: S,
66}
67
68impl<S, N> EthSnapStream<S, N>
69where
70 N: NetworkPrimitives,
71{
72 pub const fn new(stream: S, eth_version: EthVersion) -> Self {
74 Self { eth_snap: EthSnapStreamInner::new(eth_version), inner: stream }
75 }
76
77 pub const fn new_with_snap_version(
79 stream: S,
80 eth_version: EthVersion,
81 snap_version: SnapVersion,
82 ) -> Self {
83 Self {
84 eth_snap: EthSnapStreamInner::new_with_snap_version(eth_version, snap_version),
85 inner: stream,
86 }
87 }
88
89 pub const fn with_max_message_size(
91 stream: S,
92 eth_version: EthVersion,
93 max_message_size: usize,
94 ) -> Self {
95 Self {
96 eth_snap: EthSnapStreamInner::with_max_message_size(eth_version, max_message_size),
97 inner: stream,
98 }
99 }
100
101 pub const fn with_max_message_size_and_snap_version(
103 stream: S,
104 eth_version: EthVersion,
105 snap_version: SnapVersion,
106 max_message_size: usize,
107 ) -> Self {
108 Self {
109 eth_snap: EthSnapStreamInner::with_max_message_size_and_snap_version(
110 eth_version,
111 snap_version,
112 max_message_size,
113 ),
114 inner: stream,
115 }
116 }
117
118 #[inline]
120 pub const fn eth_version(&self) -> EthVersion {
121 self.eth_snap.eth_version()
122 }
123
124 #[inline]
126 pub const fn snap_version(&self) -> SnapVersion {
127 self.eth_snap.snap_version()
128 }
129
130 #[inline]
132 pub const fn inner(&self) -> &S {
133 &self.inner
134 }
135
136 #[inline]
138 pub const fn inner_mut(&mut self) -> &mut S {
139 &mut self.inner
140 }
141
142 #[inline]
144 pub fn into_inner(self) -> S {
145 self.inner
146 }
147}
148
149impl<S, E, N> EthSnapStream<S, N>
150where
151 S: Sink<Bytes, Error = E> + Unpin,
152 EthSnapStreamError: From<E>,
153 N: NetworkPrimitives,
154{
155 pub fn start_send_broadcast(
157 &mut self,
158 item: EthBroadcastMessage<N>,
159 ) -> Result<(), EthSnapStreamError> {
160 self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
161 ProtocolBroadcastMessage::from(item),
162 )))?;
163
164 Ok(())
165 }
166
167 pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthSnapStreamError> {
169 let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
170 msg.id.encode(&mut bytes);
171 bytes.extend_from_slice(&msg.payload);
172
173 self.inner.start_send_unpin(bytes.into())?;
174 Ok(())
175 }
176}
177
178impl<S, E, N> Stream for EthSnapStream<S, N>
179where
180 S: Stream<Item = Result<BytesMut, E>> + Unpin,
181 EthSnapStreamError: From<E>,
182 N: NetworkPrimitives,
183{
184 type Item = Result<EthSnapMessage<N>, EthSnapStreamError>;
185
186 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187 let this = self.project();
188 let res = ready!(this.inner.poll_next(cx));
189
190 match res {
191 Some(Ok(bytes)) => Poll::Ready(Some(this.eth_snap.decode_message(bytes))),
192 Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
193 None => Poll::Ready(None),
194 }
195 }
196}
197
198impl<S, E, N> Sink<EthSnapMessage<N>> for EthSnapStream<S, N>
199where
200 S: Sink<Bytes, Error = E> + Unpin,
201 EthSnapStreamError: From<E>,
202 N: NetworkPrimitives,
203{
204 type Error = EthSnapStreamError;
205
206 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207 self.project().inner.poll_ready(cx).map_err(Into::into)
208 }
209
210 fn start_send(mut self: Pin<&mut Self>, item: EthSnapMessage<N>) -> Result<(), Self::Error> {
211 let mut this = self.as_mut().project();
212
213 let bytes = match item {
214 EthSnapMessage::Eth(eth_msg) => this.eth_snap.encode_eth_message(eth_msg)?,
215 EthSnapMessage::Snap(snap_msg) => this.eth_snap.encode_snap_message(snap_msg),
216 };
217
218 this.inner.start_send_unpin(bytes)?;
219 Ok(())
220 }
221
222 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223 self.project().inner.poll_flush(cx).map_err(Into::into)
224 }
225
226 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227 self.project().inner.poll_close(cx).map_err(Into::into)
228 }
229}
230
231#[derive(Debug, Clone)]
233struct EthSnapStreamInner<N> {
234 eth_version: EthVersion,
236 snap_version: SnapVersion,
238 max_message_size: usize,
240 _pd: PhantomData<N>,
242}
243
244impl<N> EthSnapStreamInner<N>
245where
246 N: NetworkPrimitives,
247{
248 const fn new(eth_version: EthVersion) -> Self {
250 Self::new_with_snap_version(eth_version, SnapVersion::V1)
251 }
252
253 const fn new_with_snap_version(eth_version: EthVersion, snap_version: SnapVersion) -> Self {
255 Self::with_max_message_size_and_snap_version(eth_version, snap_version, MAX_MESSAGE_SIZE)
256 }
257
258 const fn with_max_message_size(eth_version: EthVersion, max_message_size: usize) -> Self {
260 Self::with_max_message_size_and_snap_version(eth_version, SnapVersion::V1, max_message_size)
261 }
262
263 const fn with_max_message_size_and_snap_version(
265 eth_version: EthVersion,
266 snap_version: SnapVersion,
267 max_message_size: usize,
268 ) -> Self {
269 Self { eth_version, snap_version, max_message_size, _pd: PhantomData }
270 }
271
272 #[inline]
273 const fn eth_version(&self) -> EthVersion {
274 self.eth_version
275 }
276
277 #[inline]
278 const fn snap_version(&self) -> SnapVersion {
279 self.snap_version
280 }
281
282 fn decode_message(&self, bytes: BytesMut) -> Result<EthSnapMessage<N>, EthSnapStreamError> {
284 if bytes.len() > self.max_message_size {
285 return Err(EthSnapStreamError::MessageTooLarge(bytes.len(), self.max_message_size));
286 }
287
288 if bytes.is_empty() {
289 return Err(EthSnapStreamError::Rlp(alloy_rlp::Error::InputTooShort));
290 }
291
292 let message_id = bytes[0];
293
294 if message_id <= EthMessageID::max(self.eth_version) {
300 let mut buf = bytes.as_ref();
301 match ProtocolMessage::decode_message_with_tx_memory_budget(
302 self.eth_version,
303 &mut buf,
304 self.max_message_size * TX_MEMORY_BUDGET_MULTIPLIER,
305 ) {
306 Ok(protocol_msg) => {
307 if matches!(protocol_msg.message, EthMessage::Status(_)) {
308 return Err(EthSnapStreamError::StatusNotInHandshake);
309 }
310 Ok(EthSnapMessage::Eth(protocol_msg.message))
311 }
312 Err(err) => {
313 Err(EthSnapStreamError::InvalidMessage(self.eth_version, err.to_string()))
314 }
315 }
316 } else if message_id > EthMessageID::max(self.eth_version) &&
317 message_id <
318 EthMessageID::message_count(self.eth_version) +
319 self.snap_version.message_count()
320 {
321 let adjusted_message_id = message_id - EthMessageID::message_count(self.eth_version);
328 let mut buf = &bytes[1..];
329
330 match SnapProtocolMessage::decode(adjusted_message_id, &mut buf) {
331 Ok(snap_msg) => Ok(EthSnapMessage::Snap(snap_msg)),
332 Err(err) => Err(EthSnapStreamError::Rlp(err)),
333 }
334 } else {
335 Err(EthSnapStreamError::UnknownMessageId(message_id))
336 }
337 }
338
339 fn encode_eth_message(&self, item: EthMessage<N>) -> Result<Bytes, EthSnapStreamError> {
341 if matches!(item, EthMessage::Status(_)) {
342 return Err(EthSnapStreamError::StatusNotInHandshake);
343 }
344
345 let protocol_msg = ProtocolMessage::from(item);
346 let mut buf = Vec::new();
347 protocol_msg.encode(&mut buf);
348 Ok(Bytes::from(buf))
349 }
350
351 fn encode_snap_message(&self, message: SnapProtocolMessage) -> Bytes {
354 let encoded = message.encode();
355
356 let message_id = encoded[0];
357 let adjusted_id = message_id + EthMessageID::message_count(self.eth_version);
358
359 let mut adjusted = Vec::with_capacity(encoded.len());
360 adjusted.push(adjusted_id);
361 adjusted.extend_from_slice(&encoded[1..]);
362
363 Bytes::from(adjusted)
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::{EthMessage, SnapProtocolMessage};
371 use alloy_eips::BlockHashOrNumber;
372 use alloy_primitives::B256;
373 use alloy_rlp::Encodable;
374 use reth_eth_wire_types::{
375 message::RequestPair, BlockAccessLists, BlockAccessListsMessage, GetAccountRangeMessage,
376 GetBlockAccessLists, GetBlockAccessListsMessage, GetBlockHeaders, HeadersDirection,
377 };
378
379 fn create_eth_message() -> (EthMessage<EthNetworkPrimitives>, BytesMut) {
381 let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockHeaders(RequestPair {
382 request_id: 1,
383 message: GetBlockHeaders {
384 start_block: BlockHashOrNumber::Number(1),
385 limit: 10,
386 skip: 0,
387 direction: HeadersDirection::Rising,
388 },
389 });
390
391 let protocol_msg = ProtocolMessage::from(eth_msg.clone());
392 let mut buf = Vec::new();
393 protocol_msg.encode(&mut buf);
394
395 (eth_msg, BytesMut::from(&buf[..]))
396 }
397
398 fn create_snap_message() -> (SnapProtocolMessage, BytesMut) {
400 let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
401 request_id: 1,
402 root_hash: B256::default(),
403 starting_hash: B256::default(),
404 limit_hash: B256::default(),
405 response_bytes: 1000,
406 });
407
408 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
409 let encoded = inner.encode_snap_message(snap_msg.clone());
410
411 (snap_msg, BytesMut::from(&encoded[..]))
412 }
413
414 fn create_snap2_message() -> (SnapProtocolMessage, BytesMut) {
415 let snap_msg = SnapProtocolMessage::GetBlockAccessLists(GetBlockAccessListsMessage {
416 request_id: 1,
417 block_hashes: vec![B256::default()],
418 response_bytes: 1000,
419 });
420
421 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new_with_snap_version(
422 EthVersion::Eth67,
423 SnapVersion::V2,
424 );
425 let encoded = inner.encode_snap_message(snap_msg.clone());
426
427 (snap_msg, BytesMut::from(&encoded[..]))
428 }
429
430 #[test]
431 fn test_eth_message_roundtrip() {
432 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
433 let (eth_msg, eth_bytes) = create_eth_message();
434
435 let encoded_result = inner.encode_eth_message(eth_msg.clone());
437 assert!(encoded_result.is_ok());
438
439 let decoded_result = inner.decode_message(eth_bytes.clone());
441 assert!(matches!(decoded_result, Ok(EthSnapMessage::Eth(_))));
442
443 if let Ok(EthSnapMessage::Eth(decoded_msg)) = inner.decode_message(eth_bytes) {
445 assert_eq!(decoded_msg, eth_msg);
446
447 let re_encoded = inner.encode_eth_message(decoded_msg.clone()).unwrap();
448 let re_encoded_bytes = BytesMut::from(&re_encoded[..]);
449 let re_decoded = inner.decode_message(re_encoded_bytes);
450
451 assert!(matches!(re_decoded, Ok(EthSnapMessage::Eth(_))));
452 if let Ok(EthSnapMessage::Eth(final_msg)) = re_decoded {
453 assert_eq!(final_msg, decoded_msg);
454 }
455 }
456 }
457
458 #[test]
459 fn test_snap_protocol() {
460 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
461 let (snap_msg, snap_bytes) = create_snap_message();
462
463 let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
465 assert!(!encoded_bytes.is_empty());
466
467 let decoded_result = inner.decode_message(snap_bytes.clone());
469 assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
470
471 if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
473 assert_eq!(decoded_msg, snap_msg);
474
475 let encoded = inner.encode_snap_message(decoded_msg.clone());
477
478 let re_encoded_bytes = BytesMut::from(&encoded[..]);
479
480 let re_decoded = inner.decode_message(re_encoded_bytes);
482
483 assert!(matches!(re_decoded, Ok(EthSnapMessage::Snap(_))));
484 if let Ok(EthSnapMessage::Snap(final_msg)) = re_decoded {
485 assert_eq!(final_msg, decoded_msg);
486 }
487 }
488 }
489
490 #[test]
491 fn test_snap2_protocol() {
492 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new_with_snap_version(
493 EthVersion::Eth67,
494 SnapVersion::V2,
495 );
496 let (snap_msg, snap_bytes) = create_snap2_message();
497
498 let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
499 assert!(!encoded_bytes.is_empty());
500
501 let decoded_result = inner.decode_message(snap_bytes.clone());
502 assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
503
504 if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
505 assert_eq!(decoded_msg, snap_msg);
506 }
507 }
508
509 #[test]
510 fn test_message_id_boundaries() {
511 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
512
513 let eth_max_id = EthMessageID::max(EthVersion::Eth67);
515 let mut eth_boundary_bytes = BytesMut::new();
516 eth_boundary_bytes.extend_from_slice(&[eth_max_id]);
517 eth_boundary_bytes.extend_from_slice(&[0, 0]);
518
519 let eth_boundary_result = inner.decode_message(eth_boundary_bytes);
521 assert!(
522 eth_boundary_result.is_err() ||
523 matches!(eth_boundary_result, Ok(EthSnapMessage::Eth(_)))
524 );
525
526 let snap_min_id = eth_max_id + 1;
528 let mut snap_boundary_bytes = BytesMut::new();
529 snap_boundary_bytes.extend_from_slice(&[snap_min_id]);
530 snap_boundary_bytes.extend_from_slice(&[0, 0]);
531
532 let snap_boundary_result = inner.decode_message(snap_boundary_bytes);
534 assert!(snap_boundary_result.is_err());
535 }
536
537 #[test]
538 fn test_eth70_message_id_0x12_is_snap() {
539 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth70);
540 let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
541 request_id: 1,
542 root_hash: B256::default(),
543 starting_hash: B256::default(),
544 limit_hash: B256::default(),
545 response_bytes: 1000,
546 });
547
548 let encoded = inner.encode_snap_message(snap_msg);
549 assert_eq!(encoded[0], EthMessageID::message_count(EthVersion::Eth70));
550
551 let decoded = inner.decode_message(BytesMut::from(&encoded[..])).unwrap();
552 assert!(matches!(decoded, EthSnapMessage::Snap(_)));
553 }
554
555 #[test]
556 fn test_eth71_message_id_0x12_is_eth() {
557 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth71);
558 let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
559 request_id: 1,
560 message: GetBlockAccessLists(vec![B256::ZERO]),
561 });
562 let protocol_msg = ProtocolMessage::from(eth_msg.clone());
563 let mut buf = Vec::new();
564 protocol_msg.encode(&mut buf);
565
566 let decoded = inner.decode_message(BytesMut::from(&buf[..])).unwrap();
567 let EthSnapMessage::Eth(decoded_eth) = decoded else {
568 panic!("expected eth message");
569 };
570 assert_eq!(decoded_eth, eth_msg);
571 }
572
573 #[test]
574 fn test_snap1_rejects_snap2_message_ids() {
575 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
576 let snap2_msg = SnapProtocolMessage::BlockAccessLists(BlockAccessListsMessage {
577 request_id: 1,
578 block_access_lists: BlockAccessLists(vec![Some(alloy_primitives::Bytes::from_static(
579 &[alloy_rlp::EMPTY_LIST_CODE],
580 ))]),
581 });
582
583 let encoded = EthSnapStreamInner::<EthNetworkPrimitives>::new_with_snap_version(
584 EthVersion::Eth67,
585 SnapVersion::V2,
586 )
587 .encode_snap_message(snap2_msg);
588
589 let decoded = inner.decode_message(BytesMut::from(&encoded[..]));
590 assert!(matches!(decoded, Err(EthSnapStreamError::UnknownMessageId(_))));
591 }
592}