Skip to main content

reth_eth_wire/
ethstream.rs

1//! Ethereum protocol stream implementations.
2//!
3//! Provides stream types for the Ethereum wire protocol.
4//! It separates protocol logic [`EthStreamInner`] from transport concerns [`EthStream`].
5//! Handles handshaking, message processing, and RLP serialization.
6
7use crate::{
8    errors::{EthHandshakeError, EthStreamError},
9    handshake::EthereumEthHandshake,
10    message::{EthBroadcastMessage, ProtocolBroadcastMessage, MAX_MESSAGE_SIZE},
11    p2pstream::HANDSHAKE_TIMEOUT,
12    CanDisconnect, DisconnectReason, EthMessage, EthNetworkPrimitives, EthVersion, ProtocolMessage,
13    UnifiedStatus,
14};
15use alloy_primitives::bytes::{Bytes, BytesMut};
16use alloy_rlp::Encodable;
17use futures::{ready, Sink, SinkExt};
18use pin_project::pin_project;
19use reth_eth_wire_types::{NetworkPrimitives, RawCapabilityMessage};
20use reth_ethereum_forks::ForkFilter;
21use std::{
22    future::Future,
23    pin::Pin,
24    task::{Context, Poll},
25    time::Duration,
26};
27use tokio::time::timeout;
28use tokio_stream::Stream;
29use tracing::{debug, trace};
30
31/// An un-authenticated [`EthStream`]. This is consumed and returns a [`EthStream`] after the
32/// `Status` handshake is completed.
33#[pin_project]
34#[derive(Debug)]
35pub struct UnauthedEthStream<S> {
36    #[pin]
37    inner: S,
38}
39
40impl<S> UnauthedEthStream<S> {
41    /// Create a new `UnauthedEthStream` from a type `S` which implements `Stream` and `Sink`.
42    pub const fn new(inner: S) -> Self {
43        Self { inner }
44    }
45
46    /// Consumes the type and returns the wrapped stream
47    pub fn into_inner(self) -> S {
48        self.inner
49    }
50}
51
52impl<S, E> UnauthedEthStream<S>
53where
54    S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
55    EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
56{
57    /// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
58    /// handshake is completed successfully. This also returns the `Status` message sent by the
59    /// remote peer.
60    ///
61    /// Caution: This expects that the [`UnifiedStatus`] has the proper eth version configured, with
62    /// ETH69 the initial status message changed.
63    pub async fn handshake<N: NetworkPrimitives>(
64        self,
65        status: UnifiedStatus,
66        fork_filter: ForkFilter,
67    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
68        self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
69    }
70
71    /// Wrapper around handshake which enforces a timeout.
72    pub async fn handshake_with_timeout<N: NetworkPrimitives>(
73        self,
74        status: UnifiedStatus,
75        fork_filter: ForkFilter,
76        timeout_limit: Duration,
77    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
78        timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
79            .await
80            .map_err(|_| EthStreamError::StreamTimeout)?
81    }
82
83    /// Handshake with no timeout
84    pub async fn handshake_without_timeout<N: NetworkPrimitives>(
85        mut self,
86        status: UnifiedStatus,
87        fork_filter: ForkFilter,
88    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
89        trace!(
90            status = %status.into_message(),
91            "sending eth status to peer"
92        );
93        let their_status =
94            EthereumEthHandshake(&mut self.inner).eth_handshake(status, fork_filter).await?;
95
96        // now we can create the `EthStream` because the peer has successfully completed
97        // the handshake
98        let stream = EthStream::new(status.version, self.inner);
99
100        Ok((stream, their_status))
101    }
102}
103
104/// Contains eth protocol specific logic for processing messages
105#[derive(Debug)]
106pub struct EthStreamInner<N> {
107    /// Negotiated eth version
108    version: EthVersion,
109    /// Maximum allowed ETH message size.
110    max_message_size: usize,
111    _pd: std::marker::PhantomData<N>,
112}
113
114impl<N> EthStreamInner<N>
115where
116    N: NetworkPrimitives,
117{
118    /// Creates a new [`EthStreamInner`] with the given eth version
119    pub const fn new(version: EthVersion) -> Self {
120        Self::with_max_message_size(version, MAX_MESSAGE_SIZE)
121    }
122
123    /// Creates a new [`EthStreamInner`] with the given eth version and message size limit.
124    pub const fn with_max_message_size(version: EthVersion, max_message_size: usize) -> Self {
125        Self { version, max_message_size, _pd: std::marker::PhantomData }
126    }
127
128    /// Returns the eth version
129    #[inline]
130    pub const fn version(&self) -> EthVersion {
131        self.version
132    }
133
134    /// Decodes incoming bytes into an [`EthMessage`].
135    pub fn decode_message(&self, bytes: BytesMut) -> Result<EthMessage<N>, EthStreamError> {
136        if bytes.len() > self.max_message_size {
137            return Err(EthStreamError::MessageTooBig(bytes.len()));
138        }
139
140        let msg = match ProtocolMessage::decode_message(self.version, &mut bytes.as_ref()) {
141            Ok(m) => m,
142            Err(err) => {
143                let msg = if bytes.len() > 50 {
144                    format!("{:02x?}...{:x?}", &bytes[..10], &bytes[bytes.len() - 10..])
145                } else {
146                    format!("{bytes:02x?}")
147                };
148                debug!(
149                    version=?self.version,
150                    %msg,
151                    "failed to decode protocol message"
152                );
153                return Err(EthStreamError::InvalidMessage(err));
154            }
155        };
156
157        if matches!(msg.message, EthMessage::Status(_)) {
158            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
159        }
160
161        Ok(msg.message)
162    }
163
164    /// Encodes an [`EthMessage`] to bytes.
165    ///
166    /// Validates that Status messages are not sent after handshake, enforcing protocol rules.
167    pub fn encode_message(&self, item: EthMessage<N>) -> Result<Bytes, EthStreamError> {
168        if matches!(item, EthMessage::Status(_)) {
169            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
170        }
171
172        Ok(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))
173    }
174}
175
176/// An `EthStream` wraps over any `Stream` that yields bytes and makes it
177/// compatible with eth-networking protocol messages, which get RLP encoded/decoded.
178#[pin_project]
179#[derive(Debug)]
180pub struct EthStream<S, N = EthNetworkPrimitives> {
181    /// Eth-specific logic
182    eth: EthStreamInner<N>,
183    #[pin]
184    inner: S,
185}
186
187impl<S, N: NetworkPrimitives> EthStream<S, N> {
188    /// Creates a new unauthed [`EthStream`] from a provided stream. You will need
189    /// to manually handshake a peer.
190    #[inline]
191    pub const fn new(version: EthVersion, inner: S) -> Self {
192        Self::with_max_message_size(version, inner, MAX_MESSAGE_SIZE)
193    }
194
195    /// Creates a new unauthed [`EthStream`] with a custom max message size.
196    #[inline]
197    pub const fn with_max_message_size(
198        version: EthVersion,
199        inner: S,
200        max_message_size: usize,
201    ) -> Self {
202        Self { eth: EthStreamInner::with_max_message_size(version, max_message_size), inner }
203    }
204
205    /// Returns the eth version.
206    #[inline]
207    pub const fn version(&self) -> EthVersion {
208        self.eth.version()
209    }
210
211    /// Returns the underlying stream.
212    #[inline]
213    pub const fn inner(&self) -> &S {
214        &self.inner
215    }
216
217    /// Returns mutable access to the underlying stream.
218    #[inline]
219    pub const fn inner_mut(&mut self) -> &mut S {
220        &mut self.inner
221    }
222
223    /// Consumes this type and returns the wrapped stream.
224    #[inline]
225    pub fn into_inner(self) -> S {
226        self.inner
227    }
228}
229
230impl<S, E, N> EthStream<S, N>
231where
232    S: Sink<Bytes, Error = E> + Unpin,
233    EthStreamError: From<E>,
234    N: NetworkPrimitives,
235{
236    /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
237    pub fn start_send_broadcast(
238        &mut self,
239        item: EthBroadcastMessage<N>,
240    ) -> Result<(), EthStreamError> {
241        self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
242            ProtocolBroadcastMessage::from(item),
243        )))?;
244
245        Ok(())
246    }
247
248    /// Sends a raw capability message directly over the stream
249    pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthStreamError> {
250        let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
251        msg.id.encode(&mut bytes);
252        bytes.extend_from_slice(&msg.payload);
253
254        self.inner.start_send_unpin(bytes.into())?;
255        Ok(())
256    }
257}
258
259impl<S, E, N> Stream for EthStream<S, N>
260where
261    S: Stream<Item = Result<BytesMut, E>> + Unpin,
262    EthStreamError: From<E>,
263    N: NetworkPrimitives,
264{
265    type Item = Result<EthMessage<N>, EthStreamError>;
266
267    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
268        let this = self.project();
269        let res = ready!(this.inner.poll_next(cx));
270
271        match res {
272            Some(Ok(bytes)) => Poll::Ready(Some(this.eth.decode_message(bytes))),
273            Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
274            None => Poll::Ready(None),
275        }
276    }
277}
278
279impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
280where
281    S: CanDisconnect<Bytes> + Unpin,
282    EthStreamError: From<<S as Sink<Bytes>>::Error>,
283    N: NetworkPrimitives,
284{
285    type Error = EthStreamError;
286
287    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
288        self.project().inner.poll_ready(cx).map_err(Into::into)
289    }
290
291    fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
292        if matches!(item, EthMessage::Status(_)) {
293            // Attempt to disconnect the peer for protocol breach when trying to send Status
294            // message after handshake is complete
295            let mut this = self.project();
296            // We can't await the disconnect future here since this is a synchronous method,
297            // but we can start the disconnect process. The actual disconnect will be handled
298            // asynchronously by the caller or the stream's poll methods.
299            let _disconnect_future = this.inner.disconnect(DisconnectReason::ProtocolBreach);
300            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
301        }
302
303        self.project()
304            .inner
305            .start_send(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))?;
306
307        Ok(())
308    }
309
310    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
311        self.project().inner.poll_flush(cx).map_err(Into::into)
312    }
313
314    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
315        self.project().inner.poll_close(cx).map_err(Into::into)
316    }
317}
318
319impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
320where
321    S: CanDisconnect<Bytes> + Send,
322    EthStreamError: From<<S as Sink<Bytes>>::Error>,
323    N: NetworkPrimitives,
324{
325    fn disconnect(
326        &mut self,
327        reason: DisconnectReason,
328    ) -> Pin<Box<dyn Future<Output = Result<(), EthStreamError>> + Send + '_>> {
329        Box::pin(async move { self.inner.disconnect(reason).await.map_err(Into::into) })
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::UnauthedEthStream;
336    use crate::{
337        broadcast::BlockHashNumber,
338        errors::{EthHandshakeError, EthStreamError},
339        ethstream::RawCapabilityMessage,
340        hello::DEFAULT_TCP_PORT,
341        p2pstream::UnauthedP2PStream,
342        EthMessage, EthStream, EthVersion, HelloMessageWithProtocols, PassthroughCodec,
343        ProtocolVersion, Status, StatusMessage,
344    };
345    use alloy_chains::NamedChain;
346    use alloy_primitives::{bytes::Bytes, B256, U256};
347    use alloy_rlp::Decodable;
348    use futures::{SinkExt, StreamExt};
349    use reth_ecies::stream::ECIESStream;
350    use reth_eth_wire_types::{EthNetworkPrimitives, UnifiedStatus};
351    use reth_ethereum_forks::{ForkFilter, Head};
352    use reth_network_peers::pk2id;
353    use secp256k1::{SecretKey, SECP256K1};
354    use std::time::Duration;
355    use tokio::net::{TcpListener, TcpStream};
356    use tokio_util::codec::Decoder;
357
358    #[tokio::test]
359    async fn can_handshake() {
360        let genesis = B256::random();
361        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
362
363        let status = Status {
364            version: EthVersion::Eth67,
365            chain: NamedChain::Mainnet.into(),
366            total_difficulty: U256::ZERO,
367            blockhash: B256::random(),
368            genesis,
369            // Pass the current fork id.
370            forkid: fork_filter.current(),
371        };
372        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
373
374        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
375        let local_addr = listener.local_addr().unwrap();
376
377        let status_clone = unified_status;
378        let fork_filter_clone = fork_filter.clone();
379        let handle = tokio::spawn(async move {
380            // roughly based off of the design of tokio::net::TcpListener
381            let (incoming, _) = listener.accept().await.unwrap();
382            let stream = PassthroughCodec::default().framed(incoming);
383            let (_, their_status) = UnauthedEthStream::new(stream)
384                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
385                .await
386                .unwrap();
387
388            // just make sure it equals our status (our status is a clone of their status)
389            assert_eq!(their_status, status_clone);
390        });
391
392        let outgoing = TcpStream::connect(local_addr).await.unwrap();
393        let sink = PassthroughCodec::default().framed(outgoing);
394
395        // try to connect
396        let (_, their_status) = UnauthedEthStream::new(sink)
397            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
398            .await
399            .unwrap();
400
401        // their status is a clone of our status, these should be equal
402        assert_eq!(their_status, unified_status);
403
404        // wait for it to finish
405        handle.await.unwrap();
406    }
407
408    #[tokio::test]
409    async fn pass_handshake_on_low_td_bitlen() {
410        let genesis = B256::random();
411        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
412
413        let status = Status {
414            version: EthVersion::Eth67,
415            chain: NamedChain::Mainnet.into(),
416            total_difficulty: U256::from(2).pow(U256::from(100)) - U256::from(1),
417            blockhash: B256::random(),
418            genesis,
419            // Pass the current fork id.
420            forkid: fork_filter.current(),
421        };
422        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
423
424        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
425        let local_addr = listener.local_addr().unwrap();
426
427        let status_clone = unified_status;
428        let fork_filter_clone = fork_filter.clone();
429        let handle = tokio::spawn(async move {
430            // roughly based off of the design of tokio::net::TcpListener
431            let (incoming, _) = listener.accept().await.unwrap();
432            let stream = PassthroughCodec::default().framed(incoming);
433            let (_, their_status) = UnauthedEthStream::new(stream)
434                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
435                .await
436                .unwrap();
437
438            // just make sure it equals our status, and that the handshake succeeded
439            assert_eq!(their_status, status_clone);
440        });
441
442        let outgoing = TcpStream::connect(local_addr).await.unwrap();
443        let sink = PassthroughCodec::default().framed(outgoing);
444
445        // try to connect
446        let (_, their_status) = UnauthedEthStream::new(sink)
447            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
448            .await
449            .unwrap();
450
451        // their status is a clone of our status, these should be equal
452        assert_eq!(their_status, unified_status);
453
454        // await the other handshake
455        handle.await.unwrap();
456    }
457
458    #[tokio::test]
459    async fn fail_handshake_on_high_td_bitlen() {
460        let genesis = B256::random();
461        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
462
463        let status = Status {
464            version: EthVersion::Eth67,
465            chain: NamedChain::Mainnet.into(),
466            total_difficulty: U256::from(2).pow(U256::from(164)),
467            blockhash: B256::random(),
468            genesis,
469            // Pass the current fork id.
470            forkid: fork_filter.current(),
471        };
472        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
473
474        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
475        let local_addr = listener.local_addr().unwrap();
476
477        let status_clone = unified_status;
478        let fork_filter_clone = fork_filter.clone();
479        let handle = tokio::spawn(async move {
480            // roughly based off of the design of tokio::net::TcpListener
481            let (incoming, _) = listener.accept().await.unwrap();
482            let stream = PassthroughCodec::default().framed(incoming);
483            let handshake_res = UnauthedEthStream::new(stream)
484                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
485                .await;
486
487            // make sure the handshake fails due to td too high
488            assert!(matches!(
489                handshake_res,
490                Err(EthStreamError::EthHandshakeError(
491                    EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
492                ))
493            ));
494        });
495
496        let outgoing = TcpStream::connect(local_addr).await.unwrap();
497        let sink = PassthroughCodec::default().framed(outgoing);
498
499        // try to connect
500        let handshake_res = UnauthedEthStream::new(sink)
501            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
502            .await;
503
504        // this handshake should also fail due to td too high
505        assert!(matches!(
506            handshake_res,
507            Err(EthStreamError::EthHandshakeError(
508                EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
509            ))
510        ));
511
512        // await the other handshake
513        handle.await.unwrap();
514    }
515
516    #[tokio::test]
517    async fn can_write_and_read_cleartext() {
518        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
519        let local_addr = listener.local_addr().unwrap();
520        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
521            vec![
522                BlockHashNumber { hash: B256::random(), number: 5 },
523                BlockHashNumber { hash: B256::random(), number: 6 },
524            ]
525            .into(),
526        );
527
528        let test_msg_clone = test_msg.clone();
529        let handle = tokio::spawn(async move {
530            // roughly based off of the design of tokio::net::TcpListener
531            let (incoming, _) = listener.accept().await.unwrap();
532            let stream = PassthroughCodec::default().framed(incoming);
533            let mut stream = EthStream::new(EthVersion::Eth67, stream);
534
535            // use the stream to get the next message
536            let message = stream.next().await.unwrap().unwrap();
537            assert_eq!(message, test_msg_clone);
538        });
539
540        let outgoing = TcpStream::connect(local_addr).await.unwrap();
541        let sink = PassthroughCodec::default().framed(outgoing);
542        let mut client_stream = EthStream::new(EthVersion::Eth67, sink);
543
544        client_stream.send(test_msg).await.unwrap();
545
546        // make sure the server receives the message and asserts before ending the test
547        handle.await.unwrap();
548    }
549
550    #[tokio::test]
551    async fn can_write_and_read_ecies() {
552        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
553        let local_addr = listener.local_addr().unwrap();
554        let server_key = SecretKey::new(&mut rand_08::thread_rng());
555        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
556            vec![
557                BlockHashNumber { hash: B256::random(), number: 5 },
558                BlockHashNumber { hash: B256::random(), number: 6 },
559            ]
560            .into(),
561        );
562
563        let test_msg_clone = test_msg.clone();
564        let handle = tokio::spawn(async move {
565            // roughly based off of the design of tokio::net::TcpListener
566            let (incoming, _) = listener.accept().await.unwrap();
567            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
568            let mut stream = EthStream::new(EthVersion::Eth67, stream);
569
570            // use the stream to get the next message
571            let message = stream.next().await.unwrap().unwrap();
572            assert_eq!(message, test_msg_clone);
573        });
574
575        // create the server pubkey
576        let server_id = pk2id(&server_key.public_key(SECP256K1));
577
578        let client_key = SecretKey::new(&mut rand_08::thread_rng());
579
580        let outgoing = TcpStream::connect(local_addr).await.unwrap();
581        let outgoing = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
582        let mut client_stream = EthStream::new(EthVersion::Eth67, outgoing);
583
584        client_stream.send(test_msg).await.unwrap();
585
586        // make sure the server receives the message and asserts before ending the test
587        handle.await.unwrap();
588    }
589
590    #[tokio::test(flavor = "multi_thread")]
591    async fn ethstream_over_p2p() {
592        // create a p2p stream and server, then confirm that the two are authed
593        // create tcpstream
594        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
595        let local_addr = listener.local_addr().unwrap();
596        let server_key = SecretKey::new(&mut rand_08::thread_rng());
597        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
598            vec![
599                BlockHashNumber { hash: B256::random(), number: 5 },
600                BlockHashNumber { hash: B256::random(), number: 6 },
601            ]
602            .into(),
603        );
604
605        let genesis = B256::random();
606        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
607
608        let status = Status {
609            version: EthVersion::Eth67,
610            chain: NamedChain::Mainnet.into(),
611            total_difficulty: U256::ZERO,
612            blockhash: B256::random(),
613            genesis,
614            // Pass the current fork id.
615            forkid: fork_filter.current(),
616        };
617        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
618
619        let status_copy = unified_status;
620        let fork_filter_clone = fork_filter.clone();
621        let test_msg_clone = test_msg.clone();
622        let handle = tokio::spawn(async move {
623            // roughly based off of the design of tokio::net::TcpListener
624            let (incoming, _) = listener.accept().await.unwrap();
625            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
626
627            let server_hello = HelloMessageWithProtocols {
628                protocol_version: ProtocolVersion::V5,
629                client_version: "bitcoind/1.0.0".to_string(),
630                protocols: vec![EthVersion::Eth67.into()],
631                port: DEFAULT_TCP_PORT,
632                id: pk2id(&server_key.public_key(SECP256K1)),
633            };
634
635            let unauthed_stream = UnauthedP2PStream::new(stream);
636            let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
637            let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream)
638                .handshake(status_copy, fork_filter_clone)
639                .await
640                .unwrap();
641
642            // use the stream to get the next message
643            let message = eth_stream.next().await.unwrap().unwrap();
644            assert_eq!(message, test_msg_clone);
645        });
646
647        // create the server pubkey
648        let server_id = pk2id(&server_key.public_key(SECP256K1));
649
650        let client_key = SecretKey::new(&mut rand_08::thread_rng());
651
652        let outgoing = TcpStream::connect(local_addr).await.unwrap();
653        let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
654
655        let client_hello = HelloMessageWithProtocols {
656            protocol_version: ProtocolVersion::V5,
657            client_version: "bitcoind/1.0.0".to_string(),
658            protocols: vec![EthVersion::Eth67.into()],
659            port: DEFAULT_TCP_PORT,
660            id: pk2id(&client_key.public_key(SECP256K1)),
661        };
662
663        let unauthed_stream = UnauthedP2PStream::new(sink);
664        let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
665
666        let (mut client_stream, _) = UnauthedEthStream::new(p2p_stream)
667            .handshake(unified_status, fork_filter)
668            .await
669            .unwrap();
670
671        client_stream.send(test_msg).await.unwrap();
672
673        // make sure the server receives the message and asserts before ending the test
674        handle.await.unwrap();
675    }
676
677    #[tokio::test]
678    async fn handshake_should_timeout() {
679        let genesis = B256::random();
680        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
681
682        let status = Status {
683            version: EthVersion::Eth67,
684            chain: NamedChain::Mainnet.into(),
685            total_difficulty: U256::ZERO,
686            blockhash: B256::random(),
687            genesis,
688            // Pass the current fork id.
689            forkid: fork_filter.current(),
690        };
691        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
692
693        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
694        let local_addr = listener.local_addr().unwrap();
695
696        let status_clone = unified_status;
697        let fork_filter_clone = fork_filter.clone();
698        let _handle = tokio::spawn(async move {
699            // Delay accepting the connection for longer than the client's timeout period
700            tokio::time::sleep(Duration::from_secs(11)).await;
701            // roughly based off of the design of tokio::net::TcpListener
702            let (incoming, _) = listener.accept().await.unwrap();
703            let stream = PassthroughCodec::default().framed(incoming);
704            let (_, their_status) = UnauthedEthStream::new(stream)
705                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
706                .await
707                .unwrap();
708
709            // just make sure it equals our status (our status is a clone of their status)
710            assert_eq!(their_status, status_clone);
711        });
712
713        let outgoing = TcpStream::connect(local_addr).await.unwrap();
714        let sink = PassthroughCodec::default().framed(outgoing);
715
716        // try to connect
717        let handshake_result = UnauthedEthStream::new(sink)
718            .handshake_with_timeout::<EthNetworkPrimitives>(
719                unified_status,
720                fork_filter,
721                Duration::from_secs(1),
722            )
723            .await;
724
725        // Assert that a timeout error occurred
726        assert!(
727            matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
728        );
729    }
730
731    #[tokio::test]
732    async fn can_write_and_read_raw_capability() {
733        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
734        let local_addr = listener.local_addr().unwrap();
735
736        let test_msg = RawCapabilityMessage { id: 0x1234, payload: Bytes::from(vec![1, 2, 3, 4]) };
737
738        let test_msg_clone = test_msg.clone();
739        let handle = tokio::spawn(async move {
740            let (incoming, _) = listener.accept().await.unwrap();
741            let stream = PassthroughCodec::default().framed(incoming);
742            let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
743
744            let bytes = stream.inner_mut().next().await.unwrap().unwrap();
745
746            // Create a cursor to track position while decoding
747            let mut id_bytes = &bytes[..];
748            let decoded_id = <usize as Decodable>::decode(&mut id_bytes).unwrap();
749            assert_eq!(decoded_id, test_msg_clone.id);
750
751            // Get remaining bytes after ID decoding
752            let remaining = id_bytes;
753            assert_eq!(remaining, &test_msg_clone.payload[..]);
754        });
755
756        let outgoing = TcpStream::connect(local_addr).await.unwrap();
757        let sink = PassthroughCodec::default().framed(outgoing);
758        let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
759
760        client_stream.start_send_raw(test_msg).unwrap();
761        client_stream.inner_mut().flush().await.unwrap();
762
763        handle.await.unwrap();
764    }
765
766    #[tokio::test]
767    async fn status_message_after_handshake_triggers_disconnect() {
768        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
769        let local_addr = listener.local_addr().unwrap();
770
771        let handle = tokio::spawn(async move {
772            let (incoming, _) = listener.accept().await.unwrap();
773            let stream = PassthroughCodec::default().framed(incoming);
774            let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
775
776            // Try to send a Status message after handshake - this should trigger disconnect
777            let status = Status {
778                version: EthVersion::Eth67,
779                chain: NamedChain::Mainnet.into(),
780                total_difficulty: U256::ZERO,
781                blockhash: B256::random(),
782                genesis: B256::random(),
783                forkid: ForkFilter::new(Head::default(), B256::random(), 0, Vec::new()).current(),
784            };
785            let status_message =
786                EthMessage::<EthNetworkPrimitives>::Status(StatusMessage::Legacy(status));
787
788            // This should return an error and trigger disconnect
789            let result = stream.send(status_message).await;
790            assert!(result.is_err());
791            assert!(matches!(
792                result.unwrap_err(),
793                EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake)
794            ));
795        });
796
797        let outgoing = TcpStream::connect(local_addr).await.unwrap();
798        let sink = PassthroughCodec::default().framed(outgoing);
799        let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
800
801        // Send a valid message to keep the connection alive
802        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
803            vec![BlockHashNumber { hash: B256::random(), number: 5 }].into(),
804        );
805        client_stream.send(test_msg).await.unwrap();
806
807        handle.await.unwrap();
808    }
809}