1use crate::{
8 errors::{EthHandshakeError, EthStreamError},
9 handshake::EthereumEthHandshake,
10 message::{EthBroadcastMessage, ProtocolBroadcastMessage},
11 p2pstream::HANDSHAKE_TIMEOUT,
12 CanDisconnect, DisconnectReason, EthMessage, EthNetworkPrimitives, EthVersion, ProtocolMessage,
13 Status,
14};
15use alloy_primitives::bytes::{Bytes, BytesMut};
16use alloy_rlp::Encodable;
17use futures::{ready, Sink, SinkExt};
18use pin_project::pin_project;
19use reth_eth_wire_types::{NetworkPrimitives, RawCapabilityMessage};
20use reth_ethereum_forks::ForkFilter;
21use std::{
22 future::Future,
23 pin::Pin,
24 task::{Context, Poll},
25 time::Duration,
26};
27use tokio::time::timeout;
28use tokio_stream::Stream;
29use tracing::{debug, trace};
30
31pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
34
35pub(crate) const MAX_STATUS_SIZE: usize = 500 * 1024;
37
38#[pin_project]
41#[derive(Debug)]
42pub struct UnauthedEthStream<S> {
43 #[pin]
44 inner: S,
45}
46
47impl<S> UnauthedEthStream<S> {
48 pub const fn new(inner: S) -> Self {
50 Self { inner }
51 }
52
53 pub fn into_inner(self) -> S {
55 self.inner
56 }
57}
58
59impl<S, E> UnauthedEthStream<S>
60where
61 S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
62 EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
63{
64 pub async fn handshake<N: NetworkPrimitives>(
68 self,
69 status: Status,
70 fork_filter: ForkFilter,
71 ) -> Result<(EthStream<S, N>, Status), EthStreamError> {
72 self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
73 }
74
75 pub async fn handshake_with_timeout<N: NetworkPrimitives>(
77 self,
78 status: Status,
79 fork_filter: ForkFilter,
80 timeout_limit: Duration,
81 ) -> Result<(EthStream<S, N>, Status), EthStreamError> {
82 timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
83 .await
84 .map_err(|_| EthStreamError::StreamTimeout)?
85 }
86
87 pub async fn handshake_without_timeout<N: NetworkPrimitives>(
89 mut self,
90 status: Status,
91 fork_filter: ForkFilter,
92 ) -> Result<(EthStream<S, N>, Status), EthStreamError> {
93 trace!(
94 %status,
95 "sending eth status to peer"
96 );
97 EthereumEthHandshake(&mut self.inner).eth_handshake(status, fork_filter).await?;
98
99 let stream = EthStream::new(status.version, self.inner);
102
103 Ok((stream, status))
104 }
105}
106
107#[derive(Debug)]
109pub struct EthStreamInner<N> {
110 version: EthVersion,
112 _pd: std::marker::PhantomData<N>,
113}
114
115impl<N> EthStreamInner<N>
116where
117 N: NetworkPrimitives,
118{
119 pub const fn new(version: EthVersion) -> Self {
121 Self { version, _pd: std::marker::PhantomData }
122 }
123
124 #[inline]
126 pub const fn version(&self) -> EthVersion {
127 self.version
128 }
129
130 pub fn decode_message(&self, bytes: BytesMut) -> Result<EthMessage<N>, EthStreamError> {
132 if bytes.len() > MAX_MESSAGE_SIZE {
133 return Err(EthStreamError::MessageTooBig(bytes.len()));
134 }
135
136 let msg = match ProtocolMessage::decode_message(self.version, &mut bytes.as_ref()) {
137 Ok(m) => m,
138 Err(err) => {
139 let msg = if bytes.len() > 50 {
140 format!("{:02x?}...{:x?}", &bytes[..10], &bytes[bytes.len() - 10..])
141 } else {
142 format!("{bytes:02x?}")
143 };
144 debug!(
145 version=?self.version,
146 %msg,
147 "failed to decode protocol message"
148 );
149 return Err(EthStreamError::InvalidMessage(err));
150 }
151 };
152
153 if matches!(msg.message, EthMessage::Status(_)) {
154 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
155 }
156
157 Ok(msg.message)
158 }
159
160 pub fn encode_message(&self, item: EthMessage<N>) -> Result<Bytes, EthStreamError> {
164 if matches!(item, EthMessage::Status(_)) {
165 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake));
166 }
167
168 Ok(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))
169 }
170}
171
172#[pin_project]
175#[derive(Debug)]
176pub struct EthStream<S, N = EthNetworkPrimitives> {
177 eth: EthStreamInner<N>,
179 #[pin]
180 inner: S,
181}
182
183impl<S, N: NetworkPrimitives> EthStream<S, N> {
184 #[inline]
187 pub const fn new(version: EthVersion, inner: S) -> Self {
188 Self { eth: EthStreamInner::new(version), inner }
189 }
190
191 #[inline]
193 pub const fn version(&self) -> EthVersion {
194 self.eth.version()
195 }
196
197 #[inline]
199 pub const fn inner(&self) -> &S {
200 &self.inner
201 }
202
203 #[inline]
205 pub const fn inner_mut(&mut self) -> &mut S {
206 &mut self.inner
207 }
208
209 #[inline]
211 pub fn into_inner(self) -> S {
212 self.inner
213 }
214}
215
216impl<S, E, N> EthStream<S, N>
217where
218 S: Sink<Bytes, Error = E> + Unpin,
219 EthStreamError: From<E>,
220 N: NetworkPrimitives,
221{
222 pub fn start_send_broadcast(
224 &mut self,
225 item: EthBroadcastMessage<N>,
226 ) -> Result<(), EthStreamError> {
227 self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
228 ProtocolBroadcastMessage::from(item),
229 )))?;
230
231 Ok(())
232 }
233
234 pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthStreamError> {
236 let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
237 msg.id.encode(&mut bytes);
238 bytes.extend_from_slice(&msg.payload);
239
240 self.inner.start_send_unpin(bytes.into())?;
241 Ok(())
242 }
243}
244
245impl<S, E, N> Stream for EthStream<S, N>
246where
247 S: Stream<Item = Result<BytesMut, E>> + Unpin,
248 EthStreamError: From<E>,
249 N: NetworkPrimitives,
250{
251 type Item = Result<EthMessage<N>, EthStreamError>;
252
253 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254 let this = self.project();
255 let res = ready!(this.inner.poll_next(cx));
256
257 match res {
258 Some(Ok(bytes)) => Poll::Ready(Some(this.eth.decode_message(bytes))),
259 Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
260 None => Poll::Ready(None),
261 }
262 }
263}
264
265impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
266where
267 S: CanDisconnect<Bytes> + Unpin,
268 EthStreamError: From<<S as Sink<Bytes>>::Error>,
269 N: NetworkPrimitives,
270{
271 type Error = EthStreamError;
272
273 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274 self.project().inner.poll_ready(cx).map_err(Into::into)
275 }
276
277 fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
278 if matches!(item, EthMessage::Status(_)) {
279 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
289 }
290
291 self.project()
292 .inner
293 .start_send(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))?;
294
295 Ok(())
296 }
297
298 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
299 self.project().inner.poll_flush(cx).map_err(Into::into)
300 }
301
302 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 self.project().inner.poll_close(cx).map_err(Into::into)
304 }
305}
306
307impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
308where
309 S: CanDisconnect<Bytes> + Send,
310 EthStreamError: From<<S as Sink<Bytes>>::Error>,
311 N: NetworkPrimitives,
312{
313 fn disconnect(
314 &mut self,
315 reason: DisconnectReason,
316 ) -> Pin<Box<dyn Future<Output = Result<(), EthStreamError>> + Send + '_>> {
317 Box::pin(async move { self.inner.disconnect(reason).await.map_err(Into::into) })
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::UnauthedEthStream;
324 use crate::{
325 broadcast::BlockHashNumber,
326 errors::{EthHandshakeError, EthStreamError},
327 ethstream::RawCapabilityMessage,
328 hello::DEFAULT_TCP_PORT,
329 p2pstream::UnauthedP2PStream,
330 EthMessage, EthStream, EthVersion, HelloMessageWithProtocols, PassthroughCodec,
331 ProtocolVersion, Status,
332 };
333 use alloy_chains::NamedChain;
334 use alloy_primitives::{bytes::Bytes, B256, U256};
335 use alloy_rlp::Decodable;
336 use futures::{SinkExt, StreamExt};
337 use reth_ecies::stream::ECIESStream;
338 use reth_eth_wire_types::EthNetworkPrimitives;
339 use reth_ethereum_forks::{ForkFilter, Head};
340 use reth_network_peers::pk2id;
341 use secp256k1::{SecretKey, SECP256K1};
342 use std::time::Duration;
343 use tokio::net::{TcpListener, TcpStream};
344 use tokio_util::codec::Decoder;
345
346 #[tokio::test]
347 async fn can_handshake() {
348 let genesis = B256::random();
349 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
350
351 let status = Status {
352 version: EthVersion::Eth67,
353 chain: NamedChain::Mainnet.into(),
354 total_difficulty: U256::ZERO,
355 blockhash: B256::random(),
356 genesis,
357 forkid: fork_filter.current(),
359 };
360
361 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
362 let local_addr = listener.local_addr().unwrap();
363
364 let status_clone = status;
365 let fork_filter_clone = fork_filter.clone();
366 let handle = tokio::spawn(async move {
367 let (incoming, _) = listener.accept().await.unwrap();
369 let stream = PassthroughCodec::default().framed(incoming);
370 let (_, their_status) = UnauthedEthStream::new(stream)
371 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
372 .await
373 .unwrap();
374
375 assert_eq!(their_status, status_clone);
377 });
378
379 let outgoing = TcpStream::connect(local_addr).await.unwrap();
380 let sink = PassthroughCodec::default().framed(outgoing);
381
382 let (_, their_status) = UnauthedEthStream::new(sink)
384 .handshake::<EthNetworkPrimitives>(status, fork_filter)
385 .await
386 .unwrap();
387
388 assert_eq!(their_status, status);
390
391 handle.await.unwrap();
393 }
394
395 #[tokio::test]
396 async fn pass_handshake_on_low_td_bitlen() {
397 let genesis = B256::random();
398 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
399
400 let status = Status {
401 version: EthVersion::Eth67,
402 chain: NamedChain::Mainnet.into(),
403 total_difficulty: U256::from(2).pow(U256::from(100)) - U256::from(1),
404 blockhash: B256::random(),
405 genesis,
406 forkid: fork_filter.current(),
408 };
409
410 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
411 let local_addr = listener.local_addr().unwrap();
412
413 let status_clone = status;
414 let fork_filter_clone = fork_filter.clone();
415 let handle = tokio::spawn(async move {
416 let (incoming, _) = listener.accept().await.unwrap();
418 let stream = PassthroughCodec::default().framed(incoming);
419 let (_, their_status) = UnauthedEthStream::new(stream)
420 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
421 .await
422 .unwrap();
423
424 assert_eq!(their_status, status_clone);
426 });
427
428 let outgoing = TcpStream::connect(local_addr).await.unwrap();
429 let sink = PassthroughCodec::default().framed(outgoing);
430
431 let (_, their_status) = UnauthedEthStream::new(sink)
433 .handshake::<EthNetworkPrimitives>(status, fork_filter)
434 .await
435 .unwrap();
436
437 assert_eq!(their_status, status);
439
440 handle.await.unwrap();
442 }
443
444 #[tokio::test]
445 async fn fail_handshake_on_high_td_bitlen() {
446 let genesis = B256::random();
447 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
448
449 let status = Status {
450 version: EthVersion::Eth67,
451 chain: NamedChain::Mainnet.into(),
452 total_difficulty: U256::from(2).pow(U256::from(164)),
453 blockhash: B256::random(),
454 genesis,
455 forkid: fork_filter.current(),
457 };
458
459 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
460 let local_addr = listener.local_addr().unwrap();
461
462 let status_clone = status;
463 let fork_filter_clone = fork_filter.clone();
464 let handle = tokio::spawn(async move {
465 let (incoming, _) = listener.accept().await.unwrap();
467 let stream = PassthroughCodec::default().framed(incoming);
468 let handshake_res = UnauthedEthStream::new(stream)
469 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
470 .await;
471
472 assert!(matches!(
474 handshake_res,
475 Err(EthStreamError::EthHandshakeError(
476 EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
477 ))
478 ));
479 });
480
481 let outgoing = TcpStream::connect(local_addr).await.unwrap();
482 let sink = PassthroughCodec::default().framed(outgoing);
483
484 let handshake_res = UnauthedEthStream::new(sink)
486 .handshake::<EthNetworkPrimitives>(status, fork_filter)
487 .await;
488
489 assert!(matches!(
491 handshake_res,
492 Err(EthStreamError::EthHandshakeError(
493 EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 165, maximum: 160 }
494 ))
495 ));
496
497 handle.await.unwrap();
499 }
500
501 #[tokio::test]
502 async fn can_write_and_read_cleartext() {
503 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
504 let local_addr = listener.local_addr().unwrap();
505 let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
506 vec![
507 BlockHashNumber { hash: B256::random(), number: 5 },
508 BlockHashNumber { hash: B256::random(), number: 6 },
509 ]
510 .into(),
511 );
512
513 let test_msg_clone = test_msg.clone();
514 let handle = tokio::spawn(async move {
515 let (incoming, _) = listener.accept().await.unwrap();
517 let stream = PassthroughCodec::default().framed(incoming);
518 let mut stream = EthStream::new(EthVersion::Eth67, stream);
519
520 let message = stream.next().await.unwrap().unwrap();
522 assert_eq!(message, test_msg_clone);
523 });
524
525 let outgoing = TcpStream::connect(local_addr).await.unwrap();
526 let sink = PassthroughCodec::default().framed(outgoing);
527 let mut client_stream = EthStream::new(EthVersion::Eth67, sink);
528
529 client_stream.send(test_msg).await.unwrap();
530
531 handle.await.unwrap();
533 }
534
535 #[tokio::test]
536 async fn can_write_and_read_ecies() {
537 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
538 let local_addr = listener.local_addr().unwrap();
539 let server_key = SecretKey::new(&mut rand_08::thread_rng());
540 let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
541 vec![
542 BlockHashNumber { hash: B256::random(), number: 5 },
543 BlockHashNumber { hash: B256::random(), number: 6 },
544 ]
545 .into(),
546 );
547
548 let test_msg_clone = test_msg.clone();
549 let handle = tokio::spawn(async move {
550 let (incoming, _) = listener.accept().await.unwrap();
552 let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
553 let mut stream = EthStream::new(EthVersion::Eth67, stream);
554
555 let message = stream.next().await.unwrap().unwrap();
557 assert_eq!(message, test_msg_clone);
558 });
559
560 let server_id = pk2id(&server_key.public_key(SECP256K1));
562
563 let client_key = SecretKey::new(&mut rand_08::thread_rng());
564
565 let outgoing = TcpStream::connect(local_addr).await.unwrap();
566 let outgoing = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
567 let mut client_stream = EthStream::new(EthVersion::Eth67, outgoing);
568
569 client_stream.send(test_msg).await.unwrap();
570
571 handle.await.unwrap();
573 }
574
575 #[tokio::test(flavor = "multi_thread")]
576 async fn ethstream_over_p2p() {
577 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
580 let local_addr = listener.local_addr().unwrap();
581 let server_key = SecretKey::new(&mut rand_08::thread_rng());
582 let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
583 vec![
584 BlockHashNumber { hash: B256::random(), number: 5 },
585 BlockHashNumber { hash: B256::random(), number: 6 },
586 ]
587 .into(),
588 );
589
590 let genesis = B256::random();
591 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
592
593 let status = Status {
594 version: EthVersion::Eth67,
595 chain: NamedChain::Mainnet.into(),
596 total_difficulty: U256::ZERO,
597 blockhash: B256::random(),
598 genesis,
599 forkid: fork_filter.current(),
601 };
602
603 let status_copy = status;
604 let fork_filter_clone = fork_filter.clone();
605 let test_msg_clone = test_msg.clone();
606 let handle = tokio::spawn(async move {
607 let (incoming, _) = listener.accept().await.unwrap();
609 let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
610
611 let server_hello = HelloMessageWithProtocols {
612 protocol_version: ProtocolVersion::V5,
613 client_version: "bitcoind/1.0.0".to_string(),
614 protocols: vec![EthVersion::Eth67.into()],
615 port: DEFAULT_TCP_PORT,
616 id: pk2id(&server_key.public_key(SECP256K1)),
617 };
618
619 let unauthed_stream = UnauthedP2PStream::new(stream);
620 let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
621 let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream)
622 .handshake(status_copy, fork_filter_clone)
623 .await
624 .unwrap();
625
626 let message = eth_stream.next().await.unwrap().unwrap();
628 assert_eq!(message, test_msg_clone);
629 });
630
631 let server_id = pk2id(&server_key.public_key(SECP256K1));
633
634 let client_key = SecretKey::new(&mut rand_08::thread_rng());
635
636 let outgoing = TcpStream::connect(local_addr).await.unwrap();
637 let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
638
639 let client_hello = HelloMessageWithProtocols {
640 protocol_version: ProtocolVersion::V5,
641 client_version: "bitcoind/1.0.0".to_string(),
642 protocols: vec![EthVersion::Eth67.into()],
643 port: DEFAULT_TCP_PORT,
644 id: pk2id(&client_key.public_key(SECP256K1)),
645 };
646
647 let unauthed_stream = UnauthedP2PStream::new(sink);
648 let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
649
650 let (mut client_stream, _) =
651 UnauthedEthStream::new(p2p_stream).handshake(status, fork_filter).await.unwrap();
652
653 client_stream.send(test_msg).await.unwrap();
654
655 handle.await.unwrap();
657 }
658
659 #[tokio::test]
660 async fn handshake_should_timeout() {
661 let genesis = B256::random();
662 let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
663
664 let status = Status {
665 version: EthVersion::Eth67,
666 chain: NamedChain::Mainnet.into(),
667 total_difficulty: U256::ZERO,
668 blockhash: B256::random(),
669 genesis,
670 forkid: fork_filter.current(),
672 };
673
674 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
675 let local_addr = listener.local_addr().unwrap();
676
677 let status_clone = status;
678 let fork_filter_clone = fork_filter.clone();
679 let _handle = tokio::spawn(async move {
680 tokio::time::sleep(Duration::from_secs(11)).await;
682 let (incoming, _) = listener.accept().await.unwrap();
684 let stream = PassthroughCodec::default().framed(incoming);
685 let (_, their_status) = UnauthedEthStream::new(stream)
686 .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
687 .await
688 .unwrap();
689
690 assert_eq!(their_status, status_clone);
692 });
693
694 let outgoing = TcpStream::connect(local_addr).await.unwrap();
695 let sink = PassthroughCodec::default().framed(outgoing);
696
697 let handshake_result = UnauthedEthStream::new(sink)
699 .handshake_with_timeout::<EthNetworkPrimitives>(
700 status,
701 fork_filter,
702 Duration::from_secs(1),
703 )
704 .await;
705
706 assert!(
708 matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
709 );
710 }
711
712 #[tokio::test]
713 async fn can_write_and_read_raw_capability() {
714 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
715 let local_addr = listener.local_addr().unwrap();
716
717 let test_msg = RawCapabilityMessage { id: 0x1234, payload: Bytes::from(vec![1, 2, 3, 4]) };
718
719 let test_msg_clone = test_msg.clone();
720 let handle = tokio::spawn(async move {
721 let (incoming, _) = listener.accept().await.unwrap();
722 let stream = PassthroughCodec::default().framed(incoming);
723 let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
724
725 let bytes = stream.inner_mut().next().await.unwrap().unwrap();
726
727 let mut id_bytes = &bytes[..];
729 let decoded_id = <usize as Decodable>::decode(&mut id_bytes).unwrap();
730 assert_eq!(decoded_id, test_msg_clone.id);
731
732 let remaining = id_bytes;
734 assert_eq!(remaining, &test_msg_clone.payload[..]);
735 });
736
737 let outgoing = TcpStream::connect(local_addr).await.unwrap();
738 let sink = PassthroughCodec::default().framed(outgoing);
739 let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
740
741 client_stream.start_send_raw(test_msg).unwrap();
742 client_stream.inner_mut().flush().await.unwrap();
743
744 handle.await.unwrap();
745 }
746}