Skip to main content

reth_eth_wire/
eth_snap_stream.rs

1//! Ethereum and snap combined protocol stream implementation.
2//!
3//! A stream type for handling both eth and snap protocol messages over a single `RLPx` connection.
4//! Provides message encoding/decoding, ID multiplexing, and protocol message processing.
5
6use super::message::MAX_MESSAGE_SIZE;
7use crate::{
8    message::{EthBroadcastMessage, ProtocolBroadcastMessage},
9    EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion, NetworkPrimitives, ProtocolMessage,
10    RawCapabilityMessage, SnapMessageId, SnapProtocolMessage,
11};
12use alloy_rlp::{Bytes, BytesMut, Encodable};
13use core::fmt::Debug;
14use futures::{Sink, SinkExt};
15use pin_project::pin_project;
16use std::{
17    marker::PhantomData,
18    pin::Pin,
19    task::{ready, Context, Poll},
20};
21use tokio_stream::Stream;
22
23/// Error type for the eth and snap stream
24#[derive(thiserror::Error, Debug)]
25pub enum EthSnapStreamError {
26    /// Invalid message for protocol version
27    #[error("invalid message for version {0:?}: {1}")]
28    InvalidMessage(EthVersion, String),
29
30    /// Unknown message ID
31    #[error("unknown message id: {0}")]
32    UnknownMessageId(u8),
33
34    /// Message too large
35    #[error("message too large: {0} > {1}")]
36    MessageTooLarge(usize, usize),
37
38    /// RLP decoding error
39    #[error("rlp error: {0}")]
40    Rlp(#[from] alloy_rlp::Error),
41
42    /// Status message received outside handshake
43    #[error("status message received outside handshake")]
44    StatusNotInHandshake,
45}
46
47/// Combined message type that include either eth or snap protocol messages
48#[derive(Debug)]
49pub enum EthSnapMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
50    /// An Ethereum protocol message
51    Eth(EthMessage<N>),
52    /// A snap protocol message
53    Snap(SnapProtocolMessage),
54}
55
56/// A stream implementation that can handle both eth and snap protocol messages
57/// over a single connection.
58#[pin_project]
59#[derive(Debug, Clone)]
60pub struct EthSnapStream<S, N = EthNetworkPrimitives> {
61    /// Protocol logic
62    eth_snap: EthSnapStreamInner<N>,
63    /// Inner byte stream
64    #[pin]
65    inner: S,
66}
67
68impl<S, N> EthSnapStream<S, N>
69where
70    N: NetworkPrimitives,
71{
72    /// Create a new eth and snap protocol stream
73    pub const fn new(stream: S, eth_version: EthVersion) -> Self {
74        Self { eth_snap: EthSnapStreamInner::new(eth_version), inner: stream }
75    }
76
77    /// Returns the eth version
78    #[inline]
79    pub const fn eth_version(&self) -> EthVersion {
80        self.eth_snap.eth_version()
81    }
82
83    /// Returns the underlying stream
84    #[inline]
85    pub const fn inner(&self) -> &S {
86        &self.inner
87    }
88
89    /// Returns mutable access to the underlying stream
90    #[inline]
91    pub const fn inner_mut(&mut self) -> &mut S {
92        &mut self.inner
93    }
94
95    /// Consumes this type and returns the wrapped stream
96    #[inline]
97    pub fn into_inner(self) -> S {
98        self.inner
99    }
100}
101
102impl<S, E, N> EthSnapStream<S, N>
103where
104    S: Sink<Bytes, Error = E> + Unpin,
105    EthSnapStreamError: From<E>,
106    N: NetworkPrimitives,
107{
108    /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
109    pub fn start_send_broadcast(
110        &mut self,
111        item: EthBroadcastMessage<N>,
112    ) -> Result<(), EthSnapStreamError> {
113        self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
114            ProtocolBroadcastMessage::from(item),
115        )))?;
116
117        Ok(())
118    }
119
120    /// Sends a raw capability message directly over the stream
121    pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthSnapStreamError> {
122        let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
123        msg.id.encode(&mut bytes);
124        bytes.extend_from_slice(&msg.payload);
125
126        self.inner.start_send_unpin(bytes.into())?;
127        Ok(())
128    }
129}
130
131impl<S, E, N> Stream for EthSnapStream<S, N>
132where
133    S: Stream<Item = Result<BytesMut, E>> + Unpin,
134    EthSnapStreamError: From<E>,
135    N: NetworkPrimitives,
136{
137    type Item = Result<EthSnapMessage<N>, EthSnapStreamError>;
138
139    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140        let this = self.project();
141        let res = ready!(this.inner.poll_next(cx));
142
143        match res {
144            Some(Ok(bytes)) => Poll::Ready(Some(this.eth_snap.decode_message(bytes))),
145            Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
146            None => Poll::Ready(None),
147        }
148    }
149}
150
151impl<S, E, N> Sink<EthSnapMessage<N>> for EthSnapStream<S, N>
152where
153    S: Sink<Bytes, Error = E> + Unpin,
154    EthSnapStreamError: From<E>,
155    N: NetworkPrimitives,
156{
157    type Error = EthSnapStreamError;
158
159    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160        self.project().inner.poll_ready(cx).map_err(Into::into)
161    }
162
163    fn start_send(mut self: Pin<&mut Self>, item: EthSnapMessage<N>) -> Result<(), Self::Error> {
164        let mut this = self.as_mut().project();
165
166        let bytes = match item {
167            EthSnapMessage::Eth(eth_msg) => this.eth_snap.encode_eth_message(eth_msg)?,
168            EthSnapMessage::Snap(snap_msg) => this.eth_snap.encode_snap_message(snap_msg),
169        };
170
171        this.inner.start_send_unpin(bytes)?;
172        Ok(())
173    }
174
175    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176        self.project().inner.poll_flush(cx).map_err(Into::into)
177    }
178
179    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180        self.project().inner.poll_close(cx).map_err(Into::into)
181    }
182}
183
184/// Stream handling combined eth and snap protocol logic
185/// Snap version is not critical to specify yet,
186/// Only one version, snap/1, does exist.
187#[derive(Debug, Clone)]
188struct EthSnapStreamInner<N> {
189    /// Eth protocol version
190    eth_version: EthVersion,
191    /// Type marker
192    _pd: PhantomData<N>,
193}
194
195impl<N> EthSnapStreamInner<N>
196where
197    N: NetworkPrimitives,
198{
199    /// Create a new eth and snap protocol stream
200    const fn new(eth_version: EthVersion) -> Self {
201        Self { eth_version, _pd: PhantomData }
202    }
203
204    #[inline]
205    const fn eth_version(&self) -> EthVersion {
206        self.eth_version
207    }
208
209    /// Decode a message from the stream
210    fn decode_message(&self, bytes: BytesMut) -> Result<EthSnapMessage<N>, EthSnapStreamError> {
211        if bytes.len() > MAX_MESSAGE_SIZE {
212            return Err(EthSnapStreamError::MessageTooLarge(bytes.len(), MAX_MESSAGE_SIZE));
213        }
214
215        if bytes.is_empty() {
216            return Err(EthSnapStreamError::Rlp(alloy_rlp::Error::InputTooShort));
217        }
218
219        let message_id = bytes[0];
220
221        // This check works because capabilities are sorted lexicographically
222        // if "eth" before "snap", giving eth messages lower IDs than snap messages,
223        // and eth message IDs are <= [`EthMessageID::max()`],
224        // snap message IDs are > [`EthMessageID::max()`].
225        // See also <https://github.com/paradigmxyz/reth/blob/main/crates/net/eth-wire/src/capability.rs#L272-L283>.
226        if message_id <= EthMessageID::max(self.eth_version) {
227            let mut buf = bytes.as_ref();
228            match ProtocolMessage::decode_message(self.eth_version, &mut buf) {
229                Ok(protocol_msg) => {
230                    if matches!(protocol_msg.message, EthMessage::Status(_)) {
231                        return Err(EthSnapStreamError::StatusNotInHandshake);
232                    }
233                    Ok(EthSnapMessage::Eth(protocol_msg.message))
234                }
235                Err(err) => {
236                    Err(EthSnapStreamError::InvalidMessage(self.eth_version, err.to_string()))
237                }
238            }
239        } else if message_id > EthMessageID::max(self.eth_version) &&
240            message_id <=
241                EthMessageID::message_count(self.eth_version) + SnapMessageId::TrieNodes as u8
242        {
243            // Checks for multiplexed snap message IDs :
244            // - message_id > EthMessageID::max() : ensures it's not an eth message
245            // - message_id <= EthMessageID::message_count() + snap_max : ensures it's within valid
246            //   snap range
247            // Message IDs are assigned lexicographically during capability negotiation
248            // So real_snap_id = multiplexed_id - num_eth_messages
249            let adjusted_message_id = message_id - EthMessageID::message_count(self.eth_version);
250            let mut buf = &bytes[1..];
251
252            match SnapProtocolMessage::decode(adjusted_message_id, &mut buf) {
253                Ok(snap_msg) => Ok(EthSnapMessage::Snap(snap_msg)),
254                Err(err) => Err(EthSnapStreamError::Rlp(err)),
255            }
256        } else {
257            Err(EthSnapStreamError::UnknownMessageId(message_id))
258        }
259    }
260
261    /// Encode an eth message
262    fn encode_eth_message(&self, item: EthMessage<N>) -> Result<Bytes, EthSnapStreamError> {
263        if matches!(item, EthMessage::Status(_)) {
264            return Err(EthSnapStreamError::StatusNotInHandshake);
265        }
266
267        let protocol_msg = ProtocolMessage::from(item);
268        let mut buf = Vec::new();
269        protocol_msg.encode(&mut buf);
270        Ok(Bytes::from(buf))
271    }
272
273    /// Encode a snap protocol message, adjusting the message ID to follow eth message IDs
274    /// for proper multiplexing.
275    fn encode_snap_message(&self, message: SnapProtocolMessage) -> Bytes {
276        let encoded = message.encode();
277
278        let message_id = encoded[0];
279        let adjusted_id = message_id + EthMessageID::message_count(self.eth_version);
280
281        let mut adjusted = Vec::with_capacity(encoded.len());
282        adjusted.push(adjusted_id);
283        adjusted.extend_from_slice(&encoded[1..]);
284
285        Bytes::from(adjusted)
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::{EthMessage, SnapProtocolMessage};
293    use alloy_eips::BlockHashOrNumber;
294    use alloy_primitives::B256;
295    use alloy_rlp::Encodable;
296    use reth_eth_wire_types::{
297        message::RequestPair, GetAccountRangeMessage, GetBlockAccessLists, GetBlockHeaders,
298        HeadersDirection,
299    };
300
301    // Helper to create eth message and its bytes
302    fn create_eth_message() -> (EthMessage<EthNetworkPrimitives>, BytesMut) {
303        let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockHeaders(RequestPair {
304            request_id: 1,
305            message: GetBlockHeaders {
306                start_block: BlockHashOrNumber::Number(1),
307                limit: 10,
308                skip: 0,
309                direction: HeadersDirection::Rising,
310            },
311        });
312
313        let protocol_msg = ProtocolMessage::from(eth_msg.clone());
314        let mut buf = Vec::new();
315        protocol_msg.encode(&mut buf);
316
317        (eth_msg, BytesMut::from(&buf[..]))
318    }
319
320    // Helper to create snap message and its bytes
321    fn create_snap_message() -> (SnapProtocolMessage, BytesMut) {
322        let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
323            request_id: 1,
324            root_hash: B256::default(),
325            starting_hash: B256::default(),
326            limit_hash: B256::default(),
327            response_bytes: 1000,
328        });
329
330        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
331        let encoded = inner.encode_snap_message(snap_msg.clone());
332
333        (snap_msg, BytesMut::from(&encoded[..]))
334    }
335
336    #[test]
337    fn test_eth_message_roundtrip() {
338        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
339        let (eth_msg, eth_bytes) = create_eth_message();
340
341        // Verify encoding
342        let encoded_result = inner.encode_eth_message(eth_msg.clone());
343        assert!(encoded_result.is_ok());
344
345        // Verify decoding
346        let decoded_result = inner.decode_message(eth_bytes.clone());
347        assert!(matches!(decoded_result, Ok(EthSnapMessage::Eth(_))));
348
349        // round trip
350        if let Ok(EthSnapMessage::Eth(decoded_msg)) = inner.decode_message(eth_bytes) {
351            assert_eq!(decoded_msg, eth_msg);
352
353            let re_encoded = inner.encode_eth_message(decoded_msg.clone()).unwrap();
354            let re_encoded_bytes = BytesMut::from(&re_encoded[..]);
355            let re_decoded = inner.decode_message(re_encoded_bytes);
356
357            assert!(matches!(re_decoded, Ok(EthSnapMessage::Eth(_))));
358            if let Ok(EthSnapMessage::Eth(final_msg)) = re_decoded {
359                assert_eq!(final_msg, decoded_msg);
360            }
361        }
362    }
363
364    #[test]
365    fn test_snap_protocol() {
366        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
367        let (snap_msg, snap_bytes) = create_snap_message();
368
369        // Verify encoding
370        let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
371        assert!(!encoded_bytes.is_empty());
372
373        // Verify decoding
374        let decoded_result = inner.decode_message(snap_bytes.clone());
375        assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
376
377        // round trip
378        if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
379            assert_eq!(decoded_msg, snap_msg);
380
381            // re-encode message
382            let encoded = inner.encode_snap_message(decoded_msg.clone());
383
384            let re_encoded_bytes = BytesMut::from(&encoded[..]);
385
386            // decode with properly adjusted ID
387            let re_decoded = inner.decode_message(re_encoded_bytes);
388
389            assert!(matches!(re_decoded, Ok(EthSnapMessage::Snap(_))));
390            if let Ok(EthSnapMessage::Snap(final_msg)) = re_decoded {
391                assert_eq!(final_msg, decoded_msg);
392            }
393        }
394    }
395
396    #[test]
397    fn test_message_id_boundaries() {
398        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
399
400        // Create a bytes buffer with eth message ID at the max boundary with minimal content
401        let eth_max_id = EthMessageID::max(EthVersion::Eth67);
402        let mut eth_boundary_bytes = BytesMut::new();
403        eth_boundary_bytes.extend_from_slice(&[eth_max_id]);
404        eth_boundary_bytes.extend_from_slice(&[0, 0]);
405
406        // This should be decoded as eth message
407        let eth_boundary_result = inner.decode_message(eth_boundary_bytes);
408        assert!(
409            eth_boundary_result.is_err() ||
410                matches!(eth_boundary_result, Ok(EthSnapMessage::Eth(_)))
411        );
412
413        // Create a bytes buffer with message ID just above eth max, it should be snap min
414        let snap_min_id = eth_max_id + 1;
415        let mut snap_boundary_bytes = BytesMut::new();
416        snap_boundary_bytes.extend_from_slice(&[snap_min_id]);
417        snap_boundary_bytes.extend_from_slice(&[0, 0]);
418
419        // Not a valid snap message yet, only snap id --> error
420        let snap_boundary_result = inner.decode_message(snap_boundary_bytes);
421        assert!(snap_boundary_result.is_err());
422    }
423
424    #[test]
425    fn test_eth70_message_id_0x12_is_snap() {
426        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth70);
427        let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
428            request_id: 1,
429            root_hash: B256::default(),
430            starting_hash: B256::default(),
431            limit_hash: B256::default(),
432            response_bytes: 1000,
433        });
434
435        let encoded = inner.encode_snap_message(snap_msg);
436        assert_eq!(encoded[0], EthMessageID::message_count(EthVersion::Eth70));
437
438        let decoded = inner.decode_message(BytesMut::from(&encoded[..])).unwrap();
439        assert!(matches!(decoded, EthSnapMessage::Snap(_)));
440    }
441
442    #[test]
443    fn test_eth71_message_id_0x12_is_eth() {
444        let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth71);
445        let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockAccessLists(RequestPair {
446            request_id: 1,
447            message: GetBlockAccessLists(vec![B256::ZERO]),
448        });
449        let protocol_msg = ProtocolMessage::from(eth_msg.clone());
450        let mut buf = Vec::new();
451        protocol_msg.encode(&mut buf);
452
453        let decoded = inner.decode_message(BytesMut::from(&buf[..])).unwrap();
454        let EthSnapMessage::Eth(decoded_eth) = decoded else {
455            panic!("expected eth message");
456        };
457        assert_eq!(decoded_eth, eth_msg);
458    }
459}