Skip to main content

reth_eth_wire/
ethstream.rs

1//! Ethereum protocol stream implementations.
2//!
3//! Provides stream types for the Ethereum wire protocol.
4//! It separates protocol logic [`EthStreamInner`] from transport concerns [`EthStream`].
5//! Handles handshaking, message processing, and RLP serialization.
6
7use crate::{
8    errors::{EthHandshakeError, EthStreamError},
9    handshake::EthereumEthHandshake,
10    message::{
11        EthBroadcastMessage, ProtocolBroadcastMessage, MAX_MESSAGE_SIZE,
12        TX_MEMORY_BUDGET_MULTIPLIER,
13    },
14    p2pstream::HANDSHAKE_TIMEOUT,
15    CanDisconnect, DisconnectReason, EthMessage, EthNetworkPrimitives, EthVersion, ProtocolMessage,
16    UnifiedStatus,
17};
18use alloy_primitives::bytes::{Bytes, BytesMut};
19use alloy_rlp::Encodable;
20use futures::{ready, Sink, SinkExt};
21use pin_project::pin_project;
22use reth_eth_wire_types::{EthMessageID, NetworkPrimitives, RawCapabilityMessage};
23use reth_ethereum_forks::ForkFilter;
24use std::{
25    future::Future,
26    pin::Pin,
27    task::{Context, Poll},
28    time::Duration,
29};
30use tokio::time::timeout;
31use tokio_stream::Stream;
32use tracing::{debug, trace};
33
34/// An un-authenticated [`EthStream`]. This is consumed and returns a [`EthStream`] after the
35/// `Status` handshake is completed.
36#[pin_project]
37#[derive(Debug)]
38pub struct UnauthedEthStream<S> {
39    #[pin]
40    inner: S,
41}
42
43impl<S> UnauthedEthStream<S> {
44    /// Create a new `UnauthedEthStream` from a type `S` which implements `Stream` and `Sink`.
45    pub const fn new(inner: S) -> Self {
46        Self { inner }
47    }
48
49    /// Consumes the type and returns the wrapped stream
50    pub fn into_inner(self) -> S {
51        self.inner
52    }
53}
54
55impl<S, E> UnauthedEthStream<S>
56where
57    S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
58    EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
59{
60    /// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
61    /// handshake is completed successfully. This also returns the `Status` message sent by the
62    /// remote peer.
63    ///
64    /// Caution: This expects that the [`UnifiedStatus`] has the proper eth version configured, with
65    /// ETH69 the initial status message changed.
66    pub async fn handshake<N: NetworkPrimitives>(
67        self,
68        status: UnifiedStatus,
69        fork_filter: ForkFilter,
70    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
71        self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
72    }
73
74    /// Wrapper around handshake which enforces a timeout.
75    pub async fn handshake_with_timeout<N: NetworkPrimitives>(
76        self,
77        status: UnifiedStatus,
78        fork_filter: ForkFilter,
79        timeout_limit: Duration,
80    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
81        timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
82            .await
83            .map_err(|_| EthStreamError::StreamTimeout)?
84    }
85
86    /// Handshake with no timeout
87    pub async fn handshake_without_timeout<N: NetworkPrimitives>(
88        mut self,
89        status: UnifiedStatus,
90        fork_filter: ForkFilter,
91    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
92        trace!(
93            status = %status.into_message(),
94            "sending eth status to peer"
95        );
96        let their_status =
97            EthereumEthHandshake(&mut self.inner).eth_handshake(status, fork_filter).await?;
98
99        // now we can create the `EthStream` because the peer has successfully completed
100        // the handshake
101        let stream = EthStream::new(status.version, self.inner);
102
103        Ok((stream, their_status))
104    }
105}
106
107/// Contains eth protocol specific logic for processing messages
108#[derive(Debug)]
109pub struct EthStreamInner<N> {
110    /// Negotiated eth version
111    version: EthVersion,
112    /// Maximum allowed ETH message size.
113    max_message_size: usize,
114    /// When true, `NewBlock` (0x07) and `NewBlockHashes` (0x01) messages are rejected before RLP
115    /// decoding to avoid any memory impact for non-PoW networks.
116    reject_block_announcements: bool,
117    _pd: std::marker::PhantomData<N>,
118}
119
120impl<N> EthStreamInner<N>
121where
122    N: NetworkPrimitives,
123{
124    /// Creates a new [`EthStreamInner`] with the given eth version
125    pub const fn new(version: EthVersion) -> Self {
126        Self::with_max_message_size(version, MAX_MESSAGE_SIZE)
127    }
128
129    /// Creates a new [`EthStreamInner`] with the given eth version and message size limit.
130    pub const fn with_max_message_size(version: EthVersion, max_message_size: usize) -> Self {
131        Self {
132            version,
133            max_message_size,
134            reject_block_announcements: false,
135            _pd: std::marker::PhantomData,
136        }
137    }
138
139    /// Returns the eth version
140    #[inline]
141    pub const fn version(&self) -> EthVersion {
142        self.version
143    }
144
145    /// Sets whether to reject block announcement messages (`NewBlock`, `NewBlockHashes`) before
146    /// RLP decoding.
147    pub const fn set_reject_block_announcements(&mut self, reject: bool) {
148        self.reject_block_announcements = reject;
149    }
150
151    /// Decodes incoming bytes into an [`EthMessage`].
152    pub fn decode_message(&self, bytes: BytesMut) -> Result<EthMessage<N>, EthStreamError> {
153        if bytes.len() > self.max_message_size {
154            return Err(EthStreamError::MessageTooBig(bytes.len()));
155        }
156
157        if self.reject_block_announcements &&
158            let Some(&id) = bytes.first() &&
159            (id == EthMessageID::NewBlock.to_u8() || id == EthMessageID::NewBlockHashes.to_u8())
160        {
161            return Err(EthStreamError::UnsupportedMessage { message_id: id });
162        }
163
164        let msg = match ProtocolMessage::decode_message_with_tx_memory_budget(
165            self.version,
166            &mut bytes.as_ref(),
167            self.max_message_size * TX_MEMORY_BUDGET_MULTIPLIER,
168        ) {
169            Ok(m) => m,
170            Err(err) => {
171                let msg = if bytes.len() > 50 {
172                    format!("{:02x?}...{:x?}", &bytes[..10], &bytes[bytes.len() - 10..])
173                } else {
174                    format!("{bytes:02x?}")
175                };
176                debug!(
177                    version=?self.version,
178                    %msg,
179                    "failed to decode protocol message"
180                );
181                return Err(EthStreamError::InvalidMessage(err));
182            }
183        };
184
185        if matches!(msg.message, EthMessage::Status(_)) {
186            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
187        }
188
189        Ok(msg.message)
190    }
191
192    /// Encodes an [`EthMessage`] to bytes.
193    ///
194    /// Validates that Status messages are not sent after handshake, enforcing protocol rules.
195    pub fn encode_message(&self, item: EthMessage<N>) -> Result<Bytes, EthStreamError> {
196        if matches!(item, EthMessage::Status(_)) {
197            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
198        }
199
200        Ok(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))
201    }
202}
203
204/// An `EthStream` wraps over any `Stream` that yields bytes and makes it
205/// compatible with eth-networking protocol messages, which get RLP encoded/decoded.
206#[pin_project]
207#[derive(Debug)]
208pub struct EthStream<S, N = EthNetworkPrimitives> {
209    /// Eth-specific logic
210    eth: EthStreamInner<N>,
211    #[pin]
212    inner: S,
213}
214
215impl<S, N: NetworkPrimitives> EthStream<S, N> {
216    /// Creates a new unauthed [`EthStream`] from a provided stream. You will need
217    /// to manually handshake a peer.
218    #[inline]
219    pub const fn new(version: EthVersion, inner: S) -> Self {
220        Self::with_max_message_size(version, inner, MAX_MESSAGE_SIZE)
221    }
222
223    /// Creates a new unauthed [`EthStream`] with a custom max message size.
224    #[inline]
225    pub const fn with_max_message_size(
226        version: EthVersion,
227        inner: S,
228        max_message_size: usize,
229    ) -> Self {
230        Self { eth: EthStreamInner::with_max_message_size(version, max_message_size), inner }
231    }
232
233    /// Returns the eth version.
234    #[inline]
235    pub const fn version(&self) -> EthVersion {
236        self.eth.version()
237    }
238
239    /// Sets whether to reject block announcement messages (`NewBlock`, `NewBlockHashes`) before
240    /// RLP decoding.
241    pub const fn set_reject_block_announcements(&mut self, reject: bool) {
242        self.eth.set_reject_block_announcements(reject);
243    }
244
245    /// Returns the underlying stream.
246    #[inline]
247    pub const fn inner(&self) -> &S {
248        &self.inner
249    }
250
251    /// Returns mutable access to the underlying stream.
252    #[inline]
253    pub const fn inner_mut(&mut self) -> &mut S {
254        &mut self.inner
255    }
256
257    /// Consumes this type and returns the wrapped stream.
258    #[inline]
259    pub fn into_inner(self) -> S {
260        self.inner
261    }
262}
263
264impl<S, E, N> EthStream<S, N>
265where
266    S: Sink<Bytes, Error = E> + Unpin,
267    EthStreamError: From<E>,
268    N: NetworkPrimitives,
269{
270    /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
271    pub fn start_send_broadcast(
272        &mut self,
273        item: EthBroadcastMessage<N>,
274    ) -> Result<(), EthStreamError> {
275        self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
276            ProtocolBroadcastMessage::from(item),
277        )))?;
278
279        Ok(())
280    }
281
282    /// Sends a raw capability message directly over the stream
283    pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthStreamError> {
284        let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
285        msg.id.encode(&mut bytes);
286        bytes.extend_from_slice(&msg.payload);
287
288        self.inner.start_send_unpin(bytes.into())?;
289        Ok(())
290    }
291}
292
293impl<S, E, N> Stream for EthStream<S, N>
294where
295    S: Stream<Item = Result<BytesMut, E>> + Unpin,
296    EthStreamError: From<E>,
297    N: NetworkPrimitives,
298{
299    type Item = Result<EthMessage<N>, EthStreamError>;
300
301    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302        let this = self.project();
303        let res = ready!(this.inner.poll_next(cx));
304
305        match res {
306            Some(Ok(bytes)) => Poll::Ready(Some(this.eth.decode_message(bytes))),
307            Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
308            None => Poll::Ready(None),
309        }
310    }
311}
312
313impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
314where
315    S: CanDisconnect<Bytes> + Unpin,
316    EthStreamError: From<<S as Sink<Bytes>>::Error>,
317    N: NetworkPrimitives,
318{
319    type Error = EthStreamError;
320
321    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
322        self.project().inner.poll_ready(cx).map_err(Into::into)
323    }
324
325    fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
326        if matches!(item, EthMessage::Status(_)) {
327            // Attempt to disconnect the peer for protocol breach when trying to send Status
328            // message after handshake is complete
329            let mut this = self.project();
330            // We can't await the disconnect future here since this is a synchronous method,
331            // but we can start the disconnect process. The actual disconnect will be handled
332            // asynchronously by the caller or the stream's poll methods.
333            let _disconnect_future = this.inner.disconnect(DisconnectReason::ProtocolBreach);
334            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
335        }
336
337        self.project()
338            .inner
339            .start_send(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))?;
340
341        Ok(())
342    }
343
344    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
345        self.project().inner.poll_flush(cx).map_err(Into::into)
346    }
347
348    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
349        self.project().inner.poll_close(cx).map_err(Into::into)
350    }
351}
352
353impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
354where
355    S: CanDisconnect<Bytes> + Send,
356    EthStreamError: From<<S as Sink<Bytes>>::Error>,
357    N: NetworkPrimitives,
358{
359    fn disconnect(
360        &mut self,
361        reason: DisconnectReason,
362    ) -> Pin<Box<dyn Future<Output = Result<(), EthStreamError>> + Send + '_>> {
363        Box::pin(async move { self.inner.disconnect(reason).await.map_err(Into::into) })
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::UnauthedEthStream;
370    use crate::{
371        broadcast::BlockHashNumber,
372        errors::{EthHandshakeError, EthStreamError},
373        ethstream::RawCapabilityMessage,
374        hello::DEFAULT_TCP_PORT,
375        p2pstream::UnauthedP2PStream,
376        EthMessage, EthStream, EthVersion, HelloMessageWithProtocols, PassthroughCodec,
377        ProtocolVersion, Status, StatusMessage,
378    };
379    use alloy_chains::NamedChain;
380    use alloy_primitives::{bytes::Bytes, B256, U256};
381    use alloy_rlp::Decodable;
382    use futures::{SinkExt, StreamExt};
383    use reth_ecies::stream::ECIESStream;
384    use reth_eth_wire_types::{EthNetworkPrimitives, UnifiedStatus};
385    use reth_ethereum_forks::{ForkFilter, Head};
386    use reth_network_peers::pk2id;
387    use secp256k1::{SecretKey, SECP256K1};
388    use std::time::Duration;
389    use tokio::net::{TcpListener, TcpStream};
390    use tokio_util::codec::Decoder;
391
392    #[tokio::test]
393    async fn can_handshake() {
394        let genesis = B256::random();
395        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
396
397        let status = Status {
398            version: EthVersion::Eth67,
399            chain: NamedChain::Mainnet.into(),
400            total_difficulty: U256::ZERO,
401            blockhash: B256::random(),
402            genesis,
403            // Pass the current fork id.
404            forkid: fork_filter.current(),
405        };
406        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
407
408        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
409        let local_addr = listener.local_addr().unwrap();
410
411        let status_clone = unified_status;
412        let fork_filter_clone = fork_filter.clone();
413        let handle = tokio::spawn(async move {
414            // roughly based off of the design of tokio::net::TcpListener
415            let (incoming, _) = listener.accept().await.unwrap();
416            let stream = PassthroughCodec::default().framed(incoming);
417            let (_, their_status) = UnauthedEthStream::new(stream)
418                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
419                .await
420                .unwrap();
421
422            // just make sure it equals our status (our status is a clone of their status)
423            assert_eq!(their_status, status_clone);
424        });
425
426        let outgoing = TcpStream::connect(local_addr).await.unwrap();
427        let sink = PassthroughCodec::default().framed(outgoing);
428
429        // try to connect
430        let (_, their_status) = UnauthedEthStream::new(sink)
431            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
432            .await
433            .unwrap();
434
435        // their status is a clone of our status, these should be equal
436        assert_eq!(their_status, unified_status);
437
438        // wait for it to finish
439        handle.await.unwrap();
440    }
441
442    #[tokio::test]
443    async fn pass_handshake_on_low_td_bitlen() {
444        let genesis = B256::random();
445        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
446
447        let status = Status {
448            version: EthVersion::Eth67,
449            chain: NamedChain::Mainnet.into(),
450            total_difficulty: U256::from(2).pow(U256::from(100)) - U256::from(1),
451            blockhash: B256::random(),
452            genesis,
453            // Pass the current fork id.
454            forkid: fork_filter.current(),
455        };
456        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
457
458        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
459        let local_addr = listener.local_addr().unwrap();
460
461        let status_clone = unified_status;
462        let fork_filter_clone = fork_filter.clone();
463        let handle = tokio::spawn(async move {
464            // roughly based off of the design of tokio::net::TcpListener
465            let (incoming, _) = listener.accept().await.unwrap();
466            let stream = PassthroughCodec::default().framed(incoming);
467            let (_, their_status) = UnauthedEthStream::new(stream)
468                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
469                .await
470                .unwrap();
471
472            // just make sure it equals our status, and that the handshake succeeded
473            assert_eq!(their_status, status_clone);
474        });
475
476        let outgoing = TcpStream::connect(local_addr).await.unwrap();
477        let sink = PassthroughCodec::default().framed(outgoing);
478
479        // try to connect
480        let (_, their_status) = UnauthedEthStream::new(sink)
481            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
482            .await
483            .unwrap();
484
485        // their status is a clone of our status, these should be equal
486        assert_eq!(their_status, unified_status);
487
488        // await the other handshake
489        handle.await.unwrap();
490    }
491
492    #[tokio::test]
493    async fn fail_handshake_on_high_td_bitlen() {
494        let genesis = B256::random();
495        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
496
497        let status = Status {
498            version: EthVersion::Eth67,
499            chain: NamedChain::Mainnet.into(),
500            total_difficulty: U256::from(2).pow(U256::from(164)),
501            blockhash: B256::random(),
502            genesis,
503            // Pass the current fork id.
504            forkid: fork_filter.current(),
505        };
506        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
507
508        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
509        let local_addr = listener.local_addr().unwrap();
510
511        let status_clone = unified_status;
512        let fork_filter_clone = fork_filter.clone();
513        let handle = tokio::spawn(async move {
514            // roughly based off of the design of tokio::net::TcpListener
515            let (incoming, _) = listener.accept().await.unwrap();
516            let stream = PassthroughCodec::default().framed(incoming);
517            let handshake_res = UnauthedEthStream::new(stream)
518                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
519                .await;
520
521            // make sure the handshake fails due to td too high
522            assert!(matches!(
523                handshake_res,
524                Err(EthStreamError::EthHandshakeError(
525                    EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
526                ))
527            ));
528        });
529
530        let outgoing = TcpStream::connect(local_addr).await.unwrap();
531        let sink = PassthroughCodec::default().framed(outgoing);
532
533        // try to connect
534        let handshake_res = UnauthedEthStream::new(sink)
535            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
536            .await;
537
538        // this handshake should also fail due to td too high
539        assert!(matches!(
540            handshake_res,
541            Err(EthStreamError::EthHandshakeError(
542                EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
543            ))
544        ));
545
546        // await the other handshake
547        handle.await.unwrap();
548    }
549
550    #[tokio::test]
551    async fn can_write_and_read_cleartext() {
552        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
553        let local_addr = listener.local_addr().unwrap();
554        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
555            vec![
556                BlockHashNumber { hash: B256::random(), number: 5 },
557                BlockHashNumber { hash: B256::random(), number: 6 },
558            ]
559            .into(),
560        );
561
562        let test_msg_clone = test_msg.clone();
563        let handle = tokio::spawn(async move {
564            // roughly based off of the design of tokio::net::TcpListener
565            let (incoming, _) = listener.accept().await.unwrap();
566            let stream = PassthroughCodec::default().framed(incoming);
567            let mut stream = EthStream::new(EthVersion::Eth67, stream);
568
569            // use the stream to get the next message
570            let message = stream.next().await.unwrap().unwrap();
571            assert_eq!(message, test_msg_clone);
572        });
573
574        let outgoing = TcpStream::connect(local_addr).await.unwrap();
575        let sink = PassthroughCodec::default().framed(outgoing);
576        let mut client_stream = EthStream::new(EthVersion::Eth67, sink);
577
578        client_stream.send(test_msg).await.unwrap();
579
580        // make sure the server receives the message and asserts before ending the test
581        handle.await.unwrap();
582    }
583
584    #[tokio::test]
585    async fn can_write_and_read_ecies() {
586        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
587        let local_addr = listener.local_addr().unwrap();
588        let server_key = SecretKey::new(&mut rand_08::thread_rng());
589        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
590            vec![
591                BlockHashNumber { hash: B256::random(), number: 5 },
592                BlockHashNumber { hash: B256::random(), number: 6 },
593            ]
594            .into(),
595        );
596
597        let test_msg_clone = test_msg.clone();
598        let handle = tokio::spawn(async move {
599            // roughly based off of the design of tokio::net::TcpListener
600            let (incoming, _) = listener.accept().await.unwrap();
601            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
602            let mut stream = EthStream::new(EthVersion::Eth67, stream);
603
604            // use the stream to get the next message
605            let message = stream.next().await.unwrap().unwrap();
606            assert_eq!(message, test_msg_clone);
607        });
608
609        // create the server pubkey
610        let server_id = pk2id(&server_key.public_key(SECP256K1));
611
612        let client_key = SecretKey::new(&mut rand_08::thread_rng());
613
614        let outgoing = TcpStream::connect(local_addr).await.unwrap();
615        let outgoing = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
616        let mut client_stream = EthStream::new(EthVersion::Eth67, outgoing);
617
618        client_stream.send(test_msg).await.unwrap();
619
620        // make sure the server receives the message and asserts before ending the test
621        handle.await.unwrap();
622    }
623
624    #[tokio::test(flavor = "multi_thread")]
625    async fn ethstream_over_p2p() {
626        // create a p2p stream and server, then confirm that the two are authed
627        // create tcpstream
628        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
629        let local_addr = listener.local_addr().unwrap();
630        let server_key = SecretKey::new(&mut rand_08::thread_rng());
631        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
632            vec![
633                BlockHashNumber { hash: B256::random(), number: 5 },
634                BlockHashNumber { hash: B256::random(), number: 6 },
635            ]
636            .into(),
637        );
638
639        let genesis = B256::random();
640        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
641
642        let status = Status {
643            version: EthVersion::Eth67,
644            chain: NamedChain::Mainnet.into(),
645            total_difficulty: U256::ZERO,
646            blockhash: B256::random(),
647            genesis,
648            // Pass the current fork id.
649            forkid: fork_filter.current(),
650        };
651        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
652
653        let status_copy = unified_status;
654        let fork_filter_clone = fork_filter.clone();
655        let test_msg_clone = test_msg.clone();
656        let handle = tokio::spawn(async move {
657            // roughly based off of the design of tokio::net::TcpListener
658            let (incoming, _) = listener.accept().await.unwrap();
659            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
660
661            let server_hello = HelloMessageWithProtocols {
662                protocol_version: ProtocolVersion::V5,
663                client_version: "bitcoind/1.0.0".to_string(),
664                protocols: vec![EthVersion::Eth67.into()],
665                port: DEFAULT_TCP_PORT,
666                id: pk2id(&server_key.public_key(SECP256K1)),
667            };
668
669            let unauthed_stream = UnauthedP2PStream::new(stream);
670            let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
671            let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream)
672                .handshake(status_copy, fork_filter_clone)
673                .await
674                .unwrap();
675
676            // use the stream to get the next message
677            let message = eth_stream.next().await.unwrap().unwrap();
678            assert_eq!(message, test_msg_clone);
679        });
680
681        // create the server pubkey
682        let server_id = pk2id(&server_key.public_key(SECP256K1));
683
684        let client_key = SecretKey::new(&mut rand_08::thread_rng());
685
686        let outgoing = TcpStream::connect(local_addr).await.unwrap();
687        let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
688
689        let client_hello = HelloMessageWithProtocols {
690            protocol_version: ProtocolVersion::V5,
691            client_version: "bitcoind/1.0.0".to_string(),
692            protocols: vec![EthVersion::Eth67.into()],
693            port: DEFAULT_TCP_PORT,
694            id: pk2id(&client_key.public_key(SECP256K1)),
695        };
696
697        let unauthed_stream = UnauthedP2PStream::new(sink);
698        let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
699
700        let (mut client_stream, _) = UnauthedEthStream::new(p2p_stream)
701            .handshake(unified_status, fork_filter)
702            .await
703            .unwrap();
704
705        client_stream.send(test_msg).await.unwrap();
706
707        // make sure the server receives the message and asserts before ending the test
708        handle.await.unwrap();
709    }
710
711    #[tokio::test]
712    async fn handshake_should_timeout() {
713        let genesis = B256::random();
714        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
715
716        let status = Status {
717            version: EthVersion::Eth67,
718            chain: NamedChain::Mainnet.into(),
719            total_difficulty: U256::ZERO,
720            blockhash: B256::random(),
721            genesis,
722            // Pass the current fork id.
723            forkid: fork_filter.current(),
724        };
725        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
726
727        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
728        let local_addr = listener.local_addr().unwrap();
729
730        let status_clone = unified_status;
731        let fork_filter_clone = fork_filter.clone();
732        let _handle = tokio::spawn(async move {
733            // Delay accepting the connection for longer than the client's timeout period
734            tokio::time::sleep(Duration::from_secs(11)).await;
735            // roughly based off of the design of tokio::net::TcpListener
736            let (incoming, _) = listener.accept().await.unwrap();
737            let stream = PassthroughCodec::default().framed(incoming);
738            let (_, their_status) = UnauthedEthStream::new(stream)
739                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
740                .await
741                .unwrap();
742
743            // just make sure it equals our status (our status is a clone of their status)
744            assert_eq!(their_status, status_clone);
745        });
746
747        let outgoing = TcpStream::connect(local_addr).await.unwrap();
748        let sink = PassthroughCodec::default().framed(outgoing);
749
750        // try to connect
751        let handshake_result = UnauthedEthStream::new(sink)
752            .handshake_with_timeout::<EthNetworkPrimitives>(
753                unified_status,
754                fork_filter,
755                Duration::from_secs(1),
756            )
757            .await;
758
759        // Assert that a timeout error occurred
760        assert!(
761            matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
762        );
763    }
764
765    #[tokio::test]
766    async fn can_write_and_read_raw_capability() {
767        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
768        let local_addr = listener.local_addr().unwrap();
769
770        let test_msg = RawCapabilityMessage { id: 0x1234, payload: Bytes::from(vec![1, 2, 3, 4]) };
771
772        let test_msg_clone = test_msg.clone();
773        let handle = tokio::spawn(async move {
774            let (incoming, _) = listener.accept().await.unwrap();
775            let stream = PassthroughCodec::default().framed(incoming);
776            let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
777
778            let bytes = stream.inner_mut().next().await.unwrap().unwrap();
779
780            // Create a cursor to track position while decoding
781            let mut id_bytes = &bytes[..];
782            let decoded_id = <usize as Decodable>::decode(&mut id_bytes).unwrap();
783            assert_eq!(decoded_id, test_msg_clone.id);
784
785            // Get remaining bytes after ID decoding
786            let remaining = id_bytes;
787            assert_eq!(remaining, &test_msg_clone.payload[..]);
788        });
789
790        let outgoing = TcpStream::connect(local_addr).await.unwrap();
791        let sink = PassthroughCodec::default().framed(outgoing);
792        let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
793
794        client_stream.start_send_raw(test_msg).unwrap();
795        client_stream.inner_mut().flush().await.unwrap();
796
797        handle.await.unwrap();
798    }
799
800    #[tokio::test]
801    async fn status_message_after_handshake_triggers_disconnect() {
802        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
803        let local_addr = listener.local_addr().unwrap();
804
805        let handle = tokio::spawn(async move {
806            let (incoming, _) = listener.accept().await.unwrap();
807            let stream = PassthroughCodec::default().framed(incoming);
808            let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
809
810            // Try to send a Status message after handshake - this should trigger disconnect
811            let status = Status {
812                version: EthVersion::Eth67,
813                chain: NamedChain::Mainnet.into(),
814                total_difficulty: U256::ZERO,
815                blockhash: B256::random(),
816                genesis: B256::random(),
817                forkid: ForkFilter::new(Head::default(), B256::random(), 0, Vec::new()).current(),
818            };
819            let status_message =
820                EthMessage::<EthNetworkPrimitives>::Status(StatusMessage::Legacy(status));
821
822            // This should return an error and trigger disconnect
823            let result = stream.send(status_message).await;
824            assert!(result.is_err());
825            assert!(matches!(
826                result.unwrap_err(),
827                EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake)
828            ));
829        });
830
831        let outgoing = TcpStream::connect(local_addr).await.unwrap();
832        let sink = PassthroughCodec::default().framed(outgoing);
833        let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
834
835        // Send a valid message to keep the connection alive
836        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
837            vec![BlockHashNumber { hash: B256::random(), number: 5 }].into(),
838        );
839        client_stream.send(test_msg).await.unwrap();
840
841        handle.await.unwrap();
842    }
843}