reth_eth_wire/
handshake.rs1use crate::{
2 errors::{EthHandshakeError, EthStreamError, P2PStreamError},
3 ethstream::MAX_STATUS_SIZE,
4 CanDisconnect,
5};
6use bytes::{Bytes, BytesMut};
7use futures::{Sink, SinkExt, Stream};
8use reth_eth_wire_types::{
9 DisconnectReason, EthMessage, EthNetworkPrimitives, ProtocolMessage, StatusMessage,
10 UnifiedStatus,
11};
12use reth_ethereum_forks::ForkFilter;
13use reth_primitives_traits::GotExpected;
14use std::{fmt::Debug, future::Future, pin::Pin, time::Duration};
15use tokio::time::timeout;
16use tokio_stream::StreamExt;
17use tracing::{debug, trace};
18
19pub trait EthRlpxHandshake: Debug + Send + Sync + 'static {
21 fn handshake<'a>(
23 &'a self,
24 unauth: &'a mut dyn UnauthEth,
25 status: UnifiedStatus,
26 fork_filter: ForkFilter,
27 timeout_limit: Duration,
28 ) -> Pin<Box<dyn Future<Output = Result<UnifiedStatus, EthStreamError>> + 'a + Send>>;
29}
30
31pub trait UnauthEth:
33 Stream<Item = Result<BytesMut, P2PStreamError>>
34 + Sink<Bytes, Error = P2PStreamError>
35 + CanDisconnect<Bytes>
36 + Unpin
37 + Send
38{
39}
40
41impl<T> UnauthEth for T where
42 T: Stream<Item = Result<BytesMut, P2PStreamError>>
43 + Sink<Bytes, Error = P2PStreamError>
44 + CanDisconnect<Bytes>
45 + Unpin
46 + Send
47{
48}
49
50#[derive(Debug, Default, Clone)]
54#[non_exhaustive]
55pub struct EthHandshake;
56
57impl EthRlpxHandshake for EthHandshake {
58 fn handshake<'a>(
59 &'a self,
60 unauth: &'a mut dyn UnauthEth,
61 status: UnifiedStatus,
62 fork_filter: ForkFilter,
63 timeout_limit: Duration,
64 ) -> Pin<Box<dyn Future<Output = Result<UnifiedStatus, EthStreamError>> + 'a + Send>> {
65 Box::pin(async move {
66 timeout(timeout_limit, EthereumEthHandshake(unauth).eth_handshake(status, fork_filter))
67 .await
68 .map_err(|_| EthStreamError::StreamTimeout)?
69 })
70 }
71}
72
73#[derive(Debug)]
75pub struct EthereumEthHandshake<'a, S: ?Sized>(pub &'a mut S);
76
77impl<S: ?Sized, E> EthereumEthHandshake<'_, S>
78where
79 S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
80 EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
81{
82 pub async fn eth_handshake(
84 self,
85 unified_status: UnifiedStatus,
86 fork_filter: ForkFilter,
87 ) -> Result<UnifiedStatus, EthStreamError> {
88 let unauth = self.0;
89
90 let status = unified_status.into_message();
91
92 let status_msg = alloy_rlp::encode(ProtocolMessage::<EthNetworkPrimitives>::from(
94 EthMessage::Status(status),
95 ))
96 .into();
97 unauth.send(status_msg).await.map_err(EthStreamError::from)?;
98
99 let their_msg_res = unauth.next().await;
101 let their_msg = match their_msg_res {
102 Some(Ok(msg)) => msg,
103 Some(Err(e)) => return Err(EthStreamError::from(e)),
104 None => {
105 unauth
106 .disconnect(DisconnectReason::DisconnectRequested)
107 .await
108 .map_err(EthStreamError::from)?;
109 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse));
110 }
111 };
112
113 if their_msg.len() > MAX_STATUS_SIZE {
114 unauth
115 .disconnect(DisconnectReason::ProtocolBreach)
116 .await
117 .map_err(EthStreamError::from)?;
118 return Err(EthStreamError::MessageTooBig(their_msg.len()));
119 }
120
121 let version = status.version();
122 let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
123 version,
124 &mut their_msg.as_ref(),
125 ) {
126 Ok(m) => m,
127 Err(err) => {
128 debug!("decode error in eth handshake: msg={their_msg:x}");
129 unauth
130 .disconnect(DisconnectReason::DisconnectRequested)
131 .await
132 .map_err(EthStreamError::from)?;
133 return Err(EthStreamError::InvalidMessage(err));
134 }
135 };
136
137 match msg.message {
139 EthMessage::Status(their_status_message) => {
140 trace!("Validating incoming ETH status from peer");
141
142 if status.genesis() != their_status_message.genesis() {
143 unauth
144 .disconnect(DisconnectReason::ProtocolBreach)
145 .await
146 .map_err(EthStreamError::from)?;
147 return Err(EthHandshakeError::MismatchedGenesis(
148 GotExpected {
149 expected: status.genesis(),
150 got: their_status_message.genesis(),
151 }
152 .into(),
153 )
154 .into());
155 }
156
157 if status.version() != their_status_message.version() {
158 unauth
159 .disconnect(DisconnectReason::ProtocolBreach)
160 .await
161 .map_err(EthStreamError::from)?;
162 return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
163 got: their_status_message.version(),
164 expected: status.version(),
165 })
166 .into());
167 }
168
169 if *status.chain() != *their_status_message.chain() {
170 unauth
171 .disconnect(DisconnectReason::ProtocolBreach)
172 .await
173 .map_err(EthStreamError::from)?;
174 return Err(EthHandshakeError::MismatchedChain(GotExpected {
175 got: *their_status_message.chain(),
176 expected: *status.chain(),
177 })
178 .into());
179 }
180
181 if let StatusMessage::Legacy(s) = status {
183 if s.total_difficulty.bit_len() > 160 {
184 unauth
185 .disconnect(DisconnectReason::ProtocolBreach)
186 .await
187 .map_err(EthStreamError::from)?;
188 return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
189 got: s.total_difficulty.bit_len(),
190 maximum: 160,
191 }
192 .into());
193 }
194 }
195
196 if let Err(err) = fork_filter
198 .validate(their_status_message.forkid())
199 .map_err(EthHandshakeError::InvalidFork)
200 {
201 unauth
202 .disconnect(DisconnectReason::ProtocolBreach)
203 .await
204 .map_err(EthStreamError::from)?;
205 return Err(err.into());
206 }
207
208 Ok(UnifiedStatus::from_message(their_status_message))
209 }
210 _ => {
211 unauth
212 .disconnect(DisconnectReason::ProtocolBreach)
213 .await
214 .map_err(EthStreamError::from)?;
215 Err(EthStreamError::EthHandshakeError(
216 EthHandshakeError::NonStatusMessageInHandshake,
217 ))
218 }
219 }
220 }
221}