reth_eth_wire/
handshake.rs
1use crate::{
2 errors::{EthHandshakeError, EthStreamError, P2PStreamError},
3 ethstream::MAX_STATUS_SIZE,
4 CanDisconnect,
5};
6use bytes::{Bytes, BytesMut};
7use futures::{Sink, SinkExt, Stream};
8use reth_eth_wire_types::{
9 DisconnectReason, EthMessage, EthNetworkPrimitives, ProtocolMessage, Status, StatusMessage,
10};
11use reth_ethereum_forks::ForkFilter;
12use reth_primitives_traits::GotExpected;
13use std::{fmt::Debug, future::Future, pin::Pin, time::Duration};
14use tokio::time::timeout;
15use tokio_stream::StreamExt;
16use tracing::{debug, trace};
17
18pub trait EthRlpxHandshake: Debug + Send + Sync + 'static {
20 fn handshake<'a>(
22 &'a self,
23 unauth: &'a mut dyn UnauthEth,
24 status: Status,
25 fork_filter: ForkFilter,
26 timeout_limit: Duration,
27 ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>>;
28}
29
30pub trait UnauthEth:
32 Stream<Item = Result<BytesMut, P2PStreamError>>
33 + Sink<Bytes, Error = P2PStreamError>
34 + CanDisconnect<Bytes>
35 + Unpin
36 + Send
37{
38}
39
40impl<T> UnauthEth for T where
41 T: Stream<Item = Result<BytesMut, P2PStreamError>>
42 + Sink<Bytes, Error = P2PStreamError>
43 + CanDisconnect<Bytes>
44 + Unpin
45 + Send
46{
47}
48
49#[derive(Debug, Default, Clone)]
53#[non_exhaustive]
54pub struct EthHandshake;
55
56impl EthRlpxHandshake for EthHandshake {
57 fn handshake<'a>(
58 &'a self,
59 unauth: &'a mut dyn UnauthEth,
60 status: Status,
61 fork_filter: ForkFilter,
62 timeout_limit: Duration,
63 ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>> {
64 Box::pin(async move {
65 timeout(timeout_limit, EthereumEthHandshake(unauth).eth_handshake(status, fork_filter))
66 .await
67 .map_err(|_| EthStreamError::StreamTimeout)?
68 })
69 }
70}
71
72#[derive(Debug)]
74pub struct EthereumEthHandshake<'a, S: ?Sized>(pub &'a mut S);
75
76impl<S: ?Sized, E> EthereumEthHandshake<'_, S>
77where
78 S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
79 EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
80{
81 pub async fn eth_handshake(
83 self,
84 status: Status,
85 fork_filter: ForkFilter,
86 ) -> Result<Status, EthStreamError> {
87 let unauth = self.0;
88 let status_msg =
90 alloy_rlp::encode(ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::<
91 EthNetworkPrimitives,
92 >::Status(
93 StatusMessage::Legacy(status),
94 )))
95 .into();
96 unauth.send(status_msg).await.map_err(EthStreamError::from)?;
97
98 let their_msg_res = unauth.next().await;
100 let their_msg = match their_msg_res {
101 Some(Ok(msg)) => msg,
102 Some(Err(e)) => return Err(EthStreamError::from(e)),
103 None => {
104 unauth
105 .disconnect(DisconnectReason::DisconnectRequested)
106 .await
107 .map_err(EthStreamError::from)?;
108 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse));
109 }
110 };
111
112 if their_msg.len() > MAX_STATUS_SIZE {
113 unauth
114 .disconnect(DisconnectReason::ProtocolBreach)
115 .await
116 .map_err(EthStreamError::from)?;
117 return Err(EthStreamError::MessageTooBig(their_msg.len()));
118 }
119
120 let version = status.version;
121 let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
122 version,
123 &mut their_msg.as_ref(),
124 ) {
125 Ok(m) => m,
126 Err(err) => {
127 debug!("decode error in eth handshake: msg={their_msg:x}");
128 unauth
129 .disconnect(DisconnectReason::DisconnectRequested)
130 .await
131 .map_err(EthStreamError::from)?;
132 return Err(EthStreamError::InvalidMessage(err));
133 }
134 };
135
136 match msg.message {
138 EthMessage::Status(their_status_message) => {
139 trace!("Validating incoming ETH status from peer");
140
141 if status.genesis != their_status_message.genesis() {
142 unauth
143 .disconnect(DisconnectReason::ProtocolBreach)
144 .await
145 .map_err(EthStreamError::from)?;
146 return Err(EthHandshakeError::MismatchedGenesis(
147 GotExpected {
148 expected: status.genesis,
149 got: their_status_message.genesis(),
150 }
151 .into(),
152 )
153 .into());
154 }
155
156 if status.version != their_status_message.version() {
157 unauth
158 .disconnect(DisconnectReason::ProtocolBreach)
159 .await
160 .map_err(EthStreamError::from)?;
161 return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
162 got: their_status_message.version(),
163 expected: status.version,
164 })
165 .into());
166 }
167
168 if status.chain != *their_status_message.chain() {
169 unauth
170 .disconnect(DisconnectReason::ProtocolBreach)
171 .await
172 .map_err(EthStreamError::from)?;
173 return Err(EthHandshakeError::MismatchedChain(GotExpected {
174 got: *their_status_message.chain(),
175 expected: status.chain,
176 })
177 .into());
178 }
179
180 if status.total_difficulty.bit_len() > 160 {
182 unauth
183 .disconnect(DisconnectReason::ProtocolBreach)
184 .await
185 .map_err(EthStreamError::from)?;
186 return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
187 got: status.total_difficulty.bit_len(),
188 maximum: 160,
189 }
190 .into());
191 }
192
193 if let Err(err) = fork_filter
195 .validate(their_status_message.forkid())
196 .map_err(EthHandshakeError::InvalidFork)
197 {
198 unauth
199 .disconnect(DisconnectReason::ProtocolBreach)
200 .await
201 .map_err(EthStreamError::from)?;
202 return Err(err.into());
203 }
204
205 Ok(their_status_message.to_legacy())
206 }
207 _ => {
208 unauth
209 .disconnect(DisconnectReason::ProtocolBreach)
210 .await
211 .map_err(EthStreamError::from)?;
212 Err(EthStreamError::EthHandshakeError(
213 EthHandshakeError::NonStatusMessageInHandshake,
214 ))
215 }
216 }
217 }
218}