reth_eth_wire/
handshake.rs
1use crate::{
2 errors::{EthHandshakeError, EthStreamError, P2PStreamError},
3 ethstream::MAX_STATUS_SIZE,
4 CanDisconnect,
5};
6use bytes::{Bytes, BytesMut};
7use derive_more::with_trait::Debug;
8use futures::{Sink, SinkExt, Stream};
9use reth_eth_wire_types::{
10 DisconnectReason, EthMessage, EthNetworkPrimitives, ProtocolMessage, Status,
11};
12use reth_ethereum_forks::ForkFilter;
13use reth_primitives_traits::GotExpected;
14use std::{future::Future, pin::Pin, time::Duration};
15use tokio::time::timeout;
16use tokio_stream::StreamExt;
17use tracing::{debug, trace};
18
19pub trait EthRlpxHandshake: Debug + Send + Sync + 'static {
21 fn handshake<'a>(
23 &'a self,
24 unauth: &'a mut dyn UnauthEth,
25 status: Status,
26 fork_filter: ForkFilter,
27 timeout_limit: Duration,
28 ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>>;
29}
30
31pub trait UnauthEth:
33 Stream<Item = Result<BytesMut, P2PStreamError>>
34 + Sink<Bytes, Error = P2PStreamError>
35 + CanDisconnect<Bytes>
36 + Unpin
37 + Send
38{
39}
40
41impl<T> UnauthEth for T where
42 T: Stream<Item = Result<BytesMut, P2PStreamError>>
43 + Sink<Bytes, Error = P2PStreamError>
44 + CanDisconnect<Bytes>
45 + Unpin
46 + Send
47{
48}
49
50#[derive(Debug, Default, Clone)]
54#[non_exhaustive]
55pub struct EthHandshake;
56
57impl EthRlpxHandshake for EthHandshake {
58 fn handshake<'a>(
59 &'a self,
60 unauth: &'a mut dyn UnauthEth,
61 status: Status,
62 fork_filter: ForkFilter,
63 timeout_limit: Duration,
64 ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>> {
65 Box::pin(async move {
66 timeout(timeout_limit, EthereumEthHandshake(unauth).eth_handshake(status, fork_filter))
67 .await
68 .map_err(|_| EthStreamError::StreamTimeout)?
69 })
70 }
71}
72
73#[derive(Debug)]
75pub struct EthereumEthHandshake<'a, S: ?Sized>(pub &'a mut S);
76
77impl<S: ?Sized, E> EthereumEthHandshake<'_, S>
78where
79 S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
80 EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
81{
82 pub async fn eth_handshake(
84 self,
85 status: Status,
86 fork_filter: ForkFilter,
87 ) -> Result<Status, EthStreamError> {
88 let unauth = self.0;
89 let status_msg =
91 alloy_rlp::encode(ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::<
92 EthNetworkPrimitives,
93 >::Status(
94 status
95 )))
96 .into();
97 unauth.send(status_msg).await.map_err(EthStreamError::from)?;
98
99 let their_msg_res = unauth.next().await;
101 let their_msg = match their_msg_res {
102 Some(Ok(msg)) => msg,
103 Some(Err(e)) => return Err(EthStreamError::from(e)),
104 None => {
105 unauth
106 .disconnect(DisconnectReason::DisconnectRequested)
107 .await
108 .map_err(EthStreamError::from)?;
109 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse));
110 }
111 };
112
113 if their_msg.len() > MAX_STATUS_SIZE {
114 unauth
115 .disconnect(DisconnectReason::ProtocolBreach)
116 .await
117 .map_err(EthStreamError::from)?;
118 return Err(EthStreamError::MessageTooBig(their_msg.len()));
119 }
120
121 let version = status.version;
122 let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
123 version,
124 &mut their_msg.as_ref(),
125 ) {
126 Ok(m) => m,
127 Err(err) => {
128 debug!("decode error in eth handshake: msg={their_msg:x}");
129 unauth
130 .disconnect(DisconnectReason::DisconnectRequested)
131 .await
132 .map_err(EthStreamError::from)?;
133 return Err(EthStreamError::InvalidMessage(err));
134 }
135 };
136
137 match msg.message {
139 EthMessage::Status(their_status) => {
140 trace!("Validating incoming ETH status from peer");
141
142 if status.genesis != their_status.genesis {
143 unauth
144 .disconnect(DisconnectReason::ProtocolBreach)
145 .await
146 .map_err(EthStreamError::from)?;
147 return Err(EthHandshakeError::MismatchedGenesis(
148 GotExpected { expected: status.genesis, got: their_status.genesis }.into(),
149 )
150 .into());
151 }
152
153 if status.version != their_status.version {
154 unauth
155 .disconnect(DisconnectReason::ProtocolBreach)
156 .await
157 .map_err(EthStreamError::from)?;
158 return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
159 got: their_status.version,
160 expected: status.version,
161 })
162 .into());
163 }
164
165 if status.chain != their_status.chain {
166 unauth
167 .disconnect(DisconnectReason::ProtocolBreach)
168 .await
169 .map_err(EthStreamError::from)?;
170 return Err(EthHandshakeError::MismatchedChain(GotExpected {
171 got: their_status.chain,
172 expected: status.chain,
173 })
174 .into());
175 }
176
177 if status.total_difficulty.bit_len() > 160 {
179 unauth
180 .disconnect(DisconnectReason::ProtocolBreach)
181 .await
182 .map_err(EthStreamError::from)?;
183 return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
184 got: status.total_difficulty.bit_len(),
185 maximum: 160,
186 }
187 .into());
188 }
189
190 if let Err(err) = fork_filter
192 .validate(their_status.forkid)
193 .map_err(EthHandshakeError::InvalidFork)
194 {
195 unauth
196 .disconnect(DisconnectReason::ProtocolBreach)
197 .await
198 .map_err(EthStreamError::from)?;
199 return Err(err.into());
200 }
201
202 Ok(their_status)
203 }
204 _ => {
205 unauth
206 .disconnect(DisconnectReason::ProtocolBreach)
207 .await
208 .map_err(EthStreamError::from)?;
209 Err(EthStreamError::EthHandshakeError(
210 EthHandshakeError::NonStatusMessageInHandshake,
211 ))
212 }
213 }
214 }
215}