1use crate::{
2 capability::SharedCapabilities,
3 disconnect::CanDisconnect,
4 errors::{P2PHandshakeError, P2PStreamError},
5 pinger::{Pinger, PingerEvent},
6 DisconnectReason, HelloMessage, HelloMessageWithProtocols,
7};
8use alloy_primitives::{
9 bytes::{Buf, BufMut, Bytes, BytesMut},
10 hex,
11};
12use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE};
13use futures::{Sink, SinkExt, StreamExt};
14use pin_project::pin_project;
15use reth_codecs::add_arbitrary_tests;
16use reth_metrics::metrics::counter;
17use reth_primitives_traits::GotExpected;
18use std::{
19 collections::VecDeque,
20 future::Future,
21 io,
22 pin::Pin,
23 task::{ready, Context, Poll},
24 time::Duration,
25};
26use tokio_stream::Stream;
27use tracing::{debug, trace};
28
29#[cfg(feature = "serde")]
30use serde::{Deserialize, Serialize};
31
32const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36pub const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f;
39
40const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
42
43pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
46
47const PING_TIMEOUT: Duration = Duration::from_secs(15);
50
51const PING_INTERVAL: Duration = Duration::from_secs(60);
54
55const MAX_P2P_CAPACITY: usize = 2;
62
63#[pin_project]
66#[derive(Debug)]
67pub struct UnauthedP2PStream<S> {
68 #[pin]
69 inner: S,
70}
71
72impl<S> UnauthedP2PStream<S> {
73 pub const fn new(inner: S) -> Self {
75 Self { inner }
76 }
77
78 pub const fn inner(&self) -> &S {
80 &self.inner
81 }
82}
83
84impl<S> UnauthedP2PStream<S>
85where
86 S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
87{
88 pub async fn handshake(
91 mut self,
92 hello: HelloMessageWithProtocols,
93 ) -> Result<(P2PStream<S>, HelloMessage), P2PStreamError> {
94 trace!(?hello, "sending p2p hello to peer");
95
96 self.inner.send(alloy_rlp::encode(P2PMessage::Hello(hello.message())).into()).await?;
98
99 let first_message_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next())
100 .await
101 .or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))?
102 .ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??;
103
104 if first_message_bytes.len() > MAX_PAYLOAD_SIZE {
107 return Err(P2PStreamError::MessageTooBig {
108 message_size: first_message_bytes.len(),
109 max_size: MAX_PAYLOAD_SIZE,
110 })
111 }
112
113 let their_hello = match P2PMessage::decode(&mut &first_message_bytes[..]) {
120 Ok(P2PMessage::Hello(hello)) => Ok(hello),
121 Ok(P2PMessage::Disconnect(reason)) => {
122 if matches!(reason, DisconnectReason::TooManyPeers) {
123 trace!(%reason, "Disconnected by peer during handshake");
125 } else {
126 debug!(%reason, "Disconnected by peer during handshake");
127 };
128 counter!("p2pstream.disconnected_errors").increment(1);
129 Err(P2PStreamError::HandshakeError(P2PHandshakeError::Disconnected(reason)))
130 }
131 Err(err) => {
132 debug!(%err, msg=%hex::encode(&first_message_bytes), "Failed to decode first message from peer");
133 Err(P2PStreamError::HandshakeError(err.into()))
134 }
135 Ok(msg) => {
136 debug!(?msg, "expected hello message but received another message");
137 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake))
138 }
139 }?;
140
141 trace!(
142 hello=?their_hello,
143 "validating incoming p2p hello from peer"
144 );
145
146 if (hello.protocol_version as u8) != their_hello.protocol_version as u8 {
147 self.send_disconnect(DisconnectReason::IncompatibleP2PProtocolVersion).await?;
149 return Err(P2PStreamError::MismatchedProtocolVersion(GotExpected {
150 got: their_hello.protocol_version,
151 expected: hello.protocol_version,
152 }))
153 }
154
155 let capability_res =
157 SharedCapabilities::try_new(hello.protocols, their_hello.capabilities.clone());
158
159 let shared_capability = match capability_res {
160 Err(err) => {
161 self.send_disconnect(DisconnectReason::UselessPeer).await?;
163 Err(err)
164 }
165 Ok(cap) => Ok(cap),
166 }?;
167
168 let stream = P2PStream::new(self.inner, shared_capability);
169
170 Ok((stream, their_hello))
171 }
172}
173
174impl<S> UnauthedP2PStream<S>
175where
176 S: Sink<Bytes, Error = io::Error> + Unpin,
177{
178 pub async fn send_disconnect(
180 &mut self,
181 reason: DisconnectReason,
182 ) -> Result<(), P2PStreamError> {
183 trace!(
184 %reason,
185 "Sending disconnect message during the handshake",
186 );
187 self.inner
188 .send(Bytes::from(alloy_rlp::encode(P2PMessage::Disconnect(reason))))
189 .await
190 .map_err(P2PStreamError::Io)
191 }
192}
193
194impl<S> CanDisconnect<Bytes> for P2PStream<S>
195where
196 S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
197{
198 fn disconnect(
199 &mut self,
200 reason: DisconnectReason,
201 ) -> Pin<Box<dyn Future<Output = Result<(), P2PStreamError>> + Send + '_>> {
202 Box::pin(async move { self.disconnect(reason).await })
203 }
204}
205
206#[pin_project]
231#[derive(Debug)]
232pub struct P2PStream<S> {
233 #[pin]
234 inner: S,
235
236 encoder: snap::raw::Encoder,
238
239 decoder: snap::raw::Decoder,
241
242 pinger: Pinger,
244
245 shared_capabilities: SharedCapabilities,
247
248 outgoing_messages: VecDeque<Bytes>,
250
251 outgoing_message_buffer_capacity: usize,
254
255 disconnecting: bool,
258}
259
260impl<S> P2PStream<S> {
261 pub fn new(inner: S, shared_capabilities: SharedCapabilities) -> Self {
265 Self {
266 inner,
267 encoder: snap::raw::Encoder::new(),
268 decoder: snap::raw::Decoder::new(),
269 pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
270 shared_capabilities,
271 outgoing_messages: VecDeque::new(),
272 outgoing_message_buffer_capacity: MAX_P2P_CAPACITY,
273 disconnecting: false,
274 }
275 }
276
277 pub const fn inner(&self) -> &S {
279 &self.inner
280 }
281
282 pub const fn set_outgoing_message_buffer_capacity(&mut self, capacity: usize) {
288 self.outgoing_message_buffer_capacity = capacity;
289 }
290
291 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
296 &self.shared_capabilities
297 }
298
299 fn has_outgoing_capacity(&self) -> bool {
301 self.outgoing_messages.len() < self.outgoing_message_buffer_capacity
302 }
303
304 fn send_pong(&mut self) {
306 self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Pong)));
307 }
308
309 pub fn send_ping(&mut self) {
311 self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Ping)));
312 }
313}
314
315pub trait DisconnectP2P {
318 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>;
320
321 fn is_disconnecting(&self) -> bool;
323}
324
325impl<S> DisconnectP2P for P2PStream<S> {
326 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
335 self.outgoing_messages.clear();
337 let disconnect = P2PMessage::Disconnect(reason);
338 let mut buf = Vec::with_capacity(disconnect.length());
339 disconnect.encode(&mut buf);
340
341 let mut compressed = vec![0u8; 1 + snap::raw::max_compress_len(buf.len() - 1)];
342 let compressed_size =
343 self.encoder.compress(&buf[1..], &mut compressed[1..]).map_err(|err| {
344 debug!(
345 %err,
346 msg=%hex::encode(&buf[1..]),
347 "error compressing disconnect"
348 );
349 err
350 })?;
351
352 compressed.truncate(compressed_size + 1);
355
356 compressed[0] = buf[0];
359
360 self.outgoing_messages.push_back(compressed.into());
361 self.disconnecting = true;
362 Ok(())
363 }
364
365 fn is_disconnecting(&self) -> bool {
366 self.disconnecting
367 }
368}
369
370impl<S> P2PStream<S>
371where
372 S: Sink<Bytes, Error = io::Error> + Unpin + Send,
373{
374 pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
379 self.start_disconnect(reason)?;
380 self.close().await
381 }
382}
383
384impl<S> Stream for P2PStream<S>
387where
388 S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
389{
390 type Item = Result<BytesMut, P2PStreamError>;
391
392 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
393 let this = self.get_mut();
394
395 if this.disconnecting {
396 return Poll::Ready(None)
398 }
399
400 while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) {
403 let bytes = match res {
404 Some(Ok(bytes)) => bytes,
405 Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
406 None => return Poll::Ready(None),
407 };
408
409 if bytes.is_empty() {
410 return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
412 }
413
414 let id = bytes[0];
419 if id == P2PMessageID::Disconnect as u8 {
420 if let Ok(reason) = DisconnectReason::decode(&mut &bytes[1..]) {
432 return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
433 }
434 }
435
436 let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
439 if decompressed_len > MAX_PAYLOAD_SIZE {
440 return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
441 message_size: decompressed_len,
442 max_size: MAX_PAYLOAD_SIZE,
443 })))
444 }
445
446 let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
449
450 this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..]).map_err(|err| {
453 debug!(
454 %err,
455 msg=%hex::encode(&bytes[1..]),
456 "error decompressing p2p message"
457 );
458 err
459 })?;
460
461 match id {
462 _ if id == P2PMessageID::Ping as u8 => {
463 trace!("Received Ping, Sending Pong");
464 this.send_pong();
465 cx.waker().wake_by_ref();
468 }
469 _ if id == P2PMessageID::Hello as u8 => {
470 return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
473 P2PHandshakeError::HelloNotInHandshake,
474 ))))
475 }
476 _ if id == P2PMessageID::Pong as u8 => {
477 this.pinger.on_pong()?
479 }
480 _ if id == P2PMessageID::Disconnect as u8 => {
481 let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).inspect_err(|err| {
487 debug!(
488 %err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
489 );
490 })?;
491 return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
492 }
493 _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
494 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
496 }
497 _ => {
498 decompress_buf[0] = bytes[0] - MAX_RESERVED_MESSAGE_ID - 1;
522
523 return Poll::Ready(Some(Ok(decompress_buf)))
524 }
525 }
526 }
527
528 Poll::Pending
529 }
530}
531
532impl<S> Sink<Bytes> for P2PStream<S>
533where
534 S: Sink<Bytes, Error = io::Error> + Unpin,
535{
536 type Error = P2PStreamError;
537
538 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
539 let mut this = self.as_mut();
540
541 match this.pinger.poll_ping(cx) {
543 Poll::Pending => {}
544 Poll::Ready(Ok(PingerEvent::Ping)) => {
545 this.send_ping();
546 }
547 _ => {
548 this.start_disconnect(DisconnectReason::PingTimeout)?;
550
551 return Poll::Ready(Ok(()))
553 }
554 }
555
556 match this.inner.poll_ready_unpin(cx) {
557 Poll::Pending => {}
558 Poll::Ready(Err(err)) => return Poll::Ready(Err(P2PStreamError::Io(err))),
559 Poll::Ready(Ok(())) => {
560 let flushed = this.poll_flush(cx);
561 if flushed.is_ready() {
562 return flushed
563 }
564 }
565 }
566
567 if self.has_outgoing_capacity() {
568 Poll::Ready(Ok(()))
570 } else {
571 Poll::Pending
572 }
573 }
574
575 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
576 if item.len() > MAX_PAYLOAD_SIZE {
577 return Err(P2PStreamError::MessageTooBig {
578 message_size: item.len(),
579 max_size: MAX_PAYLOAD_SIZE,
580 })
581 }
582
583 if item.is_empty() {
584 return Err(P2PStreamError::EmptyProtocolMessage)
586 }
587
588 if !self.has_outgoing_capacity() {
590 return Err(P2PStreamError::SendBufferFull)
591 }
592
593 let this = self.project();
594
595 let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
596 let compressed_size =
597 this.encoder.compress(&item[1..], &mut compressed[1..]).map_err(|err| {
598 debug!(
599 %err,
600 msg=%hex::encode(&item[1..]),
601 "error compressing p2p message"
602 );
603 err
604 })?;
605
606 compressed.truncate(compressed_size + 1);
609
610 compressed[0] = item[0] + MAX_RESERVED_MESSAGE_ID + 1;
613 this.outgoing_messages.push_back(compressed.freeze());
614
615 Ok(())
616 }
617
618 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
620 let mut this = self.project();
621 let poll_res = loop {
622 match this.inner.as_mut().poll_ready(cx) {
623 Poll::Pending => break Poll::Pending,
624 Poll::Ready(Err(err)) => break Poll::Ready(Err(err.into())),
625 Poll::Ready(Ok(())) => {
626 let Some(message) = this.outgoing_messages.pop_front() else {
627 break Poll::Ready(Ok(()))
628 };
629 if let Err(err) = this.inner.as_mut().start_send(message) {
630 break Poll::Ready(Err(err.into()))
631 }
632 }
633 }
634 };
635
636 ready!(this.inner.as_mut().poll_flush(cx))?;
637
638 poll_res
639 }
640
641 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
642 ready!(self.as_mut().poll_flush(cx))?;
643 ready!(self.project().inner.poll_close(cx))?;
644
645 Poll::Ready(Ok(()))
646 }
647}
648
649#[derive(Debug, Clone, PartialEq, Eq)]
651#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
652#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
653#[add_arbitrary_tests(rlp)]
654pub enum P2PMessage {
655 Hello(HelloMessage),
657
658 Disconnect(DisconnectReason),
661
662 Ping,
664
665 Pong,
667}
668
669impl P2PMessage {
670 pub const fn message_id(&self) -> P2PMessageID {
672 match self {
673 Self::Hello(_) => P2PMessageID::Hello,
674 Self::Disconnect(_) => P2PMessageID::Disconnect,
675 Self::Ping => P2PMessageID::Ping,
676 Self::Pong => P2PMessageID::Pong,
677 }
678 }
679}
680
681impl Encodable for P2PMessage {
682 fn encode(&self, out: &mut dyn BufMut) {
687 (self.message_id() as u8).encode(out);
688 match self {
689 Self::Hello(msg) => msg.encode(out),
690 Self::Disconnect(msg) => msg.encode(out),
691 Self::Ping => {
692 out.put_u8(0x01);
694 out.put_u8(0x00);
695 out.put_u8(EMPTY_LIST_CODE);
696 }
697 Self::Pong => {
698 out.put_u8(0x01);
700 out.put_u8(0x00);
701 out.put_u8(EMPTY_LIST_CODE);
702 }
703 }
704 }
705
706 fn length(&self) -> usize {
707 let payload_len = match self {
708 Self::Hello(msg) => msg.length(),
709 Self::Disconnect(msg) => msg.length(),
710 Self::Ping | Self::Pong => 3, };
713 payload_len + 1 }
715}
716
717impl Decodable for P2PMessage {
718 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
725 fn advance_snappy_ping_pong_payload(buf: &mut &[u8]) -> alloy_rlp::Result<()> {
727 if buf.len() < 3 {
728 return Err(RlpError::InputTooShort)
729 }
730 if buf[..3] != [0x01, 0x00, EMPTY_LIST_CODE] {
731 return Err(RlpError::Custom("expected snappy payload"))
732 }
733 buf.advance(3);
734 Ok(())
735 }
736
737 let message_id = u8::decode(&mut &buf[..])?;
738 let id = P2PMessageID::try_from(message_id)
739 .or(Err(RlpError::Custom("unknown p2p message id")))?;
740 buf.advance(1);
741 match id {
742 P2PMessageID::Hello => Ok(Self::Hello(HelloMessage::decode(buf)?)),
743 P2PMessageID::Disconnect => Ok(Self::Disconnect(DisconnectReason::decode(buf)?)),
744 P2PMessageID::Ping => {
745 advance_snappy_ping_pong_payload(buf)?;
746 Ok(Self::Ping)
747 }
748 P2PMessageID::Pong => {
749 advance_snappy_ping_pong_payload(buf)?;
750 Ok(Self::Pong)
751 }
752 }
753 }
754}
755
756#[derive(Debug, Copy, Clone, Eq, PartialEq)]
758pub enum P2PMessageID {
759 Hello = 0x00,
761
762 Disconnect = 0x01,
764
765 Ping = 0x02,
767
768 Pong = 0x03,
770}
771
772impl From<P2PMessage> for P2PMessageID {
773 fn from(msg: P2PMessage) -> Self {
774 match msg {
775 P2PMessage::Hello(_) => Self::Hello,
776 P2PMessage::Disconnect(_) => Self::Disconnect,
777 P2PMessage::Ping => Self::Ping,
778 P2PMessage::Pong => Self::Pong,
779 }
780 }
781}
782
783impl TryFrom<u8> for P2PMessageID {
784 type Error = P2PStreamError;
785
786 fn try_from(id: u8) -> Result<Self, Self::Error> {
787 match id {
788 0x00 => Ok(Self::Hello),
789 0x01 => Ok(Self::Disconnect),
790 0x02 => Ok(Self::Ping),
791 0x03 => Ok(Self::Pong),
792 _ => Err(P2PStreamError::UnknownReservedMessageId(id)),
793 }
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800 use crate::{capability::SharedCapability, test_utils::eth_hello, EthVersion, ProtocolVersion};
801 use tokio::net::{TcpListener, TcpStream};
802 use tokio_util::codec::Decoder;
803
804 #[tokio::test]
805 async fn test_can_disconnect() {
806 reth_tracing::init_test_tracing();
807 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
808 let local_addr = listener.local_addr().unwrap();
809
810 let expected_disconnect = DisconnectReason::UselessPeer;
811
812 let handle = tokio::spawn(async move {
813 let (incoming, _) = listener.accept().await.unwrap();
815 let stream = crate::PassthroughCodec::default().framed(incoming);
816
817 let (server_hello, _) = eth_hello();
818
819 let (mut p2p_stream, _) =
820 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
821
822 p2p_stream.disconnect(expected_disconnect).await.unwrap();
823 });
824
825 let outgoing = TcpStream::connect(local_addr).await.unwrap();
826 let sink = crate::PassthroughCodec::default().framed(outgoing);
827
828 let (client_hello, _) = eth_hello();
829
830 let (mut p2p_stream, _) =
831 UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
832
833 let err = p2p_stream.next().await.unwrap().unwrap_err();
834 match err {
835 P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
836 e => panic!("unexpected err: {e}"),
837 }
838
839 handle.await.unwrap();
840 }
841
842 #[tokio::test]
843 async fn test_can_disconnect_weird_disconnect_encoding() {
844 reth_tracing::init_test_tracing();
845 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
846 let local_addr = listener.local_addr().unwrap();
847
848 let expected_disconnect = DisconnectReason::SubprotocolSpecific;
849
850 let handle = tokio::spawn(async move {
851 let (incoming, _) = listener.accept().await.unwrap();
853 let stream = crate::PassthroughCodec::default().framed(incoming);
854
855 let (server_hello, _) = eth_hello();
856
857 let (mut p2p_stream, _) =
858 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
859
860 p2p_stream.outgoing_messages.clear();
862
863 p2p_stream.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(
864 P2PMessage::Disconnect(DisconnectReason::SubprotocolSpecific),
865 )));
866 p2p_stream.disconnecting = true;
867 p2p_stream.close().await.unwrap();
868 });
869
870 let outgoing = TcpStream::connect(local_addr).await.unwrap();
871 let sink = crate::PassthroughCodec::default().framed(outgoing);
872
873 let (client_hello, _) = eth_hello();
874
875 let (mut p2p_stream, _) =
876 UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
877
878 let err = p2p_stream.next().await.unwrap().unwrap_err();
879 match err {
880 P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
881 e => panic!("unexpected err: {e}"),
882 }
883
884 handle.await.unwrap();
885 }
886
887 #[tokio::test]
888 async fn test_handshake_passthrough() {
889 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
892 let local_addr = listener.local_addr().unwrap();
893
894 let handle = tokio::spawn(async move {
895 let (incoming, _) = listener.accept().await.unwrap();
897 let stream = crate::PassthroughCodec::default().framed(incoming);
898
899 let (server_hello, _) = eth_hello();
900
901 let unauthed_stream = UnauthedP2PStream::new(stream);
902 let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
903
904 assert_eq!(
906 *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
907 SharedCapability::Eth {
908 version: EthVersion::Eth67,
909 offset: MAX_RESERVED_MESSAGE_ID + 1
910 }
911 );
912 });
913
914 let outgoing = TcpStream::connect(local_addr).await.unwrap();
915 let sink = crate::PassthroughCodec::default().framed(outgoing);
916
917 let (client_hello, _) = eth_hello();
918
919 let unauthed_stream = UnauthedP2PStream::new(sink);
920 let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
921
922 assert_eq!(
924 *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
925 SharedCapability::Eth {
926 version: EthVersion::Eth67,
927 offset: MAX_RESERVED_MESSAGE_ID + 1
928 }
929 );
930
931 handle.await.unwrap();
933 }
934
935 #[tokio::test]
936 async fn test_handshake_disconnect() {
937 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
940 let local_addr = listener.local_addr().unwrap();
941
942 let handle = tokio::spawn(Box::pin(async move {
943 let (incoming, _) = listener.accept().await.unwrap();
945 let stream = crate::PassthroughCodec::default().framed(incoming);
946
947 let (server_hello, _) = eth_hello();
948
949 let unauthed_stream = UnauthedP2PStream::new(stream);
950 match unauthed_stream.handshake(server_hello.clone()).await {
951 Ok((_, hello)) => {
952 panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
953 }
954 Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
955 assert_ne!(expected, got);
956 assert_eq!(expected, server_hello.protocol_version);
957 }
958 Err(other_err) => {
959 panic!("expected mismatched protocol version error, got {other_err:?}")
960 }
961 }
962 }));
963
964 let outgoing = TcpStream::connect(local_addr).await.unwrap();
965 let sink = crate::PassthroughCodec::default().framed(outgoing);
966
967 let (mut client_hello, _) = eth_hello();
968
969 client_hello.protocol_version = ProtocolVersion::V4;
971
972 let unauthed_stream = UnauthedP2PStream::new(sink);
973 match unauthed_stream.handshake(client_hello.clone()).await {
974 Ok((_, hello)) => {
975 panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
976 }
977 Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
978 assert_ne!(expected, got);
979 assert_eq!(expected, client_hello.protocol_version);
980 }
981 Err(other_err) => {
982 panic!("expected mismatched protocol version error, got {other_err:?}")
983 }
984 }
985
986 handle.await.unwrap();
988 }
989
990 #[test]
991 fn snappy_decode_encode_ping() {
992 let snappy_ping = b"\x02\x01\0\xc0";
993 let ping = P2PMessage::decode(&mut &snappy_ping[..]).unwrap();
994 assert!(matches!(ping, P2PMessage::Ping));
995 assert_eq!(alloy_rlp::encode(ping), &snappy_ping[..]);
996 }
997
998 #[test]
999 fn snappy_decode_encode_pong() {
1000 let snappy_pong = b"\x03\x01\0\xc0";
1001 let pong = P2PMessage::decode(&mut &snappy_pong[..]).unwrap();
1002 assert!(matches!(pong, P2PMessage::Pong));
1003 assert_eq!(alloy_rlp::encode(pong), &snappy_pong[..]);
1004 }
1005}