Skip to main content

reth_eth_wire/
eth_snap_stream.rs

1//! Ethereum and snap combined protocol stream implementation.
2//!
3//! A stream type for handling both eth and snap protocol messages over a single `RLPx` connection.
4//! Provides message encoding/decoding, ID multiplexing, and protocol message processing.
5
6use super::message::MAX_MESSAGE_SIZE;
7use crate::{
8    message::{EthBroadcastMessage, ProtocolBroadcastMessage},
9    EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion, NetworkPrimitives, ProtocolMessage,
10    RawCapabilityMessage, SnapMessageId, SnapProtocolMessage,
11};
12use alloy_rlp::{Bytes, BytesMut, Encodable};
13use core::fmt::Debug;
14use futures::{Sink, SinkExt};
15use pin_project::pin_project;
16use std::{
17    marker::PhantomData,
18    pin::Pin,
19    task::{ready, Context, Poll},
20};
21use tokio_stream::Stream;
22
23/// Error type for the eth and snap stream
24#[derive(thiserror::Error, Debug)]
25pub enum EthSnapStreamError {
26    /// Invalid message for protocol version
27    #[error("invalid message for version {0:?}: {1}")]
28    InvalidMessage(EthVersion, String),
29
30    /// Unknown message ID
31    #[error("unknown message id: {0}")]
32    UnknownMessageId(u8),
33
34    /// Message too large
35    #[error("message too large: {0} > {1}")]
36    MessageTooLarge(usize, usize),
37
38    /// RLP decoding error
39    #[error("rlp error: {0}")]
40    Rlp(#[from] alloy_rlp::Error),
41
42    /// Status message received outside handshake
43    #[error("status message received outside handshake")]
44    StatusNotInHandshake,
45}
46
47/// Combined message type that include either eth or snap protocol messages
48#[derive(Debug)]
49pub enum EthSnapMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
50    /// An Ethereum protocol message
51    Eth(EthMessage<N>),
52    /// A snap protocol message
53    Snap(SnapProtocolMessage),
54}
55
56/// A stream implementation that can handle both eth and snap protocol messages
57/// over a single connection.
58#[pin_project]
59#[derive(Debug, Clone)]
60pub struct EthSnapStream<S, N = EthNetworkPrimitives> {
61    /// Protocol logic
62    eth_snap: EthSnapStreamInner<N>,
63    /// Inner byte stream
64    #[pin]
65    inner: S,
66}
67
68impl<S, N> EthSnapStream<S, N>
69where
70    N: NetworkPrimitives,
71{
72    /// Create a new eth and snap protocol stream
73    pub const fn new(stream: S, eth_version: EthVersion) -> Self {
74        Self { eth_snap: EthSnapStreamInner::new(eth_version), inner: stream }
75    }
76
77    /// Create a new eth and snap protocol stream with a custom max message size.
78    pub const fn with_max_message_size(
79        stream: S,
80        eth_version: EthVersion,
81        max_message_size: usize,
82    ) -> Self {
83        Self {
84            eth_snap: EthSnapStreamInner::with_max_message_size(eth_version, max_message_size),
85            inner: stream,
86        }
87    }
88
89    /// Returns the eth version
90    #[inline]
91    pub const fn eth_version(&self) -> EthVersion {
92        self.eth_snap.eth_version()
93    }
94
95    /// Returns the underlying stream
96    #[inline]
97    pub const fn inner(&self) -> &S {
98        &self.inner
99    }
100
101    /// Returns mutable access to the underlying stream
102    #[inline]
103    pub const fn inner_mut(&mut self) -> &mut S {
104        &mut self.inner
105    }
106
107    /// Consumes this type and returns the wrapped stream
108    #[inline]
109    pub fn into_inner(self) -> S {
110        self.inner
111    }
112}
113
114impl<S, E, N> EthSnapStream<S, N>
115where
116    S: Sink<Bytes, Error = E> + Unpin,
117    EthSnapStreamError: From<E>,
118    N: NetworkPrimitives,
119{
120    /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
121    pub fn start_send_broadcast(
122        &mut self,
123        item: EthBroadcastMessage<N>,
124    ) -> Result<(), EthSnapStreamError> {
125        self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
126            ProtocolBroadcastMessage::from(item),
127        )))?;
128
129        Ok(())
130    }
131
132    /// Sends a raw capability message directly over the stream
133    pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthSnapStreamError> {
134        let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
135        msg.id.encode(&mut bytes);
136        bytes.extend_from_slice(&msg.payload);
137
138        self.inner.start_send_unpin(bytes.into())?;
139        Ok(())
140    }
141}
142
143impl<S, E, N> Stream for EthSnapStream<S, N>
144where
145    S: Stream<Item = Result<BytesMut, E>> + Unpin,
146    EthSnapStreamError: From<E>,
147    N: NetworkPrimitives,
148{
149    type Item = Result<EthSnapMessage<N>, EthSnapStreamError>;
150
151    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152        let this = self.project();
153        let res = ready!(this.inner.poll_next(cx));
154
155        match res {
156            Some(Ok(bytes)) => Poll::Ready(Some(this.eth_snap.decode_message(bytes))),
157            Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
158            None => Poll::Ready(None),
159        }
160    }
161}
162
163impl<S, E, N> Sink<EthSnapMessage<N>> for EthSnapStream<S, N>
164where
165    S: Sink<Bytes, Error = E> + Unpin,
166    EthSnapStreamError: From<E>,
167    N: NetworkPrimitives,
168{
169    type Error = EthSnapStreamError;
170
171    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172        self.project().inner.poll_ready(cx).map_err(Into::into)
173    }
174
175    fn start_send(mut self: Pin<&mut Self>, item: EthSnapMessage<N>) -> Result<(), Self::Error> {
176        let mut this = self.as_mut().project();
177
178        let bytes = match item {
179            EthSnapMessage::Eth(eth_msg) => this.eth_snap.encode_eth_message(eth_msg)?,
180            EthSnapMessage::Snap(snap_msg) => this.eth_snap.encode_snap_message(snap_msg),
181        };
182
183        this.inner.start_send_unpin(bytes)?;
184        Ok(())
185    }
186
187    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188        self.project().inner.poll_flush(cx).map_err(Into::into)
189    }
190
191    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192        self.project().inner.poll_close(cx).map_err(Into::into)
193    }
194}
195
196/// Stream handling combined eth and snap protocol logic
197/// Snap version is not critical to specify yet,
198/// Only one version, snap/1, does exist.
199#[derive(Debug, Clone)]
200struct EthSnapStreamInner<N> {
201    /// Eth protocol version
202    eth_version: EthVersion,
203    /// Maximum allowed ETH/Snap message size.
204    max_message_size: usize,
205    /// Type marker
206    _pd: PhantomData<N>,
207}
208
209impl<N> EthSnapStreamInner<N>
210where
211    N: NetworkPrimitives,
212{
213    /// Create a new eth and snap protocol stream
214    const fn new(eth_version: EthVersion) -> Self {
215        Self::with_max_message_size(eth_version, MAX_MESSAGE_SIZE)
216    }
217
218    /// Create a new eth and snap protocol stream with a custom max message size.
219    const fn with_max_message_size(eth_version: EthVersion, max_message_size: usize) -> Self {
220        Self { eth_version, max_message_size, _pd: PhantomData }
221    }
222
223    #[inline]
224    const fn eth_version(&self) -> EthVersion {
225        self.eth_version
226    }
227
228    /// Decode a message from the stream
229    fn decode_message(&self, bytes: BytesMut) -> Result<EthSnapMessage<N>, EthSnapStreamError> {
230        if bytes.len() > self.max_message_size {
231            return Err(EthSnapStreamError::MessageTooLarge(bytes.len(), self.max_message_size));
232        }
233
234        if bytes.is_empty() {
235            return Err(EthSnapStreamError::Rlp(alloy_rlp::Error::InputTooShort));
236        }
237
238        let message_id = bytes[0];
239
240        // This check works because capabilities are sorted lexicographically
241        // if "eth" before "snap", giving eth messages lower IDs than snap messages,
242        // and eth message IDs are <= [`EthMessageID::max()`],
243        // snap message IDs are > [`EthMessageID::max()`].
244        // See also <https://github.com/paradigmxyz/reth/blob/main/crates/net/eth-wire/src/capability.rs#L272-L283>.
245        if message_id <= EthMessageID::max(self.eth_version) {
246            let mut buf = bytes.as_ref();
247            match ProtocolMessage::decode_message(self.eth_version, &mut buf) {
248                Ok(protocol_msg) => {
249                    if matches!(protocol_msg.message, EthMessage::Status(_)) {
250                        return Err(EthSnapStreamError::StatusNotInHandshake);
251                    }
252                    Ok(EthSnapMessage::Eth(protocol_msg.message))
253                }
254                Err(err) => {
255                    Err(EthSnapStreamError::InvalidMessage(self.eth_version, err.to_string()))
256                }
257            }
258        } else if message_id > EthMessageID::max(self.eth_version) &&
259            message_id <=
260                EthMessageID::message_count(self.eth_version) + SnapMessageId::TrieNodes as u8
261        {
262            // Checks for multiplexed snap message IDs :
263            // - message_id > EthMessageID::max() : ensures it's not an eth message
264            // - message_id <= EthMessageID::message_count() + snap_max : ensures it's within valid
265            //   snap range
266            // Message IDs are assigned lexicographically during capability negotiation
267            // So real_snap_id = multiplexed_id - num_eth_messages
268            let adjusted_message_id = message_id - EthMessageID::message_count(self.eth_version);
269            let mut buf = &bytes[1..];
270
271            match SnapProtocolMessage::decode(adjusted_message_id, &mut buf) {
272                Ok(snap_msg) => Ok(EthSnapMessage::Snap(snap_msg)),
273                Err(err) => Err(EthSnapStreamError::Rlp(err)),
274            }
275        } else {
276            Err(EthSnapStreamError::UnknownMessageId(message_id))
277        }
278    }
279
280    /// Encode an eth message
281    fn encode_eth_message(&self, item: EthMessage<N>) -> Result<Bytes, EthSnapStreamError> {
282        if matches!(item, EthMessage::Status(_)) {
283            return Err(EthSnapStreamError::StatusNotInHandshake);
284        }
285
286        let protocol_msg = ProtocolMessage::from(item);
287        let mut buf = Vec::new();
288        protocol_msg.encode(&mut buf);
289        Ok(Bytes::from(buf))
290    }
291
292    /// Encode a snap protocol message, adjusting the message ID to follow eth message IDs
293    /// for proper multiplexing.
294    fn encode_snap_message(&self, message: SnapProtocolMessage) -> Bytes {
295        let encoded = message.encode();
296
297        let message_id = encoded[0];
298        let adjusted_id = message_id + EthMessageID::message_count(self.eth_version);
299
300        let mut adjusted = Vec::with_capacity(encoded.len());
301        adjusted.push(adjusted_id);
302        adjusted.extend_from_slice(&encoded[1..]);
303
304        Bytes::from(adjusted)
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::{EthMessage, SnapProtocolMessage};
312    use alloy_eips::BlockHashOrNumber;
313    use alloy_primitives::B256;
314    use alloy_rlp::Encodable;
315    use reth_eth_wire_types::{
316        message::RequestPair, GetAccountRangeMessage, GetBlockAccessLists, GetBlockHeaders,
317        HeadersDirection,
318    };
319
320    // Helper to create eth message and its bytes
321    fn create_eth_message() -> (EthMessage<EthNetworkPrimitives>, BytesMut) {
322        let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockHeaders(RequestPair {
323            request_id: 1,
324            message: GetBlockHeaders {
325                start_block: BlockHashOrNumber::Number(1),
326                limit: 10,
327                skip: 0,
328                direction: HeadersDirection::Rising,
329            },
330        });
331
332        let protocol_msg = ProtocolMessage::from(eth_msg.clone());
333        let mut buf = Vec::new();
334        protocol_msg.encode(&mut buf);
335
336        (eth_msg, BytesMut::from(&buf[..]))
337    }
338
339    // Helper to create snap message and its bytes
340    fn create_snap_message() -> (SnapProtocolMessage, BytesMut) {
341        let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
342            request_id: 1,
343            root_hash: B256::default(),
344            starting_hash: B256::default(),
345            limit_hash: B256::default(),
346            response_bytes: 1000,
347        });
348
349        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
350        let encoded = inner.encode_snap_message(snap_msg.clone());
351
352        (snap_msg, BytesMut::from(&encoded[..]))
353    }
354
355    #[test]
356    fn test_eth_message_roundtrip() {
357        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
358        let (eth_msg, eth_bytes) = create_eth_message();
359
360        // Verify encoding
361        let encoded_result = inner.encode_eth_message(eth_msg.clone());
362        assert!(encoded_result.is_ok());
363
364        // Verify decoding
365        let decoded_result = inner.decode_message(eth_bytes.clone());
366        assert!(matches!(decoded_result, Ok(EthSnapMessage::Eth(_))));
367
368        // round trip
369        if let Ok(EthSnapMessage::Eth(decoded_msg)) = inner.decode_message(eth_bytes) {
370            assert_eq!(decoded_msg, eth_msg);
371
372            let re_encoded = inner.encode_eth_message(decoded_msg.clone()).unwrap();
373            let re_encoded_bytes = BytesMut::from(&re_encoded[..]);
374            let re_decoded = inner.decode_message(re_encoded_bytes);
375
376            assert!(matches!(re_decoded, Ok(EthSnapMessage::Eth(_))));
377            if let Ok(EthSnapMessage::Eth(final_msg)) = re_decoded {
378                assert_eq!(final_msg, decoded_msg);
379            }
380        }
381    }
382
383    #[test]
384    fn test_snap_protocol() {
385        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
386        let (snap_msg, snap_bytes) = create_snap_message();
387
388        // Verify encoding
389        let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
390        assert!(!encoded_bytes.is_empty());
391
392        // Verify decoding
393        let decoded_result = inner.decode_message(snap_bytes.clone());
394        assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
395
396        // round trip
397        if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
398            assert_eq!(decoded_msg, snap_msg);
399
400            // re-encode message
401            let encoded = inner.encode_snap_message(decoded_msg.clone());
402
403            let re_encoded_bytes = BytesMut::from(&encoded[..]);
404
405            // decode with properly adjusted ID
406            let re_decoded = inner.decode_message(re_encoded_bytes);
407
408            assert!(matches!(re_decoded, Ok(EthSnapMessage::Snap(_))));
409            if let Ok(EthSnapMessage::Snap(final_msg)) = re_decoded {
410                assert_eq!(final_msg, decoded_msg);
411            }
412        }
413    }
414
415    #[test]
416    fn test_message_id_boundaries() {
417        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
418
419        // Create a bytes buffer with eth message ID at the max boundary with minimal content
420        let eth_max_id = EthMessageID::max(EthVersion::Eth67);
421        let mut eth_boundary_bytes = BytesMut::new();
422        eth_boundary_bytes.extend_from_slice(&[eth_max_id]);
423        eth_boundary_bytes.extend_from_slice(&[0, 0]);
424
425        // This should be decoded as eth message
426        let eth_boundary_result = inner.decode_message(eth_boundary_bytes);
427        assert!(
428            eth_boundary_result.is_err() ||
429                matches!(eth_boundary_result, Ok(EthSnapMessage::Eth(_)))
430        );
431
432        // Create a bytes buffer with message ID just above eth max, it should be snap min
433        let snap_min_id = eth_max_id + 1;
434        let mut snap_boundary_bytes = BytesMut::new();
435        snap_boundary_bytes.extend_from_slice(&[snap_min_id]);
436        snap_boundary_bytes.extend_from_slice(&[0, 0]);
437
438        // Not a valid snap message yet, only snap id --> error
439        let snap_boundary_result = inner.decode_message(snap_boundary_bytes);
440        assert!(snap_boundary_result.is_err());
441    }
442
443    #[test]
444    fn test_eth70_message_id_0x12_is_snap() {
445        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth70);
446        let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
447            request_id: 1,
448            root_hash: B256::default(),
449            starting_hash: B256::default(),
450            limit_hash: B256::default(),
451            response_bytes: 1000,
452        });
453
454        let encoded = inner.encode_snap_message(snap_msg);
455        assert_eq!(encoded[0], EthMessageID::message_count(EthVersion::Eth70));
456
457        let decoded = inner.decode_message(BytesMut::from(&encoded[..])).unwrap();
458        assert!(matches!(decoded, EthSnapMessage::Snap(_)));
459    }
460
461    #[test]
462    fn test_eth71_message_id_0x12_is_eth() {
463        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth71);
464        let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
465            request_id: 1,
466            message: GetBlockAccessLists(vec![B256::ZERO]),
467        });
468        let protocol_msg = ProtocolMessage::from(eth_msg.clone());
469        let mut buf = Vec::new();
470        protocol_msg.encode(&mut buf);
471
472        let decoded = inner.decode_message(BytesMut::from(&buf[..])).unwrap();
473        let EthSnapMessage::Eth(decoded_eth) = decoded else {
474            panic!("expected eth message");
475        };
476        assert_eq!(decoded_eth, eth_msg);
477    }
478}