Skip to main content

reth_eth_wire/
eth_snap_stream.rs

1//! Ethereum and snap combined protocol stream implementation.
2//!
3//! A stream type for handling both eth and snap protocol messages over a single `RLPx` connection.
4//! Provides message encoding/decoding, ID multiplexing, and protocol message processing.
5
6use super::message::MAX_MESSAGE_SIZE;
7use crate::{
8    message::{EthBroadcastMessage, ProtocolBroadcastMessage, TX_MEMORY_BUDGET_MULTIPLIER},
9    EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion, NetworkPrimitives, ProtocolMessage,
10    RawCapabilityMessage, SnapProtocolMessage, SnapVersion,
11};
12use alloy_rlp::{Bytes, BytesMut, Encodable};
13use core::fmt::Debug;
14use futures::{Sink, SinkExt};
15use pin_project::pin_project;
16use std::{
17    marker::PhantomData,
18    pin::Pin,
19    task::{ready, Context, Poll},
20};
21use tokio_stream::Stream;
22
23/// Error type for the eth and snap stream
24#[derive(thiserror::Error, Debug)]
25pub enum EthSnapStreamError {
26    /// Invalid message for protocol version
27    #[error("invalid message for version {0:?}: {1}")]
28    InvalidMessage(EthVersion, String),
29
30    /// Unknown message ID
31    #[error("unknown message id: {0}")]
32    UnknownMessageId(u8),
33
34    /// Message too large
35    #[error("message too large: {0} > {1}")]
36    MessageTooLarge(usize, usize),
37
38    /// RLP decoding error
39    #[error("rlp error: {0}")]
40    Rlp(#[from] alloy_rlp::Error),
41
42    /// Status message received outside handshake
43    #[error("status message received outside handshake")]
44    StatusNotInHandshake,
45}
46
47/// Combined message type that include either eth or snap protocol messages
48#[derive(Debug)]
49pub enum EthSnapMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
50    /// An Ethereum protocol message
51    Eth(EthMessage<N>),
52    /// A snap protocol message
53    Snap(SnapProtocolMessage),
54}
55
56/// A stream implementation that can handle both eth and snap protocol messages
57/// over a single connection.
58#[pin_project]
59#[derive(Debug, Clone)]
60pub struct EthSnapStream<S, N = EthNetworkPrimitives> {
61    /// Protocol logic
62    eth_snap: EthSnapStreamInner<N>,
63    /// Inner byte stream
64    #[pin]
65    inner: S,
66}
67
68impl<S, N> EthSnapStream<S, N>
69where
70    N: NetworkPrimitives,
71{
72    /// Create a new eth and snap protocol stream
73    pub const fn new(stream: S, eth_version: EthVersion) -> Self {
74        Self { eth_snap: EthSnapStreamInner::new(eth_version), inner: stream }
75    }
76
77    /// Create a new eth and snap protocol stream with an explicit snap version.
78    pub const fn new_with_snap_version(
79        stream: S,
80        eth_version: EthVersion,
81        snap_version: SnapVersion,
82    ) -> Self {
83        Self {
84            eth_snap: EthSnapStreamInner::new_with_snap_version(eth_version, snap_version),
85            inner: stream,
86        }
87    }
88
89    /// Create a new eth and snap protocol stream with a custom max message size.
90    pub const fn with_max_message_size(
91        stream: S,
92        eth_version: EthVersion,
93        max_message_size: usize,
94    ) -> Self {
95        Self {
96            eth_snap: EthSnapStreamInner::with_max_message_size(eth_version, max_message_size),
97            inner: stream,
98        }
99    }
100
101    /// Create a new eth and snap protocol stream with a custom max message size and snap version.
102    pub const fn with_max_message_size_and_snap_version(
103        stream: S,
104        eth_version: EthVersion,
105        snap_version: SnapVersion,
106        max_message_size: usize,
107    ) -> Self {
108        Self {
109            eth_snap: EthSnapStreamInner::with_max_message_size_and_snap_version(
110                eth_version,
111                snap_version,
112                max_message_size,
113            ),
114            inner: stream,
115        }
116    }
117
118    /// Returns the eth version
119    #[inline]
120    pub const fn eth_version(&self) -> EthVersion {
121        self.eth_snap.eth_version()
122    }
123
124    /// Returns the snap version.
125    #[inline]
126    pub const fn snap_version(&self) -> SnapVersion {
127        self.eth_snap.snap_version()
128    }
129
130    /// Returns the underlying stream
131    #[inline]
132    pub const fn inner(&self) -> &S {
133        &self.inner
134    }
135
136    /// Returns mutable access to the underlying stream
137    #[inline]
138    pub const fn inner_mut(&mut self) -> &mut S {
139        &mut self.inner
140    }
141
142    /// Consumes this type and returns the wrapped stream
143    #[inline]
144    pub fn into_inner(self) -> S {
145        self.inner
146    }
147}
148
149impl<S, E, N> EthSnapStream<S, N>
150where
151    S: Sink<Bytes, Error = E> + Unpin,
152    EthSnapStreamError: From<E>,
153    N: NetworkPrimitives,
154{
155    /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
156    pub fn start_send_broadcast(
157        &mut self,
158        item: EthBroadcastMessage<N>,
159    ) -> Result<(), EthSnapStreamError> {
160        self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
161            ProtocolBroadcastMessage::from(item),
162        )))?;
163
164        Ok(())
165    }
166
167    /// Sends a raw capability message directly over the stream
168    pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthSnapStreamError> {
169        let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
170        msg.id.encode(&mut bytes);
171        bytes.extend_from_slice(&msg.payload);
172
173        self.inner.start_send_unpin(bytes.into())?;
174        Ok(())
175    }
176}
177
178impl<S, E, N> Stream for EthSnapStream<S, N>
179where
180    S: Stream<Item = Result<BytesMut, E>> + Unpin,
181    EthSnapStreamError: From<E>,
182    N: NetworkPrimitives,
183{
184    type Item = Result<EthSnapMessage<N>, EthSnapStreamError>;
185
186    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187        let this = self.project();
188        let res = ready!(this.inner.poll_next(cx));
189
190        match res {
191            Some(Ok(bytes)) => Poll::Ready(Some(this.eth_snap.decode_message(bytes))),
192            Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
193            None => Poll::Ready(None),
194        }
195    }
196}
197
198impl<S, E, N> Sink<EthSnapMessage<N>> for EthSnapStream<S, N>
199where
200    S: Sink<Bytes, Error = E> + Unpin,
201    EthSnapStreamError: From<E>,
202    N: NetworkPrimitives,
203{
204    type Error = EthSnapStreamError;
205
206    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207        self.project().inner.poll_ready(cx).map_err(Into::into)
208    }
209
210    fn start_send(mut self: Pin<&mut Self>, item: EthSnapMessage<N>) -> Result<(), Self::Error> {
211        let mut this = self.as_mut().project();
212
213        let bytes = match item {
214            EthSnapMessage::Eth(eth_msg) => this.eth_snap.encode_eth_message(eth_msg)?,
215            EthSnapMessage::Snap(snap_msg) => this.eth_snap.encode_snap_message(snap_msg),
216        };
217
218        this.inner.start_send_unpin(bytes)?;
219        Ok(())
220    }
221
222    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223        self.project().inner.poll_flush(cx).map_err(Into::into)
224    }
225
226    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227        self.project().inner.poll_close(cx).map_err(Into::into)
228    }
229}
230
231/// Stream handling combined eth and snap protocol logic.
232#[derive(Debug, Clone)]
233struct EthSnapStreamInner<N> {
234    /// Eth protocol version
235    eth_version: EthVersion,
236    /// Snap protocol version
237    snap_version: SnapVersion,
238    /// Maximum allowed ETH/Snap message size.
239    max_message_size: usize,
240    /// Type marker
241    _pd: PhantomData<N>,
242}
243
244impl<N> EthSnapStreamInner<N>
245where
246    N: NetworkPrimitives,
247{
248    /// Create a new eth and snap protocol stream
249    const fn new(eth_version: EthVersion) -> Self {
250        Self::new_with_snap_version(eth_version, SnapVersion::V1)
251    }
252
253    /// Create a new eth and snap protocol stream with an explicit snap version.
254    const fn new_with_snap_version(eth_version: EthVersion, snap_version: SnapVersion) -> Self {
255        Self::with_max_message_size_and_snap_version(eth_version, snap_version, MAX_MESSAGE_SIZE)
256    }
257
258    /// Create a new eth and snap protocol stream with a custom max message size.
259    const fn with_max_message_size(eth_version: EthVersion, max_message_size: usize) -> Self {
260        Self::with_max_message_size_and_snap_version(eth_version, SnapVersion::V1, max_message_size)
261    }
262
263    /// Create a new eth and snap protocol stream with a custom max message size and snap version.
264    const fn with_max_message_size_and_snap_version(
265        eth_version: EthVersion,
266        snap_version: SnapVersion,
267        max_message_size: usize,
268    ) -> Self {
269        Self { eth_version, snap_version, max_message_size, _pd: PhantomData }
270    }
271
272    #[inline]
273    const fn eth_version(&self) -> EthVersion {
274        self.eth_version
275    }
276
277    #[inline]
278    const fn snap_version(&self) -> SnapVersion {
279        self.snap_version
280    }
281
282    /// Decode a message from the stream
283    fn decode_message(&self, bytes: BytesMut) -> Result<EthSnapMessage<N>, EthSnapStreamError> {
284        if bytes.len() > self.max_message_size {
285            return Err(EthSnapStreamError::MessageTooLarge(bytes.len(), self.max_message_size));
286        }
287
288        if bytes.is_empty() {
289            return Err(EthSnapStreamError::Rlp(alloy_rlp::Error::InputTooShort));
290        }
291
292        let message_id = bytes[0];
293
294        // This check works because capabilities are sorted lexicographically
295        // if "eth" before "snap", giving eth messages lower IDs than snap messages,
296        // and eth message IDs are <= [`EthMessageID::max()`],
297        // snap message IDs are > [`EthMessageID::max()`].
298        // See also <https://github.com/paradigmxyz/reth/blob/main/crates/net/eth-wire/src/capability.rs#L272-L283>.
299        if message_id <= EthMessageID::max(self.eth_version) {
300            let mut buf = bytes.as_ref();
301            match ProtocolMessage::decode_message_with_tx_memory_budget(
302                self.eth_version,
303                &mut buf,
304                self.max_message_size * TX_MEMORY_BUDGET_MULTIPLIER,
305            ) {
306                Ok(protocol_msg) => {
307                    if matches!(protocol_msg.message, EthMessage::Status(_)) {
308                        return Err(EthSnapStreamError::StatusNotInHandshake);
309                    }
310                    Ok(EthSnapMessage::Eth(protocol_msg.message))
311                }
312                Err(err) => {
313                    Err(EthSnapStreamError::InvalidMessage(self.eth_version, err.to_string()))
314                }
315            }
316        } else if message_id > EthMessageID::max(self.eth_version) &&
317            message_id <
318                EthMessageID::message_count(self.eth_version) +
319                    self.snap_version.message_count()
320        {
321            // Checks for multiplexed snap message IDs :
322            // - message_id > EthMessageID::max() : ensures it's not an eth message
323            // - message_id <= EthMessageID::message_count() + snap_max : ensures it's within valid
324            //   snap range
325            // Message IDs are assigned lexicographically during capability negotiation
326            // So real_snap_id = multiplexed_id - num_eth_messages
327            let adjusted_message_id = message_id - EthMessageID::message_count(self.eth_version);
328            let mut buf = &bytes[1..];
329
330            match SnapProtocolMessage::decode(adjusted_message_id, &mut buf) {
331                Ok(snap_msg) => Ok(EthSnapMessage::Snap(snap_msg)),
332                Err(err) => Err(EthSnapStreamError::Rlp(err)),
333            }
334        } else {
335            Err(EthSnapStreamError::UnknownMessageId(message_id))
336        }
337    }
338
339    /// Encode an eth message
340    fn encode_eth_message(&self, item: EthMessage<N>) -> Result<Bytes, EthSnapStreamError> {
341        if matches!(item, EthMessage::Status(_)) {
342            return Err(EthSnapStreamError::StatusNotInHandshake);
343        }
344
345        let protocol_msg = ProtocolMessage::from(item);
346        let mut buf = Vec::new();
347        protocol_msg.encode(&mut buf);
348        Ok(Bytes::from(buf))
349    }
350
351    /// Encode a snap protocol message, adjusting the message ID to follow eth message IDs
352    /// for proper multiplexing.
353    fn encode_snap_message(&self, message: SnapProtocolMessage) -> Bytes {
354        let encoded = message.encode();
355
356        let message_id = encoded[0];
357        let adjusted_id = message_id + EthMessageID::message_count(self.eth_version);
358
359        let mut adjusted = Vec::with_capacity(encoded.len());
360        adjusted.push(adjusted_id);
361        adjusted.extend_from_slice(&encoded[1..]);
362
363        Bytes::from(adjusted)
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::{EthMessage, SnapProtocolMessage};
371    use alloy_eips::BlockHashOrNumber;
372    use alloy_primitives::B256;
373    use alloy_rlp::Encodable;
374    use reth_eth_wire_types::{
375        message::RequestPair, BlockAccessLists, BlockAccessListsMessage, GetAccountRangeMessage,
376        GetBlockAccessLists, GetBlockAccessListsMessage, GetBlockHeaders, HeadersDirection,
377    };
378
379    // Helper to create eth message and its bytes
380    fn create_eth_message() -> (EthMessage<EthNetworkPrimitives>, BytesMut) {
381        let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockHeaders(RequestPair {
382            request_id: 1,
383            message: GetBlockHeaders {
384                start_block: BlockHashOrNumber::Number(1),
385                limit: 10,
386                skip: 0,
387                direction: HeadersDirection::Rising,
388            },
389        });
390
391        let protocol_msg = ProtocolMessage::from(eth_msg.clone());
392        let mut buf = Vec::new();
393        protocol_msg.encode(&mut buf);
394
395        (eth_msg, BytesMut::from(&buf[..]))
396    }
397
398    // Helper to create snap message and its bytes
399    fn create_snap_message() -> (SnapProtocolMessage, BytesMut) {
400        let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
401            request_id: 1,
402            root_hash: B256::default(),
403            starting_hash: B256::default(),
404            limit_hash: B256::default(),
405            response_bytes: 1000,
406        });
407
408        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
409        let encoded = inner.encode_snap_message(snap_msg.clone());
410
411        (snap_msg, BytesMut::from(&encoded[..]))
412    }
413
414    fn create_snap2_message() -> (SnapProtocolMessage, BytesMut) {
415        let snap_msg = SnapProtocolMessage::GetBlockAccessLists(GetBlockAccessListsMessage {
416            request_id: 1,
417            block_hashes: vec![B256::default()],
418            response_bytes: 1000,
419        });
420
421        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new_with_snap_version(
422            EthVersion::Eth67,
423            SnapVersion::V2,
424        );
425        let encoded = inner.encode_snap_message(snap_msg.clone());
426
427        (snap_msg, BytesMut::from(&encoded[..]))
428    }
429
430    #[test]
431    fn test_eth_message_roundtrip() {
432        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
433        let (eth_msg, eth_bytes) = create_eth_message();
434
435        // Verify encoding
436        let encoded_result = inner.encode_eth_message(eth_msg.clone());
437        assert!(encoded_result.is_ok());
438
439        // Verify decoding
440        let decoded_result = inner.decode_message(eth_bytes.clone());
441        assert!(matches!(decoded_result, Ok(EthSnapMessage::Eth(_))));
442
443        // round trip
444        if let Ok(EthSnapMessage::Eth(decoded_msg)) = inner.decode_message(eth_bytes) {
445            assert_eq!(decoded_msg, eth_msg);
446
447            let re_encoded = inner.encode_eth_message(decoded_msg.clone()).unwrap();
448            let re_encoded_bytes = BytesMut::from(&re_encoded[..]);
449            let re_decoded = inner.decode_message(re_encoded_bytes);
450
451            assert!(matches!(re_decoded, Ok(EthSnapMessage::Eth(_))));
452            if let Ok(EthSnapMessage::Eth(final_msg)) = re_decoded {
453                assert_eq!(final_msg, decoded_msg);
454            }
455        }
456    }
457
458    #[test]
459    fn test_snap_protocol() {
460        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
461        let (snap_msg, snap_bytes) = create_snap_message();
462
463        // Verify encoding
464        let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
465        assert!(!encoded_bytes.is_empty());
466
467        // Verify decoding
468        let decoded_result = inner.decode_message(snap_bytes.clone());
469        assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
470
471        // round trip
472        if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
473            assert_eq!(decoded_msg, snap_msg);
474
475            // re-encode message
476            let encoded = inner.encode_snap_message(decoded_msg.clone());
477
478            let re_encoded_bytes = BytesMut::from(&encoded[..]);
479
480            // decode with properly adjusted ID
481            let re_decoded = inner.decode_message(re_encoded_bytes);
482
483            assert!(matches!(re_decoded, Ok(EthSnapMessage::Snap(_))));
484            if let Ok(EthSnapMessage::Snap(final_msg)) = re_decoded {
485                assert_eq!(final_msg, decoded_msg);
486            }
487        }
488    }
489
490    #[test]
491    fn test_snap2_protocol() {
492        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new_with_snap_version(
493            EthVersion::Eth67,
494            SnapVersion::V2,
495        );
496        let (snap_msg, snap_bytes) = create_snap2_message();
497
498        let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
499        assert!(!encoded_bytes.is_empty());
500
501        let decoded_result = inner.decode_message(snap_bytes.clone());
502        assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
503
504        if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
505            assert_eq!(decoded_msg, snap_msg);
506        }
507    }
508
509    #[test]
510    fn test_message_id_boundaries() {
511        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
512
513        // Create a bytes buffer with eth message ID at the max boundary with minimal content
514        let eth_max_id = EthMessageID::max(EthVersion::Eth67);
515        let mut eth_boundary_bytes = BytesMut::new();
516        eth_boundary_bytes.extend_from_slice(&[eth_max_id]);
517        eth_boundary_bytes.extend_from_slice(&[0, 0]);
518
519        // This should be decoded as eth message
520        let eth_boundary_result = inner.decode_message(eth_boundary_bytes);
521        assert!(
522            eth_boundary_result.is_err() ||
523                matches!(eth_boundary_result, Ok(EthSnapMessage::Eth(_)))
524        );
525
526        // Create a bytes buffer with message ID just above eth max, it should be snap min
527        let snap_min_id = eth_max_id + 1;
528        let mut snap_boundary_bytes = BytesMut::new();
529        snap_boundary_bytes.extend_from_slice(&[snap_min_id]);
530        snap_boundary_bytes.extend_from_slice(&[0, 0]);
531
532        // Not a valid snap message yet, only snap id --> error
533        let snap_boundary_result = inner.decode_message(snap_boundary_bytes);
534        assert!(snap_boundary_result.is_err());
535    }
536
537    #[test]
538    fn test_eth70_message_id_0x12_is_snap() {
539        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth70);
540        let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
541            request_id: 1,
542            root_hash: B256::default(),
543            starting_hash: B256::default(),
544            limit_hash: B256::default(),
545            response_bytes: 1000,
546        });
547
548        let encoded = inner.encode_snap_message(snap_msg);
549        assert_eq!(encoded[0], EthMessageID::message_count(EthVersion::Eth70));
550
551        let decoded = inner.decode_message(BytesMut::from(&encoded[..])).unwrap();
552        assert!(matches!(decoded, EthSnapMessage::Snap(_)));
553    }
554
555    #[test]
556    fn test_eth71_message_id_0x12_is_eth() {
557        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth71);
558        let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
559            request_id: 1,
560            message: GetBlockAccessLists(vec![B256::ZERO]),
561        });
562        let protocol_msg = ProtocolMessage::from(eth_msg.clone());
563        let mut buf = Vec::new();
564        protocol_msg.encode(&mut buf);
565
566        let decoded = inner.decode_message(BytesMut::from(&buf[..])).unwrap();
567        let EthSnapMessage::Eth(decoded_eth) = decoded else {
568            panic!("expected eth message");
569        };
570        assert_eq!(decoded_eth, eth_msg);
571    }
572
573    #[test]
574    fn test_snap1_rejects_snap2_message_ids() {
575        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
576        let snap2_msg = SnapProtocolMessage::BlockAccessLists(BlockAccessListsMessage {
577            request_id: 1,
578            block_access_lists: BlockAccessLists(vec![Some(alloy_primitives::Bytes::from_static(
579                &[alloy_rlp::EMPTY_LIST_CODE],
580            ))]),
581        });
582
583        let encoded = EthSnapStreamInner::<EthNetworkPrimitives>::new_with_snap_version(
584            EthVersion::Eth67,
585            SnapVersion::V2,
586        )
587        .encode_snap_message(snap2_msg);
588
589        let decoded = inner.decode_message(BytesMut::from(&encoded[..]));
590        assert!(matches!(decoded, Err(EthSnapStreamError::UnknownMessageId(_))));
591    }
592}