reth_eth_wire/
ethstream.rs

1//! Ethereum protocol stream implementations.
2//!
3//! Provides stream types for the Ethereum wire protocol.
4//! It separates protocol logic [`EthStreamInner`] from transport concerns [`EthStream`].
5//! Handles handshaking, message processing, and RLP serialization.
6
7use crate::{
8    errors::{EthHandshakeError, EthStreamError},
9    handshake::EthereumEthHandshake,
10    message::{EthBroadcastMessage, ProtocolBroadcastMessage},
11    p2pstream::HANDSHAKE_TIMEOUT,
12    CanDisconnect, DisconnectReason, EthMessage, EthNetworkPrimitives, EthVersion, ProtocolMessage,
13    UnifiedStatus,
14};
15use alloy_primitives::bytes::{Bytes, BytesMut};
16use alloy_rlp::Encodable;
17use futures::{ready, Sink, SinkExt};
18use pin_project::pin_project;
19use reth_eth_wire_types::{NetworkPrimitives, RawCapabilityMessage};
20use reth_ethereum_forks::ForkFilter;
21use std::{
22    future::Future,
23    pin::Pin,
24    task::{Context, Poll},
25    time::Duration,
26};
27use tokio::time::timeout;
28use tokio_stream::Stream;
29use tracing::{debug, trace};
30
31/// [`MAX_MESSAGE_SIZE`] is the maximum cap on the size of a protocol message.
32// https://github.com/ethereum/go-ethereum/blob/30602163d5d8321fbc68afdcbbaf2362b2641bde/eth/protocols/eth/protocol.go#L50
33pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
34
35/// An un-authenticated [`EthStream`]. This is consumed and returns a [`EthStream`] after the
36/// `Status` handshake is completed.
37#[pin_project]
38#[derive(Debug)]
39pub struct UnauthedEthStream<S> {
40    #[pin]
41    inner: S,
42}
43
44impl<S> UnauthedEthStream<S> {
45    /// Create a new `UnauthedEthStream` from a type `S` which implements `Stream` and `Sink`.
46    pub const fn new(inner: S) -> Self {
47        Self { inner }
48    }
49
50    /// Consumes the type and returns the wrapped stream
51    pub fn into_inner(self) -> S {
52        self.inner
53    }
54}
55
56impl<S, E> UnauthedEthStream<S>
57where
58    S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
59    EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
60{
61    /// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
62    /// handshake is completed successfully. This also returns the `Status` message sent by the
63    /// remote peer.
64    ///
65    /// Caution: This expects that the [`UnifiedStatus`] has the proper eth version configured, with
66    /// ETH69 the initial status message changed.
67    pub async fn handshake<N: NetworkPrimitives>(
68        self,
69        status: UnifiedStatus,
70        fork_filter: ForkFilter,
71    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
72        self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
73    }
74
75    /// Wrapper around handshake which enforces a timeout.
76    pub async fn handshake_with_timeout<N: NetworkPrimitives>(
77        self,
78        status: UnifiedStatus,
79        fork_filter: ForkFilter,
80        timeout_limit: Duration,
81    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
82        timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
83            .await
84            .map_err(|_| EthStreamError::StreamTimeout)?
85    }
86
87    /// Handshake with no timeout
88    pub async fn handshake_without_timeout<N: NetworkPrimitives>(
89        mut self,
90        status: UnifiedStatus,
91        fork_filter: ForkFilter,
92    ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
93        trace!(
94            status = %status.into_message(),
95            "sending eth status to peer"
96        );
97        let their_status =
98            EthereumEthHandshake(&mut self.inner).eth_handshake(status, fork_filter).await?;
99
100        // now we can create the `EthStream` because the peer has successfully completed
101        // the handshake
102        let stream = EthStream::new(status.version, self.inner);
103
104        Ok((stream, their_status))
105    }
106}
107
108/// Contains eth protocol specific logic for processing messages
109#[derive(Debug)]
110pub struct EthStreamInner<N> {
111    /// Negotiated eth version
112    version: EthVersion,
113    _pd: std::marker::PhantomData<N>,
114}
115
116impl<N> EthStreamInner<N>
117where
118    N: NetworkPrimitives,
119{
120    /// Creates a new [`EthStreamInner`] with the given eth version
121    pub const fn new(version: EthVersion) -> Self {
122        Self { version, _pd: std::marker::PhantomData }
123    }
124
125    /// Returns the eth version
126    #[inline]
127    pub const fn version(&self) -> EthVersion {
128        self.version
129    }
130
131    /// Decodes incoming bytes into an [`EthMessage`].
132    pub fn decode_message(&self, bytes: BytesMut) -> Result<EthMessage<N>, EthStreamError> {
133        if bytes.len() > MAX_MESSAGE_SIZE {
134            return Err(EthStreamError::MessageTooBig(bytes.len()));
135        }
136
137        let msg = match ProtocolMessage::decode_message(self.version, &mut bytes.as_ref()) {
138            Ok(m) => m,
139            Err(err) => {
140                let msg = if bytes.len() > 50 {
141                    format!("{:02x?}...{:x?}", &bytes[..10], &bytes[bytes.len() - 10..])
142                } else {
143                    format!("{bytes:02x?}")
144                };
145                debug!(
146                    version=?self.version,
147                    %msg,
148                    "failed to decode protocol message"
149                );
150                return Err(EthStreamError::InvalidMessage(err));
151            }
152        };
153
154        if matches!(msg.message, EthMessage::Status(_)) {
155            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
156        }
157
158        Ok(msg.message)
159    }
160
161    /// Encodes an [`EthMessage`] to bytes.
162    ///
163    /// Validates that Status messages are not sent after handshake, enforcing protocol rules.
164    pub fn encode_message(&self, item: EthMessage<N>) -> Result<Bytes, EthStreamError> {
165        if matches!(item, EthMessage::Status(_)) {
166            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
167        }
168
169        Ok(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))
170    }
171}
172
173/// An `EthStream` wraps over any `Stream` that yields bytes and makes it
174/// compatible with eth-networking protocol messages, which get RLP encoded/decoded.
175#[pin_project]
176#[derive(Debug)]
177pub struct EthStream<S, N = EthNetworkPrimitives> {
178    /// Eth-specific logic
179    eth: EthStreamInner<N>,
180    #[pin]
181    inner: S,
182}
183
184impl<S, N: NetworkPrimitives> EthStream<S, N> {
185    /// Creates a new unauthed [`EthStream`] from a provided stream. You will need
186    /// to manually handshake a peer.
187    #[inline]
188    pub const fn new(version: EthVersion, inner: S) -> Self {
189        Self { eth: EthStreamInner::new(version), inner }
190    }
191
192    /// Returns the eth version.
193    #[inline]
194    pub const fn version(&self) -> EthVersion {
195        self.eth.version()
196    }
197
198    /// Returns the underlying stream.
199    #[inline]
200    pub const fn inner(&self) -> &S {
201        &self.inner
202    }
203
204    /// Returns mutable access to the underlying stream.
205    #[inline]
206    pub const fn inner_mut(&mut self) -> &mut S {
207        &mut self.inner
208    }
209
210    /// Consumes this type and returns the wrapped stream.
211    #[inline]
212    pub fn into_inner(self) -> S {
213        self.inner
214    }
215}
216
217impl<S, E, N> EthStream<S, N>
218where
219    S: Sink<Bytes, Error = E> + Unpin,
220    EthStreamError: From<E>,
221    N: NetworkPrimitives,
222{
223    /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
224    pub fn start_send_broadcast(
225        &mut self,
226        item: EthBroadcastMessage<N>,
227    ) -> Result<(), EthStreamError> {
228        self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
229            ProtocolBroadcastMessage::from(item),
230        )))?;
231
232        Ok(())
233    }
234
235    /// Sends a raw capability message directly over the stream
236    pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthStreamError> {
237        let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
238        msg.id.encode(&mut bytes);
239        bytes.extend_from_slice(&msg.payload);
240
241        self.inner.start_send_unpin(bytes.into())?;
242        Ok(())
243    }
244}
245
246impl<S, E, N> Stream for EthStream<S, N>
247where
248    S: Stream<Item = Result<BytesMut, E>> + Unpin,
249    EthStreamError: From<E>,
250    N: NetworkPrimitives,
251{
252    type Item = Result<EthMessage<N>, EthStreamError>;
253
254    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255        let this = self.project();
256        let res = ready!(this.inner.poll_next(cx));
257
258        match res {
259            Some(Ok(bytes)) => Poll::Ready(Some(this.eth.decode_message(bytes))),
260            Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
261            None => Poll::Ready(None),
262        }
263    }
264}
265
266impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
267where
268    S: CanDisconnect<Bytes> + Unpin,
269    EthStreamError: From<<S as Sink<Bytes>>::Error>,
270    N: NetworkPrimitives,
271{
272    type Error = EthStreamError;
273
274    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
275        self.project().inner.poll_ready(cx).map_err(Into::into)
276    }
277
278    fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
279        if matches!(item, EthMessage::Status(_)) {
280            // Attempt to disconnect the peer for protocol breach when trying to send Status
281            // message after handshake is complete
282            let mut this = self.project();
283            // We can't await the disconnect future here since this is a synchronous method,
284            // but we can start the disconnect process. The actual disconnect will be handled
285            // asynchronously by the caller or the stream's poll methods.
286            let _disconnect_future = this.inner.disconnect(DisconnectReason::ProtocolBreach);
287            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
288        }
289
290        self.project()
291            .inner
292            .start_send(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))?;
293
294        Ok(())
295    }
296
297    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
298        self.project().inner.poll_flush(cx).map_err(Into::into)
299    }
300
301    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
302        self.project().inner.poll_close(cx).map_err(Into::into)
303    }
304}
305
306impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
307where
308    S: CanDisconnect<Bytes> + Send,
309    EthStreamError: From<<S as Sink<Bytes>>::Error>,
310    N: NetworkPrimitives,
311{
312    fn disconnect(
313        &mut self,
314        reason: DisconnectReason,
315    ) -> Pin<Box<dyn Future<Output = Result<(), EthStreamError>> + Send + '_>> {
316        Box::pin(async move { self.inner.disconnect(reason).await.map_err(Into::into) })
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::UnauthedEthStream;
323    use crate::{
324        broadcast::BlockHashNumber,
325        errors::{EthHandshakeError, EthStreamError},
326        ethstream::RawCapabilityMessage,
327        hello::DEFAULT_TCP_PORT,
328        p2pstream::UnauthedP2PStream,
329        EthMessage, EthStream, EthVersion, HelloMessageWithProtocols, PassthroughCodec,
330        ProtocolVersion, Status, StatusMessage,
331    };
332    use alloy_chains::NamedChain;
333    use alloy_primitives::{bytes::Bytes, B256, U256};
334    use alloy_rlp::Decodable;
335    use futures::{SinkExt, StreamExt};
336    use reth_ecies::stream::ECIESStream;
337    use reth_eth_wire_types::{EthNetworkPrimitives, UnifiedStatus};
338    use reth_ethereum_forks::{ForkFilter, Head};
339    use reth_network_peers::pk2id;
340    use secp256k1::{SecretKey, SECP256K1};
341    use std::time::Duration;
342    use tokio::net::{TcpListener, TcpStream};
343    use tokio_util::codec::Decoder;
344
345    #[tokio::test]
346    async fn can_handshake() {
347        let genesis = B256::random();
348        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
349
350        let status = Status {
351            version: EthVersion::Eth67,
352            chain: NamedChain::Mainnet.into(),
353            total_difficulty: U256::ZERO,
354            blockhash: B256::random(),
355            genesis,
356            // Pass the current fork id.
357            forkid: fork_filter.current(),
358        };
359        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
360
361        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
362        let local_addr = listener.local_addr().unwrap();
363
364        let status_clone = unified_status;
365        let fork_filter_clone = fork_filter.clone();
366        let handle = tokio::spawn(async move {
367            // roughly based off of the design of tokio::net::TcpListener
368            let (incoming, _) = listener.accept().await.unwrap();
369            let stream = PassthroughCodec::default().framed(incoming);
370            let (_, their_status) = UnauthedEthStream::new(stream)
371                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
372                .await
373                .unwrap();
374
375            // just make sure it equals our status (our status is a clone of their status)
376            assert_eq!(their_status, status_clone);
377        });
378
379        let outgoing = TcpStream::connect(local_addr).await.unwrap();
380        let sink = PassthroughCodec::default().framed(outgoing);
381
382        // try to connect
383        let (_, their_status) = UnauthedEthStream::new(sink)
384            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
385            .await
386            .unwrap();
387
388        // their status is a clone of our status, these should be equal
389        assert_eq!(their_status, unified_status);
390
391        // wait for it to finish
392        handle.await.unwrap();
393    }
394
395    #[tokio::test]
396    async fn pass_handshake_on_low_td_bitlen() {
397        let genesis = B256::random();
398        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
399
400        let status = Status {
401            version: EthVersion::Eth67,
402            chain: NamedChain::Mainnet.into(),
403            total_difficulty: U256::from(2).pow(U256::from(100)) - U256::from(1),
404            blockhash: B256::random(),
405            genesis,
406            // Pass the current fork id.
407            forkid: fork_filter.current(),
408        };
409        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
410
411        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
412        let local_addr = listener.local_addr().unwrap();
413
414        let status_clone = unified_status;
415        let fork_filter_clone = fork_filter.clone();
416        let handle = tokio::spawn(async move {
417            // roughly based off of the design of tokio::net::TcpListener
418            let (incoming, _) = listener.accept().await.unwrap();
419            let stream = PassthroughCodec::default().framed(incoming);
420            let (_, their_status) = UnauthedEthStream::new(stream)
421                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
422                .await
423                .unwrap();
424
425            // just make sure it equals our status, and that the handshake succeeded
426            assert_eq!(their_status, status_clone);
427        });
428
429        let outgoing = TcpStream::connect(local_addr).await.unwrap();
430        let sink = PassthroughCodec::default().framed(outgoing);
431
432        // try to connect
433        let (_, their_status) = UnauthedEthStream::new(sink)
434            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
435            .await
436            .unwrap();
437
438        // their status is a clone of our status, these should be equal
439        assert_eq!(their_status, unified_status);
440
441        // await the other handshake
442        handle.await.unwrap();
443    }
444
445    #[tokio::test]
446    async fn fail_handshake_on_high_td_bitlen() {
447        let genesis = B256::random();
448        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
449
450        let status = Status {
451            version: EthVersion::Eth67,
452            chain: NamedChain::Mainnet.into(),
453            total_difficulty: U256::from(2).pow(U256::from(164)),
454            blockhash: B256::random(),
455            genesis,
456            // Pass the current fork id.
457            forkid: fork_filter.current(),
458        };
459        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
460
461        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
462        let local_addr = listener.local_addr().unwrap();
463
464        let status_clone = unified_status;
465        let fork_filter_clone = fork_filter.clone();
466        let handle = tokio::spawn(async move {
467            // roughly based off of the design of tokio::net::TcpListener
468            let (incoming, _) = listener.accept().await.unwrap();
469            let stream = PassthroughCodec::default().framed(incoming);
470            let handshake_res = UnauthedEthStream::new(stream)
471                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
472                .await;
473
474            // make sure the handshake fails due to td too high
475            assert!(matches!(
476                handshake_res,
477                Err(EthStreamError::EthHandshakeError(
478                    EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
479                ))
480            ));
481        });
482
483        let outgoing = TcpStream::connect(local_addr).await.unwrap();
484        let sink = PassthroughCodec::default().framed(outgoing);
485
486        // try to connect
487        let handshake_res = UnauthedEthStream::new(sink)
488            .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
489            .await;
490
491        // this handshake should also fail due to td too high
492        assert!(matches!(
493            handshake_res,
494            Err(EthStreamError::EthHandshakeError(
495                EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
496            ))
497        ));
498
499        // await the other handshake
500        handle.await.unwrap();
501    }
502
503    #[tokio::test]
504    async fn can_write_and_read_cleartext() {
505        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
506        let local_addr = listener.local_addr().unwrap();
507        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
508            vec![
509                BlockHashNumber { hash: B256::random(), number: 5 },
510                BlockHashNumber { hash: B256::random(), number: 6 },
511            ]
512            .into(),
513        );
514
515        let test_msg_clone = test_msg.clone();
516        let handle = tokio::spawn(async move {
517            // roughly based off of the design of tokio::net::TcpListener
518            let (incoming, _) = listener.accept().await.unwrap();
519            let stream = PassthroughCodec::default().framed(incoming);
520            let mut stream = EthStream::new(EthVersion::Eth67, stream);
521
522            // use the stream to get the next message
523            let message = stream.next().await.unwrap().unwrap();
524            assert_eq!(message, test_msg_clone);
525        });
526
527        let outgoing = TcpStream::connect(local_addr).await.unwrap();
528        let sink = PassthroughCodec::default().framed(outgoing);
529        let mut client_stream = EthStream::new(EthVersion::Eth67, sink);
530
531        client_stream.send(test_msg).await.unwrap();
532
533        // make sure the server receives the message and asserts before ending the test
534        handle.await.unwrap();
535    }
536
537    #[tokio::test]
538    async fn can_write_and_read_ecies() {
539        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
540        let local_addr = listener.local_addr().unwrap();
541        let server_key = SecretKey::new(&mut rand_08::thread_rng());
542        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
543            vec![
544                BlockHashNumber { hash: B256::random(), number: 5 },
545                BlockHashNumber { hash: B256::random(), number: 6 },
546            ]
547            .into(),
548        );
549
550        let test_msg_clone = test_msg.clone();
551        let handle = tokio::spawn(async move {
552            // roughly based off of the design of tokio::net::TcpListener
553            let (incoming, _) = listener.accept().await.unwrap();
554            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
555            let mut stream = EthStream::new(EthVersion::Eth67, stream);
556
557            // use the stream to get the next message
558            let message = stream.next().await.unwrap().unwrap();
559            assert_eq!(message, test_msg_clone);
560        });
561
562        // create the server pubkey
563        let server_id = pk2id(&server_key.public_key(SECP256K1));
564
565        let client_key = SecretKey::new(&mut rand_08::thread_rng());
566
567        let outgoing = TcpStream::connect(local_addr).await.unwrap();
568        let outgoing = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
569        let mut client_stream = EthStream::new(EthVersion::Eth67, outgoing);
570
571        client_stream.send(test_msg).await.unwrap();
572
573        // make sure the server receives the message and asserts before ending the test
574        handle.await.unwrap();
575    }
576
577    #[tokio::test(flavor = "multi_thread")]
578    async fn ethstream_over_p2p() {
579        // create a p2p stream and server, then confirm that the two are authed
580        // create tcpstream
581        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
582        let local_addr = listener.local_addr().unwrap();
583        let server_key = SecretKey::new(&mut rand_08::thread_rng());
584        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
585            vec![
586                BlockHashNumber { hash: B256::random(), number: 5 },
587                BlockHashNumber { hash: B256::random(), number: 6 },
588            ]
589            .into(),
590        );
591
592        let genesis = B256::random();
593        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
594
595        let status = Status {
596            version: EthVersion::Eth67,
597            chain: NamedChain::Mainnet.into(),
598            total_difficulty: U256::ZERO,
599            blockhash: B256::random(),
600            genesis,
601            // Pass the current fork id.
602            forkid: fork_filter.current(),
603        };
604        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
605
606        let status_copy = unified_status;
607        let fork_filter_clone = fork_filter.clone();
608        let test_msg_clone = test_msg.clone();
609        let handle = tokio::spawn(async move {
610            // roughly based off of the design of tokio::net::TcpListener
611            let (incoming, _) = listener.accept().await.unwrap();
612            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
613
614            let server_hello = HelloMessageWithProtocols {
615                protocol_version: ProtocolVersion::V5,
616                client_version: "bitcoind/1.0.0".to_string(),
617                protocols: vec![EthVersion::Eth67.into()],
618                port: DEFAULT_TCP_PORT,
619                id: pk2id(&server_key.public_key(SECP256K1)),
620            };
621
622            let unauthed_stream = UnauthedP2PStream::new(stream);
623            let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
624            let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream)
625                .handshake(status_copy, fork_filter_clone)
626                .await
627                .unwrap();
628
629            // use the stream to get the next message
630            let message = eth_stream.next().await.unwrap().unwrap();
631            assert_eq!(message, test_msg_clone);
632        });
633
634        // create the server pubkey
635        let server_id = pk2id(&server_key.public_key(SECP256K1));
636
637        let client_key = SecretKey::new(&mut rand_08::thread_rng());
638
639        let outgoing = TcpStream::connect(local_addr).await.unwrap();
640        let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
641
642        let client_hello = HelloMessageWithProtocols {
643            protocol_version: ProtocolVersion::V5,
644            client_version: "bitcoind/1.0.0".to_string(),
645            protocols: vec![EthVersion::Eth67.into()],
646            port: DEFAULT_TCP_PORT,
647            id: pk2id(&client_key.public_key(SECP256K1)),
648        };
649
650        let unauthed_stream = UnauthedP2PStream::new(sink);
651        let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
652
653        let (mut client_stream, _) = UnauthedEthStream::new(p2p_stream)
654            .handshake(unified_status, fork_filter)
655            .await
656            .unwrap();
657
658        client_stream.send(test_msg).await.unwrap();
659
660        // make sure the server receives the message and asserts before ending the test
661        handle.await.unwrap();
662    }
663
664    #[tokio::test]
665    async fn handshake_should_timeout() {
666        let genesis = B256::random();
667        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
668
669        let status = Status {
670            version: EthVersion::Eth67,
671            chain: NamedChain::Mainnet.into(),
672            total_difficulty: U256::ZERO,
673            blockhash: B256::random(),
674            genesis,
675            // Pass the current fork id.
676            forkid: fork_filter.current(),
677        };
678        let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
679
680        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
681        let local_addr = listener.local_addr().unwrap();
682
683        let status_clone = unified_status;
684        let fork_filter_clone = fork_filter.clone();
685        let _handle = tokio::spawn(async move {
686            // Delay accepting the connection for longer than the client's timeout period
687            tokio::time::sleep(Duration::from_secs(11)).await;
688            // roughly based off of the design of tokio::net::TcpListener
689            let (incoming, _) = listener.accept().await.unwrap();
690            let stream = PassthroughCodec::default().framed(incoming);
691            let (_, their_status) = UnauthedEthStream::new(stream)
692                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
693                .await
694                .unwrap();
695
696            // just make sure it equals our status (our status is a clone of their status)
697            assert_eq!(their_status, status_clone);
698        });
699
700        let outgoing = TcpStream::connect(local_addr).await.unwrap();
701        let sink = PassthroughCodec::default().framed(outgoing);
702
703        // try to connect
704        let handshake_result = UnauthedEthStream::new(sink)
705            .handshake_with_timeout::<EthNetworkPrimitives>(
706                unified_status,
707                fork_filter,
708                Duration::from_secs(1),
709            )
710            .await;
711
712        // Assert that a timeout error occurred
713        assert!(
714            matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
715        );
716    }
717
718    #[tokio::test]
719    async fn can_write_and_read_raw_capability() {
720        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
721        let local_addr = listener.local_addr().unwrap();
722
723        let test_msg = RawCapabilityMessage { id: 0x1234, payload: Bytes::from(vec![1, 2, 3, 4]) };
724
725        let test_msg_clone = test_msg.clone();
726        let handle = tokio::spawn(async move {
727            let (incoming, _) = listener.accept().await.unwrap();
728            let stream = PassthroughCodec::default().framed(incoming);
729            let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
730
731            let bytes = stream.inner_mut().next().await.unwrap().unwrap();
732
733            // Create a cursor to track position while decoding
734            let mut id_bytes = &bytes[..];
735            let decoded_id = <usize as Decodable>::decode(&mut id_bytes).unwrap();
736            assert_eq!(decoded_id, test_msg_clone.id);
737
738            // Get remaining bytes after ID decoding
739            let remaining = id_bytes;
740            assert_eq!(remaining, &test_msg_clone.payload[..]);
741        });
742
743        let outgoing = TcpStream::connect(local_addr).await.unwrap();
744        let sink = PassthroughCodec::default().framed(outgoing);
745        let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
746
747        client_stream.start_send_raw(test_msg).unwrap();
748        client_stream.inner_mut().flush().await.unwrap();
749
750        handle.await.unwrap();
751    }
752
753    #[tokio::test]
754    async fn status_message_after_handshake_triggers_disconnect() {
755        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
756        let local_addr = listener.local_addr().unwrap();
757
758        let handle = tokio::spawn(async move {
759            let (incoming, _) = listener.accept().await.unwrap();
760            let stream = PassthroughCodec::default().framed(incoming);
761            let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
762
763            // Try to send a Status message after handshake - this should trigger disconnect
764            let status = Status {
765                version: EthVersion::Eth67,
766                chain: NamedChain::Mainnet.into(),
767                total_difficulty: U256::ZERO,
768                blockhash: B256::random(),
769                genesis: B256::random(),
770                forkid: ForkFilter::new(Head::default(), B256::random(), 0, Vec::new()).current(),
771            };
772            let status_message =
773                EthMessage::<EthNetworkPrimitives>::Status(StatusMessage::Legacy(status));
774
775            // This should return an error and trigger disconnect
776            let result = stream.send(status_message).await;
777            assert!(result.is_err());
778            assert!(matches!(
779                result.unwrap_err(),
780                EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake)
781            ));
782        });
783
784        let outgoing = TcpStream::connect(local_addr).await.unwrap();
785        let sink = PassthroughCodec::default().framed(outgoing);
786        let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
787
788        // Send a valid message to keep the connection alive
789        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
790            vec![BlockHashNumber { hash: B256::random(), number: 5 }].into(),
791        );
792        client_stream.send(test_msg).await.unwrap();
793
794        handle.await.unwrap();
795    }
796}