1use crate::{
8 errors::{EthHandshakeError, EthStreamError},
9 handshake::EthereumEthHandshake,
10 message::{
11 EthBroadcastMessage, ProtocolBroadcastMessage, MAX_MESSAGE_SIZE,
12 TX_MEMORY_BUDGET_MULTIPLIER,
13 },
14 p2pstream::HANDSHAKE_TIMEOUT,
15 CanDisconnect, DisconnectReason, EthMessage, EthNetworkPrimitives, EthVersion, ProtocolMessage,
16 UnifiedStatus,
17};
18use alloy_primitives::bytes::{Bytes, BytesMut};
19use alloy_rlp::Encodable;
20use futures::{ready, Sink, SinkExt};
21use pin_project::pin_project;
22use reth_eth_wire_types::{EthMessageID, NetworkPrimitives, RawCapabilityMessage};
23use reth_ethereum_forks::ForkFilter;
24use std::{
25 future::Future,
26 pin::Pin,
27 task::{Context, Poll},
28 time::Duration,
29};
30use tokio::time::timeout;
31use tokio_stream::Stream;
32use tracing::{debug, trace};
33
34#[pin_project]
37#[derive(Debug)]
38pub struct UnauthedEthStream<S> {
39 #[pin]
40 inner: S,
41}
42
43impl<S> UnauthedEthStream<S> {
44 pub const fn new(inner: S) -> Self {
46 Self { inner }
47 }
48
49 pub fn into_inner(self) -> S {
51 self.inner
52 }
53}
54
55impl<S, E> UnauthedEthStream<S>
56where
57 S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
58 EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
59{
60 pub async fn handshake<N: NetworkPrimitives>(
67 self,
68 status: UnifiedStatus,
69 fork_filter: ForkFilter,
70 ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
71 self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
72 }
73
74 pub async fn handshake_with_timeout<N: NetworkPrimitives>(
76 self,
77 status: UnifiedStatus,
78 fork_filter: ForkFilter,
79 timeout_limit: Duration,
80 ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
81 timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
82 .await
83 .map_err(|_| EthStreamError::StreamTimeout)?
84 }
85
86 pub async fn handshake_without_timeout<N: NetworkPrimitives>(
88 mut self,
89 status: UnifiedStatus,
90 fork_filter: ForkFilter,
91 ) -> Result<(EthStream<S, N>, UnifiedStatus), EthStreamError> {
92 trace!(
93 status = %status.into_message(),
94 "sending eth status to peer"
95 );
96 let their_status =
97 EthereumEthHandshake(&mut self.inner).eth_handshake(status, fork_filter).await?;
98
99 let stream = EthStream::new(status.version, self.inner);
102
103 Ok((stream, their_status))
104 }
105}
106
107#[derive(Debug)]
109pub struct EthStreamInner<N> {
110 version: EthVersion,
112 max_message_size: usize,
114 reject_block_announcements: bool,
117 _pd: std::marker::PhantomData<N>,
118}
119
120impl<N> EthStreamInner<N>
121where
122 N: NetworkPrimitives,
123{
124 pub const fn new(version: EthVersion) -> Self {
126 Self::with_max_message_size(version, MAX_MESSAGE_SIZE)
127 }
128
129 pub const fn with_max_message_size(version: EthVersion, max_message_size: usize) -> Self {
131 Self {
132 version,
133 max_message_size,
134 reject_block_announcements: false,
135 _pd: std::marker::PhantomData,
136 }
137 }
138
139 #[inline]
141 pub const fn version(&self) -> EthVersion {
142 self.version
143 }
144
145 pub const fn set_reject_block_announcements(&mut self, reject: bool) {
148 self.reject_block_announcements = reject;
149 }
150
151 pub fn decode_message(&self, bytes: BytesMut) -> Result<EthMessage<N>, EthStreamError> {
153 if bytes.len() > self.max_message_size {
154 return Err(EthStreamError::MessageTooBig(bytes.len()));
155 }
156
157 if self.reject_block_announcements &&
158 let Some(&id) = bytes.first() &&
159 (id == EthMessageID::NewBlock.to_u8() || id == EthMessageID::NewBlockHashes.to_u8())
160 {
161 return Err(EthStreamError::UnsupportedMessage { message_id: id });
162 }
163
164 let msg = match ProtocolMessage::decode_message_with_tx_memory_budget(
165 self.version,
166 &mut bytes.as_ref(),
167 self.max_message_size * TX_MEMORY_BUDGET_MULTIPLIER,
168 ) {
169 Ok(m) => m,
170 Err(err) => {
171 let msg = if bytes.len() > 50 {
172 format!("{:02x?}...{:x?}", &bytes[..10], &bytes[bytes.len() - 10..])
173 } else {
174 format!("{bytes:02x?}")
175 };
176 debug!(
177 version=?self.version,
178 %msg,
179 "failed to decode protocol message"
180 );
181 return Err(EthStreamError::InvalidMessage(err));
182 }
183 };
184
185 if matches!(msg.message, EthMessage::Status(_)) {
186 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
187 }
188
189 Ok(msg.message)
190 }
191
192 pub fn encode_message(&self, item: EthMessage<N>) -> Result<Bytes, EthStreamError> {
196 if matches!(item, EthMessage::Status(_)) {
197 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
198 }
199
200 Ok(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))
201 }
202}
203
204#[pin_project]
207#[derive(Debug)]
208pub struct EthStream<S, N = EthNetworkPrimitives> {
209 eth: EthStreamInner<N>,
211 #[pin]
212 inner: S,
213}
214
215impl<S, N: NetworkPrimitives> EthStream<S, N> {
216 #[inline]
219 pub const fn new(version: EthVersion, inner: S) -> Self {
220 Self::with_max_message_size(version, inner, MAX_MESSAGE_SIZE)
221 }
222
223 #[inline]
225 pub const fn with_max_message_size(
226 version: EthVersion,
227 inner: S,
228 max_message_size: usize,
229 ) -> Self {
230 Self { eth: EthStreamInner::with_max_message_size(version, max_message_size), inner }
231 }
232
233 #[inline]
235 pub const fn version(&self) -> EthVersion {
236 self.eth.version()
237 }
238
239 pub const fn set_reject_block_announcements(&mut self, reject: bool) {
242 self.eth.set_reject_block_announcements(reject);
243 }
244
245 #[inline]
247 pub const fn inner(&self) -> &S {
248 &self.inner
249 }
250
251 #[inline]
253 pub const fn inner_mut(&mut self) -> &mut S {
254 &mut self.inner
255 }
256
257 #[inline]
259 pub fn into_inner(self) -> S {
260 self.inner
261 }
262}
263
264impl<S, E, N> EthStream<S, N>
265where
266 S: Sink<Bytes, Error = E> + Unpin,
267 EthStreamError: From<E>,
268 N: NetworkPrimitives,
269{
270 pub fn start_send_broadcast(
272 &mut self,
273 item: EthBroadcastMessage<N>,
274 ) -> Result<(), EthStreamError> {
275 self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
276 ProtocolBroadcastMessage::from(item),
277 )))?;
278
279 Ok(())
280 }
281
282 pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthStreamError> {
284 let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
285 msg.id.encode(&mut bytes);
286 bytes.extend_from_slice(&msg.payload);
287
288 self.inner.start_send_unpin(bytes.into())?;
289 Ok(())
290 }
291}
292
293impl<S, E, N> Stream for EthStream<S, N>
294where
295 S: Stream<Item = Result<BytesMut, E>> + Unpin,
296 EthStreamError: From<E>,
297 N: NetworkPrimitives,
298{
299 type Item = Result<EthMessage<N>, EthStreamError>;
300
301 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302 let this = self.project();
303 let res = ready!(this.inner.poll_next(cx));
304
305 match res {
306 Some(Ok(bytes)) => Poll::Ready(Some(this.eth.decode_message(bytes))),
307 Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
308 None => Poll::Ready(None),
309 }
310 }
311}
312
313impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
314where
315 S: CanDisconnect<Bytes> + Unpin,
316 EthStreamError: From<<S as Sink<Bytes>>::Error>,
317 N: NetworkPrimitives,
318{
319 type Error = EthStreamError;
320
321 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
322 self.project().inner.poll_ready(cx).map_err(Into::into)
323 }
324
325 fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
326 if matches!(item, EthMessage::Status(_)) {
327 let mut this = self.project();
330 let _disconnect_future = this.inner.disconnect(DisconnectReason::ProtocolBreach);
334 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
335 }
336
337 self.project()
338 .inner
339 .start_send(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))?;
340
341 Ok(())
342 }
343
344 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
345 self.project().inner.poll_flush(cx).map_err(Into::into)
346 }
347
348 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
349 self.project().inner.poll_close(cx).map_err(Into::into)
350 }
351}
352
353impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
354where
355 S: CanDisconnect<Bytes> + Send,
356 EthStreamError: From<<S as Sink<Bytes>>::Error>,
357 N: NetworkPrimitives,
358{
359 fn disconnect(
360 &mut self,
361 reason: DisconnectReason,
362 ) -> Pin<Box<dyn Future<Output = Result<(), EthStreamError>> + Send + '_>> {
363 Box::pin(async move { self.inner.disconnect(reason).await.map_err(Into::into) })
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::UnauthedEthStream;
370 use crate::{
371 broadcast::BlockHashNumber,
372 errors::{EthHandshakeError, EthStreamError},
373 ethstream::RawCapabilityMessage,
374 hello::DEFAULT_TCP_PORT,
375 p2pstream::UnauthedP2PStream,
376 EthMessage, EthStream, EthVersion, HelloMessageWithProtocols, PassthroughCodec,
377 ProtocolVersion, Status, StatusMessage,
378 };
379 use alloy_chains::NamedChain;
380 use alloy_primitives::{bytes::Bytes, B256, U256};
381 use alloy_rlp::Decodable;
382 use futures::{SinkExt, StreamExt};
383 use reth_ecies::stream::ECIESStream;
384 use reth_eth_wire_types::{EthNetworkPrimitives, UnifiedStatus};
385 use reth_ethereum_forks::{ForkFilter, Head};
386 use reth_network_peers::pk2id;
387 use secp256k1::{SecretKey, SECP256K1};
388 use std::time::Duration;
389 use tokio::net::{TcpListener, TcpStream};
390 use tokio_util::codec::Decoder;
391
392 #[tokio::test]
393 async fn can_handshake() {
394 let genesis = B256::random();
395 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
396
397 let status = Status {
398 version: EthVersion::Eth67,
399 chain: NamedChain::Mainnet.into(),
400 total_difficulty: U256::ZERO,
401 blockhash: B256::random(),
402 genesis,
403 forkid: fork_filter.current(),
405 };
406 let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
407
408 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
409 let local_addr = listener.local_addr().unwrap();
410
411 let status_clone = unified_status;
412 let fork_filter_clone = fork_filter.clone();
413 let handle = tokio::spawn(async move {
414 let (incoming, _) = listener.accept().await.unwrap();
416 let stream = PassthroughCodec::default().framed(incoming);
417 let (_, their_status) = UnauthedEthStream::new(stream)
418 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
419 .await
420 .unwrap();
421
422 assert_eq!(their_status, status_clone);
424 });
425
426 let outgoing = TcpStream::connect(local_addr).await.unwrap();
427 let sink = PassthroughCodec::default().framed(outgoing);
428
429 let (_, their_status) = UnauthedEthStream::new(sink)
431 .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
432 .await
433 .unwrap();
434
435 assert_eq!(their_status, unified_status);
437
438 handle.await.unwrap();
440 }
441
442 #[tokio::test]
443 async fn pass_handshake_on_low_td_bitlen() {
444 let genesis = B256::random();
445 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
446
447 let status = Status {
448 version: EthVersion::Eth67,
449 chain: NamedChain::Mainnet.into(),
450 total_difficulty: U256::from(2).pow(U256::from(100)) - U256::from(1),
451 blockhash: B256::random(),
452 genesis,
453 forkid: fork_filter.current(),
455 };
456 let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
457
458 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
459 let local_addr = listener.local_addr().unwrap();
460
461 let status_clone = unified_status;
462 let fork_filter_clone = fork_filter.clone();
463 let handle = tokio::spawn(async move {
464 let (incoming, _) = listener.accept().await.unwrap();
466 let stream = PassthroughCodec::default().framed(incoming);
467 let (_, their_status) = UnauthedEthStream::new(stream)
468 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
469 .await
470 .unwrap();
471
472 assert_eq!(their_status, status_clone);
474 });
475
476 let outgoing = TcpStream::connect(local_addr).await.unwrap();
477 let sink = PassthroughCodec::default().framed(outgoing);
478
479 let (_, their_status) = UnauthedEthStream::new(sink)
481 .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
482 .await
483 .unwrap();
484
485 assert_eq!(their_status, unified_status);
487
488 handle.await.unwrap();
490 }
491
492 #[tokio::test]
493 async fn fail_handshake_on_high_td_bitlen() {
494 let genesis = B256::random();
495 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
496
497 let status = Status {
498 version: EthVersion::Eth67,
499 chain: NamedChain::Mainnet.into(),
500 total_difficulty: U256::from(2).pow(U256::from(164)),
501 blockhash: B256::random(),
502 genesis,
503 forkid: fork_filter.current(),
505 };
506 let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
507
508 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
509 let local_addr = listener.local_addr().unwrap();
510
511 let status_clone = unified_status;
512 let fork_filter_clone = fork_filter.clone();
513 let handle = tokio::spawn(async move {
514 let (incoming, _) = listener.accept().await.unwrap();
516 let stream = PassthroughCodec::default().framed(incoming);
517 let handshake_res = UnauthedEthStream::new(stream)
518 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
519 .await;
520
521 assert!(matches!(
523 handshake_res,
524 Err(EthStreamError::EthHandshakeError(
525 EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
526 ))
527 ));
528 });
529
530 let outgoing = TcpStream::connect(local_addr).await.unwrap();
531 let sink = PassthroughCodec::default().framed(outgoing);
532
533 let handshake_res = UnauthedEthStream::new(sink)
535 .handshake::<EthNetworkPrimitives>(unified_status, fork_filter)
536 .await;
537
538 assert!(matches!(
540 handshake_res,
541 Err(EthStreamError::EthHandshakeError(
542 EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
543 ))
544 ));
545
546 handle.await.unwrap();
548 }
549
550 #[tokio::test]
551 async fn can_write_and_read_cleartext() {
552 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
553 let local_addr = listener.local_addr().unwrap();
554 let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
555 vec![
556 BlockHashNumber { hash: B256::random(), number: 5 },
557 BlockHashNumber { hash: B256::random(), number: 6 },
558 ]
559 .into(),
560 );
561
562 let test_msg_clone = test_msg.clone();
563 let handle = tokio::spawn(async move {
564 let (incoming, _) = listener.accept().await.unwrap();
566 let stream = PassthroughCodec::default().framed(incoming);
567 let mut stream = EthStream::new(EthVersion::Eth67, stream);
568
569 let message = stream.next().await.unwrap().unwrap();
571 assert_eq!(message, test_msg_clone);
572 });
573
574 let outgoing = TcpStream::connect(local_addr).await.unwrap();
575 let sink = PassthroughCodec::default().framed(outgoing);
576 let mut client_stream = EthStream::new(EthVersion::Eth67, sink);
577
578 client_stream.send(test_msg).await.unwrap();
579
580 handle.await.unwrap();
582 }
583
584 #[tokio::test]
585 async fn can_write_and_read_ecies() {
586 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
587 let local_addr = listener.local_addr().unwrap();
588 let server_key = SecretKey::new(&mut rand_08::thread_rng());
589 let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
590 vec![
591 BlockHashNumber { hash: B256::random(), number: 5 },
592 BlockHashNumber { hash: B256::random(), number: 6 },
593 ]
594 .into(),
595 );
596
597 let test_msg_clone = test_msg.clone();
598 let handle = tokio::spawn(async move {
599 let (incoming, _) = listener.accept().await.unwrap();
601 let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
602 let mut stream = EthStream::new(EthVersion::Eth67, stream);
603
604 let message = stream.next().await.unwrap().unwrap();
606 assert_eq!(message, test_msg_clone);
607 });
608
609 let server_id = pk2id(&server_key.public_key(SECP256K1));
611
612 let client_key = SecretKey::new(&mut rand_08::thread_rng());
613
614 let outgoing = TcpStream::connect(local_addr).await.unwrap();
615 let outgoing = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
616 let mut client_stream = EthStream::new(EthVersion::Eth67, outgoing);
617
618 client_stream.send(test_msg).await.unwrap();
619
620 handle.await.unwrap();
622 }
623
624 #[tokio::test(flavor = "multi_thread")]
625 async fn ethstream_over_p2p() {
626 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
629 let local_addr = listener.local_addr().unwrap();
630 let server_key = SecretKey::new(&mut rand_08::thread_rng());
631 let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
632 vec![
633 BlockHashNumber { hash: B256::random(), number: 5 },
634 BlockHashNumber { hash: B256::random(), number: 6 },
635 ]
636 .into(),
637 );
638
639 let genesis = B256::random();
640 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
641
642 let status = Status {
643 version: EthVersion::Eth67,
644 chain: NamedChain::Mainnet.into(),
645 total_difficulty: U256::ZERO,
646 blockhash: B256::random(),
647 genesis,
648 forkid: fork_filter.current(),
650 };
651 let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
652
653 let status_copy = unified_status;
654 let fork_filter_clone = fork_filter.clone();
655 let test_msg_clone = test_msg.clone();
656 let handle = tokio::spawn(async move {
657 let (incoming, _) = listener.accept().await.unwrap();
659 let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
660
661 let server_hello = HelloMessageWithProtocols {
662 protocol_version: ProtocolVersion::V5,
663 client_version: "bitcoind/1.0.0".to_string(),
664 protocols: vec![EthVersion::Eth67.into()],
665 port: DEFAULT_TCP_PORT,
666 id: pk2id(&server_key.public_key(SECP256K1)),
667 };
668
669 let unauthed_stream = UnauthedP2PStream::new(stream);
670 let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
671 let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream)
672 .handshake(status_copy, fork_filter_clone)
673 .await
674 .unwrap();
675
676 let message = eth_stream.next().await.unwrap().unwrap();
678 assert_eq!(message, test_msg_clone);
679 });
680
681 let server_id = pk2id(&server_key.public_key(SECP256K1));
683
684 let client_key = SecretKey::new(&mut rand_08::thread_rng());
685
686 let outgoing = TcpStream::connect(local_addr).await.unwrap();
687 let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
688
689 let client_hello = HelloMessageWithProtocols {
690 protocol_version: ProtocolVersion::V5,
691 client_version: "bitcoind/1.0.0".to_string(),
692 protocols: vec![EthVersion::Eth67.into()],
693 port: DEFAULT_TCP_PORT,
694 id: pk2id(&client_key.public_key(SECP256K1)),
695 };
696
697 let unauthed_stream = UnauthedP2PStream::new(sink);
698 let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
699
700 let (mut client_stream, _) = UnauthedEthStream::new(p2p_stream)
701 .handshake(unified_status, fork_filter)
702 .await
703 .unwrap();
704
705 client_stream.send(test_msg).await.unwrap();
706
707 handle.await.unwrap();
709 }
710
711 #[tokio::test]
712 async fn handshake_should_timeout() {
713 let genesis = B256::random();
714 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
715
716 let status = Status {
717 version: EthVersion::Eth67,
718 chain: NamedChain::Mainnet.into(),
719 total_difficulty: U256::ZERO,
720 blockhash: B256::random(),
721 genesis,
722 forkid: fork_filter.current(),
724 };
725 let unified_status = UnifiedStatus::from_message(StatusMessage::Legacy(status));
726
727 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
728 let local_addr = listener.local_addr().unwrap();
729
730 let status_clone = unified_status;
731 let fork_filter_clone = fork_filter.clone();
732 let _handle = tokio::spawn(async move {
733 tokio::time::sleep(Duration::from_secs(11)).await;
735 let (incoming, _) = listener.accept().await.unwrap();
737 let stream = PassthroughCodec::default().framed(incoming);
738 let (_, their_status) = UnauthedEthStream::new(stream)
739 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
740 .await
741 .unwrap();
742
743 assert_eq!(their_status, status_clone);
745 });
746
747 let outgoing = TcpStream::connect(local_addr).await.unwrap();
748 let sink = PassthroughCodec::default().framed(outgoing);
749
750 let handshake_result = UnauthedEthStream::new(sink)
752 .handshake_with_timeout::<EthNetworkPrimitives>(
753 unified_status,
754 fork_filter,
755 Duration::from_secs(1),
756 )
757 .await;
758
759 assert!(
761 matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
762 );
763 }
764
765 #[tokio::test]
766 async fn can_write_and_read_raw_capability() {
767 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
768 let local_addr = listener.local_addr().unwrap();
769
770 let test_msg = RawCapabilityMessage { id: 0x1234, payload: Bytes::from(vec![1, 2, 3, 4]) };
771
772 let test_msg_clone = test_msg.clone();
773 let handle = tokio::spawn(async move {
774 let (incoming, _) = listener.accept().await.unwrap();
775 let stream = PassthroughCodec::default().framed(incoming);
776 let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
777
778 let bytes = stream.inner_mut().next().await.unwrap().unwrap();
779
780 let mut id_bytes = &bytes[..];
782 let decoded_id = <usize as Decodable>::decode(&mut id_bytes).unwrap();
783 assert_eq!(decoded_id, test_msg_clone.id);
784
785 let remaining = id_bytes;
787 assert_eq!(remaining, &test_msg_clone.payload[..]);
788 });
789
790 let outgoing = TcpStream::connect(local_addr).await.unwrap();
791 let sink = PassthroughCodec::default().framed(outgoing);
792 let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
793
794 client_stream.start_send_raw(test_msg).unwrap();
795 client_stream.inner_mut().flush().await.unwrap();
796
797 handle.await.unwrap();
798 }
799
800 #[tokio::test]
801 async fn status_message_after_handshake_triggers_disconnect() {
802 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
803 let local_addr = listener.local_addr().unwrap();
804
805 let handle = tokio::spawn(async move {
806 let (incoming, _) = listener.accept().await.unwrap();
807 let stream = PassthroughCodec::default().framed(incoming);
808 let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
809
810 let status = Status {
812 version: EthVersion::Eth67,
813 chain: NamedChain::Mainnet.into(),
814 total_difficulty: U256::ZERO,
815 blockhash: B256::random(),
816 genesis: B256::random(),
817 forkid: ForkFilter::new(Head::default(), B256::random(), 0, Vec::new()).current(),
818 };
819 let status_message =
820 EthMessage::<EthNetworkPrimitives>::Status(StatusMessage::Legacy(status));
821
822 let result = stream.send(status_message).await;
824 assert!(result.is_err());
825 assert!(matches!(
826 result.unwrap_err(),
827 EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake)
828 ));
829 });
830
831 let outgoing = TcpStream::connect(local_addr).await.unwrap();
832 let sink = PassthroughCodec::default().framed(outgoing);
833 let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
834
835 let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
837 vec![BlockHashNumber { hash: B256::random(), number: 5 }].into(),
838 );
839 client_stream.send(test_msg).await.unwrap();
840
841 handle.await.unwrap();
842 }
843}