1use crate::{
2 capability::SharedCapabilities,
3 disconnect::CanDisconnect,
4 errors::{P2PHandshakeError, P2PStreamError},
5 pinger::{Pinger, PingerEvent},
6 DisconnectReason, HelloMessage, HelloMessageWithProtocols,
7};
8use alloy_primitives::{
9 bytes::{Buf, BufMut, Bytes, BytesMut},
10 hex,
11};
12use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE};
13use futures::{Sink, SinkExt, StreamExt};
14use pin_project::pin_project;
15use reth_codecs::add_arbitrary_tests;
16use reth_metrics::metrics::counter;
17use reth_primitives_traits::GotExpected;
18use std::{
19 collections::VecDeque,
20 future::Future,
21 io,
22 pin::Pin,
23 task::{ready, Context, Poll},
24 time::Duration,
25};
26use tokio_stream::Stream;
27use tracing::{debug, trace};
28
29#[cfg(feature = "serde")]
30use serde::{Deserialize, Serialize};
31
32const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36pub const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f;
39
40const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
42
43pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
46
47const PING_TIMEOUT: Duration = Duration::from_secs(15);
50
51const PING_INTERVAL: Duration = Duration::from_secs(60);
54
55const MAX_P2P_CAPACITY: usize = 2;
62
63#[pin_project]
66#[derive(Debug)]
67pub struct UnauthedP2PStream<S> {
68 #[pin]
69 inner: S,
70}
71
72impl<S> UnauthedP2PStream<S> {
73 pub const fn new(inner: S) -> Self {
75 Self { inner }
76 }
77
78 pub const fn inner(&self) -> &S {
80 &self.inner
81 }
82}
83
84impl<S> UnauthedP2PStream<S>
85where
86 S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
87{
88 pub async fn handshake(
91 mut self,
92 hello: HelloMessageWithProtocols,
93 ) -> Result<(P2PStream<S>, HelloMessage), P2PStreamError> {
94 trace!(?hello, "sending p2p hello to peer");
95
96 self.inner.send(alloy_rlp::encode(P2PMessage::Hello(hello.message())).into()).await?;
98
99 let first_message_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next())
100 .await
101 .or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))?
102 .ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??;
103
104 if first_message_bytes.len() > MAX_PAYLOAD_SIZE {
108 return Err(P2PStreamError::MessageTooBig {
109 message_size: first_message_bytes.len(),
110 max_size: MAX_PAYLOAD_SIZE,
111 })
112 }
113
114 let their_hello = match P2PMessage::decode(&mut &first_message_bytes[..]) {
121 Ok(P2PMessage::Hello(hello)) => Ok(hello),
122 Ok(P2PMessage::Disconnect(reason)) => {
123 if matches!(reason, DisconnectReason::TooManyPeers) {
124 trace!(%reason, "Disconnected by peer during handshake");
126 } else {
127 debug!(%reason, "Disconnected by peer during handshake");
128 };
129 counter!("p2pstream.disconnected_errors").increment(1);
130 Err(P2PStreamError::HandshakeError(P2PHandshakeError::Disconnected(reason)))
131 }
132 Err(err) => {
133 debug!(%err, msg=%hex::encode(&first_message_bytes), "Failed to decode first message from peer");
134 Err(P2PStreamError::HandshakeError(err.into()))
135 }
136 Ok(msg) => {
137 debug!(?msg, "expected hello message but received another message");
138 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake))
139 }
140 }?;
141
142 trace!(
143 hello=?their_hello,
144 "validating incoming p2p hello from peer"
145 );
146
147 if (hello.protocol_version as u8) != their_hello.protocol_version as u8 {
148 self.send_disconnect(DisconnectReason::IncompatibleP2PProtocolVersion).await?;
150 return Err(P2PStreamError::MismatchedProtocolVersion(GotExpected {
151 got: their_hello.protocol_version,
152 expected: hello.protocol_version,
153 }))
154 }
155
156 let capability_res =
158 SharedCapabilities::try_new(hello.protocols, their_hello.capabilities.clone());
159
160 let shared_capability = match capability_res {
161 Err(err) => {
162 self.send_disconnect(DisconnectReason::UselessPeer).await?;
164 Err(err)
165 }
166 Ok(cap) => Ok(cap),
167 }?;
168
169 let stream = P2PStream::new(self.inner, shared_capability);
170
171 Ok((stream, their_hello))
172 }
173}
174
175impl<S> UnauthedP2PStream<S>
176where
177 S: Sink<Bytes, Error = io::Error> + Unpin,
178{
179 pub async fn send_disconnect(
181 &mut self,
182 reason: DisconnectReason,
183 ) -> Result<(), P2PStreamError> {
184 trace!(
185 %reason,
186 "Sending disconnect message during the handshake",
187 );
188 self.inner
189 .send(Bytes::from(alloy_rlp::encode(P2PMessage::Disconnect(reason))))
190 .await
191 .map_err(P2PStreamError::Io)
192 }
193}
194
195impl<S> CanDisconnect<Bytes> for P2PStream<S>
196where
197 S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
198{
199 fn disconnect(
200 &mut self,
201 reason: DisconnectReason,
202 ) -> Pin<Box<dyn Future<Output = Result<(), P2PStreamError>> + Send + '_>> {
203 Box::pin(async move { self.disconnect(reason).await })
204 }
205}
206
207#[pin_project]
232#[derive(Debug)]
233pub struct P2PStream<S> {
234 #[pin]
235 inner: S,
236
237 encoder: snap::raw::Encoder,
239
240 decoder: snap::raw::Decoder,
242
243 pinger: Pinger,
245
246 shared_capabilities: SharedCapabilities,
248
249 outgoing_messages: VecDeque<Bytes>,
251
252 outgoing_message_buffer_capacity: usize,
255
256 disconnecting: bool,
259}
260
261impl<S> P2PStream<S> {
262 pub fn new(inner: S, shared_capabilities: SharedCapabilities) -> Self {
266 Self {
267 inner,
268 encoder: snap::raw::Encoder::new(),
269 decoder: snap::raw::Decoder::new(),
270 pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
271 shared_capabilities,
272 outgoing_messages: VecDeque::new(),
273 outgoing_message_buffer_capacity: MAX_P2P_CAPACITY,
274 disconnecting: false,
275 }
276 }
277
278 pub const fn inner(&self) -> &S {
280 &self.inner
281 }
282
283 pub const fn set_outgoing_message_buffer_capacity(&mut self, capacity: usize) {
289 self.outgoing_message_buffer_capacity = capacity;
290 }
291
292 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
297 &self.shared_capabilities
298 }
299
300 fn has_outgoing_capacity(&self) -> bool {
302 self.outgoing_messages.len() < self.outgoing_message_buffer_capacity
303 }
304
305 fn send_pong(&mut self) {
307 self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Pong)));
308 }
309
310 pub fn send_ping(&mut self) {
312 self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Ping)));
313 }
314}
315
316pub trait DisconnectP2P {
319 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>;
321
322 fn is_disconnecting(&self) -> bool;
324}
325
326impl<S> DisconnectP2P for P2PStream<S> {
327 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
336 self.outgoing_messages.clear();
338 let disconnect = P2PMessage::Disconnect(reason);
339 let mut buf = Vec::with_capacity(disconnect.length());
340 disconnect.encode(&mut buf);
341
342 let mut compressed = vec![0u8; 1 + snap::raw::max_compress_len(buf.len() - 1)];
343 let compressed_size =
344 self.encoder.compress(&buf[1..], &mut compressed[1..]).map_err(|err| {
345 debug!(
346 %err,
347 msg=%hex::encode(&buf[1..]),
348 "error compressing disconnect"
349 );
350 err
351 })?;
352
353 compressed.truncate(compressed_size + 1);
356
357 compressed[0] = buf[0];
360
361 self.outgoing_messages.push_back(compressed.into());
362 self.disconnecting = true;
363 Ok(())
364 }
365
366 fn is_disconnecting(&self) -> bool {
367 self.disconnecting
368 }
369}
370
371impl<S> P2PStream<S>
372where
373 S: Sink<Bytes, Error = io::Error> + Unpin + Send,
374{
375 pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
380 self.start_disconnect(reason)?;
381 self.close().await
382 }
383}
384
385impl<S> Stream for P2PStream<S>
388where
389 S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
390{
391 type Item = Result<BytesMut, P2PStreamError>;
392
393 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
394 let this = self.get_mut();
395
396 if this.disconnecting {
397 return Poll::Ready(None)
399 }
400
401 while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) {
404 let bytes = match res {
405 Some(Ok(bytes)) => bytes,
406 Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
407 None => return Poll::Ready(None),
408 };
409
410 if bytes.is_empty() {
411 return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
413 }
414
415 let id = bytes[0];
420 if id == P2PMessageID::Disconnect as u8 {
421 if let Ok(reason) = DisconnectReason::decode(&mut &bytes[1..]) {
433 return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
434 }
435 }
436
437 let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
440 if decompressed_len > MAX_PAYLOAD_SIZE {
441 return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
442 message_size: decompressed_len,
443 max_size: MAX_PAYLOAD_SIZE,
444 })))
445 }
446
447 let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
450
451 this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..]).map_err(|err| {
454 debug!(
455 %err,
456 msg=%hex::encode(&bytes[1..]),
457 "error decompressing p2p message"
458 );
459 err
460 })?;
461
462 match id {
463 _ if id == P2PMessageID::Ping as u8 => {
464 trace!("Received Ping, Sending Pong");
465 this.send_pong();
466 cx.waker().wake_by_ref();
469 }
470 _ if id == P2PMessageID::Hello as u8 => {
471 return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
474 P2PHandshakeError::HelloNotInHandshake,
475 ))))
476 }
477 _ if id == P2PMessageID::Pong as u8 => {
478 this.pinger.on_pong()?
480 }
481 _ if id == P2PMessageID::Disconnect as u8 => {
482 let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).inspect_err(|err| {
488 debug!(
489 %err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
490 );
491 })?;
492 return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
493 }
494 _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
495 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
497 }
498 _ => {
499 decompress_buf[0] = bytes[0] - MAX_RESERVED_MESSAGE_ID - 1;
523
524 return Poll::Ready(Some(Ok(decompress_buf)))
525 }
526 }
527 }
528
529 Poll::Pending
530 }
531}
532
533impl<S> Sink<Bytes> for P2PStream<S>
534where
535 S: Sink<Bytes, Error = io::Error> + Unpin,
536{
537 type Error = P2PStreamError;
538
539 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
540 let mut this = self.as_mut();
541
542 match this.pinger.poll_ping(cx) {
544 Poll::Pending => {}
545 Poll::Ready(Ok(PingerEvent::Ping)) => {
546 this.send_ping();
547 }
548 _ => {
549 this.start_disconnect(DisconnectReason::PingTimeout)?;
551
552 return Poll::Ready(Ok(()))
554 }
555 }
556
557 match this.inner.poll_ready_unpin(cx) {
558 Poll::Pending => {}
559 Poll::Ready(Err(err)) => return Poll::Ready(Err(P2PStreamError::Io(err))),
560 Poll::Ready(Ok(())) => {
561 let flushed = this.poll_flush(cx);
562 if flushed.is_ready() {
563 return flushed
564 }
565 }
566 }
567
568 if self.has_outgoing_capacity() {
569 Poll::Ready(Ok(()))
571 } else {
572 Poll::Pending
573 }
574 }
575
576 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
577 if item.len() > MAX_PAYLOAD_SIZE {
578 return Err(P2PStreamError::MessageTooBig {
579 message_size: item.len(),
580 max_size: MAX_PAYLOAD_SIZE,
581 })
582 }
583
584 if item.is_empty() {
585 return Err(P2PStreamError::EmptyProtocolMessage)
587 }
588
589 if !self.has_outgoing_capacity() {
591 return Err(P2PStreamError::SendBufferFull)
592 }
593
594 let this = self.project();
595
596 let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
597 let compressed_size =
598 this.encoder.compress(&item[1..], &mut compressed[1..]).map_err(|err| {
599 debug!(
600 %err,
601 msg=%hex::encode(&item[1..]),
602 "error compressing p2p message"
603 );
604 err
605 })?;
606
607 compressed.truncate(compressed_size + 1);
610
611 compressed[0] = item[0] + MAX_RESERVED_MESSAGE_ID + 1;
614 this.outgoing_messages.push_back(compressed.freeze());
615
616 Ok(())
617 }
618
619 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
621 let mut this = self.project();
622 let poll_res = loop {
623 match this.inner.as_mut().poll_ready(cx) {
624 Poll::Pending => break Poll::Pending,
625 Poll::Ready(Err(err)) => break Poll::Ready(Err(err.into())),
626 Poll::Ready(Ok(())) => {
627 let Some(message) = this.outgoing_messages.pop_front() else {
628 break Poll::Ready(Ok(()))
629 };
630 if let Err(err) = this.inner.as_mut().start_send(message) {
631 break Poll::Ready(Err(err.into()))
632 }
633 }
634 }
635 };
636
637 ready!(this.inner.as_mut().poll_flush(cx))?;
638
639 poll_res
640 }
641
642 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
643 ready!(self.as_mut().poll_flush(cx))?;
644 ready!(self.project().inner.poll_close(cx))?;
645
646 Poll::Ready(Ok(()))
647 }
648}
649
650#[derive(Debug, Clone, PartialEq, Eq)]
652#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
653#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
654#[add_arbitrary_tests(rlp)]
655pub enum P2PMessage {
656 Hello(HelloMessage),
658
659 Disconnect(DisconnectReason),
662
663 Ping,
665
666 Pong,
668}
669
670impl P2PMessage {
671 pub const fn message_id(&self) -> P2PMessageID {
673 match self {
674 Self::Hello(_) => P2PMessageID::Hello,
675 Self::Disconnect(_) => P2PMessageID::Disconnect,
676 Self::Ping => P2PMessageID::Ping,
677 Self::Pong => P2PMessageID::Pong,
678 }
679 }
680}
681
682impl Encodable for P2PMessage {
683 fn encode(&self, out: &mut dyn BufMut) {
688 (self.message_id() as u8).encode(out);
689 match self {
690 Self::Hello(msg) => msg.encode(out),
691 Self::Disconnect(msg) => msg.encode(out),
692 Self::Ping => {
693 out.put_u8(0x01);
695 out.put_u8(0x00);
696 out.put_u8(EMPTY_LIST_CODE);
697 }
698 Self::Pong => {
699 out.put_u8(0x01);
701 out.put_u8(0x00);
702 out.put_u8(EMPTY_LIST_CODE);
703 }
704 }
705 }
706
707 fn length(&self) -> usize {
708 let payload_len = match self {
709 Self::Hello(msg) => msg.length(),
710 Self::Disconnect(msg) => msg.length(),
711 Self::Ping | Self::Pong => 3, };
714 payload_len + 1 }
716}
717
718impl Decodable for P2PMessage {
719 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
726 fn advance_snappy_ping_pong_payload(buf: &mut &[u8]) -> alloy_rlp::Result<()> {
728 if buf.len() < 3 {
729 return Err(RlpError::InputTooShort)
730 }
731 if buf[..3] != [0x01, 0x00, EMPTY_LIST_CODE] {
732 return Err(RlpError::Custom("expected snappy payload"))
733 }
734 buf.advance(3);
735 Ok(())
736 }
737
738 let message_id = u8::decode(&mut &buf[..])?;
739 let id = P2PMessageID::try_from(message_id)
740 .or(Err(RlpError::Custom("unknown p2p message id")))?;
741 buf.advance(1);
742 match id {
743 P2PMessageID::Hello => Ok(Self::Hello(HelloMessage::decode(buf)?)),
744 P2PMessageID::Disconnect => Ok(Self::Disconnect(DisconnectReason::decode(buf)?)),
745 P2PMessageID::Ping => {
746 advance_snappy_ping_pong_payload(buf)?;
747 Ok(Self::Ping)
748 }
749 P2PMessageID::Pong => {
750 advance_snappy_ping_pong_payload(buf)?;
751 Ok(Self::Pong)
752 }
753 }
754 }
755}
756
757#[derive(Debug, Copy, Clone, Eq, PartialEq)]
759pub enum P2PMessageID {
760 Hello = 0x00,
762
763 Disconnect = 0x01,
765
766 Ping = 0x02,
768
769 Pong = 0x03,
771}
772
773impl From<P2PMessage> for P2PMessageID {
774 fn from(msg: P2PMessage) -> Self {
775 match msg {
776 P2PMessage::Hello(_) => Self::Hello,
777 P2PMessage::Disconnect(_) => Self::Disconnect,
778 P2PMessage::Ping => Self::Ping,
779 P2PMessage::Pong => Self::Pong,
780 }
781 }
782}
783
784impl TryFrom<u8> for P2PMessageID {
785 type Error = P2PStreamError;
786
787 fn try_from(id: u8) -> Result<Self, Self::Error> {
788 match id {
789 0x00 => Ok(Self::Hello),
790 0x01 => Ok(Self::Disconnect),
791 0x02 => Ok(Self::Ping),
792 0x03 => Ok(Self::Pong),
793 _ => Err(P2PStreamError::UnknownReservedMessageId(id)),
794 }
795 }
796}
797
798#[cfg(test)]
799mod tests {
800 use super::*;
801 use crate::{capability::SharedCapability, test_utils::eth_hello, EthVersion, ProtocolVersion};
802 use tokio::net::{TcpListener, TcpStream};
803 use tokio_util::codec::Decoder;
804
805 #[tokio::test]
806 async fn test_can_disconnect() {
807 reth_tracing::init_test_tracing();
808 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
809 let local_addr = listener.local_addr().unwrap();
810
811 let expected_disconnect = DisconnectReason::UselessPeer;
812
813 let handle = tokio::spawn(async move {
814 let (incoming, _) = listener.accept().await.unwrap();
816 let stream = crate::PassthroughCodec::default().framed(incoming);
817
818 let (server_hello, _) = eth_hello();
819
820 let (mut p2p_stream, _) =
821 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
822
823 p2p_stream.disconnect(expected_disconnect).await.unwrap();
824 });
825
826 let outgoing = TcpStream::connect(local_addr).await.unwrap();
827 let sink = crate::PassthroughCodec::default().framed(outgoing);
828
829 let (client_hello, _) = eth_hello();
830
831 let (mut p2p_stream, _) =
832 UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
833
834 let err = p2p_stream.next().await.unwrap().unwrap_err();
835 match err {
836 P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
837 e => panic!("unexpected err: {e}"),
838 }
839
840 handle.await.unwrap();
841 }
842
843 #[tokio::test]
844 async fn test_can_disconnect_weird_disconnect_encoding() {
845 reth_tracing::init_test_tracing();
846 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
847 let local_addr = listener.local_addr().unwrap();
848
849 let expected_disconnect = DisconnectReason::SubprotocolSpecific;
850
851 let handle = tokio::spawn(async move {
852 let (incoming, _) = listener.accept().await.unwrap();
854 let stream = crate::PassthroughCodec::default().framed(incoming);
855
856 let (server_hello, _) = eth_hello();
857
858 let (mut p2p_stream, _) =
859 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
860
861 p2p_stream.outgoing_messages.clear();
863
864 p2p_stream.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(
865 P2PMessage::Disconnect(DisconnectReason::SubprotocolSpecific),
866 )));
867 p2p_stream.disconnecting = true;
868 p2p_stream.close().await.unwrap();
869 });
870
871 let outgoing = TcpStream::connect(local_addr).await.unwrap();
872 let sink = crate::PassthroughCodec::default().framed(outgoing);
873
874 let (client_hello, _) = eth_hello();
875
876 let (mut p2p_stream, _) =
877 UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
878
879 let err = p2p_stream.next().await.unwrap().unwrap_err();
880 match err {
881 P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
882 e => panic!("unexpected err: {e}"),
883 }
884
885 handle.await.unwrap();
886 }
887
888 #[tokio::test]
889 async fn test_handshake_passthrough() {
890 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
893 let local_addr = listener.local_addr().unwrap();
894
895 let handle = tokio::spawn(async move {
896 let (incoming, _) = listener.accept().await.unwrap();
898 let stream = crate::PassthroughCodec::default().framed(incoming);
899
900 let (server_hello, _) = eth_hello();
901
902 let unauthed_stream = UnauthedP2PStream::new(stream);
903 let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
904
905 assert_eq!(
907 *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
908 SharedCapability::Eth {
909 version: EthVersion::Eth67,
910 offset: MAX_RESERVED_MESSAGE_ID + 1
911 }
912 );
913 });
914
915 let outgoing = TcpStream::connect(local_addr).await.unwrap();
916 let sink = crate::PassthroughCodec::default().framed(outgoing);
917
918 let (client_hello, _) = eth_hello();
919
920 let unauthed_stream = UnauthedP2PStream::new(sink);
921 let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
922
923 assert_eq!(
925 *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
926 SharedCapability::Eth {
927 version: EthVersion::Eth67,
928 offset: MAX_RESERVED_MESSAGE_ID + 1
929 }
930 );
931
932 handle.await.unwrap();
934 }
935
936 #[tokio::test]
937 async fn test_handshake_disconnect() {
938 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
941 let local_addr = listener.local_addr().unwrap();
942
943 let handle = tokio::spawn(Box::pin(async move {
944 let (incoming, _) = listener.accept().await.unwrap();
946 let stream = crate::PassthroughCodec::default().framed(incoming);
947
948 let (server_hello, _) = eth_hello();
949
950 let unauthed_stream = UnauthedP2PStream::new(stream);
951 match unauthed_stream.handshake(server_hello.clone()).await {
952 Ok((_, hello)) => {
953 panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
954 }
955 Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
956 assert_ne!(expected, got);
957 assert_eq!(expected, server_hello.protocol_version);
958 }
959 Err(other_err) => {
960 panic!("expected mismatched protocol version error, got {other_err:?}")
961 }
962 }
963 }));
964
965 let outgoing = TcpStream::connect(local_addr).await.unwrap();
966 let sink = crate::PassthroughCodec::default().framed(outgoing);
967
968 let (mut client_hello, _) = eth_hello();
969
970 client_hello.protocol_version = ProtocolVersion::V4;
972
973 let unauthed_stream = UnauthedP2PStream::new(sink);
974 match unauthed_stream.handshake(client_hello.clone()).await {
975 Ok((_, hello)) => {
976 panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
977 }
978 Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
979 assert_ne!(expected, got);
980 assert_eq!(expected, client_hello.protocol_version);
981 }
982 Err(other_err) => {
983 panic!("expected mismatched protocol version error, got {other_err:?}")
984 }
985 }
986
987 handle.await.unwrap();
989 }
990
991 #[test]
992 fn snappy_decode_encode_ping() {
993 let snappy_ping = b"\x02\x01\0\xc0";
994 let ping = P2PMessage::decode(&mut &snappy_ping[..]).unwrap();
995 assert!(matches!(ping, P2PMessage::Ping));
996 assert_eq!(alloy_rlp::encode(ping), &snappy_ping[..]);
997 }
998
999 #[test]
1000 fn snappy_decode_encode_pong() {
1001 let snappy_pong = b"\x03\x01\0\xc0";
1002 let pong = P2PMessage::decode(&mut &snappy_pong[..]).unwrap();
1003 assert!(matches!(pong, P2PMessage::Pong));
1004 assert_eq!(alloy_rlp::encode(pong), &snappy_pong[..]);
1005 }
1006}