reth_eth_wire/
handshake.rs1use crate::{
2 errors::{EthHandshakeError, EthStreamError, P2PStreamError},
3 ethstream::MAX_MESSAGE_SIZE,
4 CanDisconnect,
5};
6use bytes::{Bytes, BytesMut};
7use futures::{Sink, SinkExt, Stream};
8use reth_eth_wire_types::{
9 DisconnectReason, EthMessage, EthNetworkPrimitives, ProtocolMessage, StatusMessage,
10 UnifiedStatus,
11};
12use reth_ethereum_forks::ForkFilter;
13use reth_primitives_traits::GotExpected;
14use std::{fmt::Debug, future::Future, pin::Pin, time::Duration};
15use tokio::time::timeout;
16use tokio_stream::StreamExt;
17use tracing::{debug, trace};
18
19pub trait EthRlpxHandshake: Debug + Send + Sync + 'static {
21 fn handshake<'a>(
23 &'a self,
24 unauth: &'a mut dyn UnauthEth,
25 status: UnifiedStatus,
26 fork_filter: ForkFilter,
27 timeout_limit: Duration,
28 ) -> Pin<Box<dyn Future<Output = Result<UnifiedStatus, EthStreamError>> + 'a + Send>>;
29}
30
31pub trait UnauthEth:
33 Stream<Item = Result<BytesMut, P2PStreamError>>
34 + Sink<Bytes, Error = P2PStreamError>
35 + CanDisconnect<Bytes>
36 + Unpin
37 + Send
38{
39}
40
41impl<T> UnauthEth for T where
42 T: Stream<Item = Result<BytesMut, P2PStreamError>>
43 + Sink<Bytes, Error = P2PStreamError>
44 + CanDisconnect<Bytes>
45 + Unpin
46 + Send
47{
48}
49
50#[derive(Debug, Default, Clone)]
54#[non_exhaustive]
55pub struct EthHandshake;
56
57impl EthRlpxHandshake for EthHandshake {
58 fn handshake<'a>(
59 &'a self,
60 unauth: &'a mut dyn UnauthEth,
61 status: UnifiedStatus,
62 fork_filter: ForkFilter,
63 timeout_limit: Duration,
64 ) -> Pin<Box<dyn Future<Output = Result<UnifiedStatus, EthStreamError>> + 'a + Send>> {
65 Box::pin(async move {
66 timeout(timeout_limit, EthereumEthHandshake(unauth).eth_handshake(status, fork_filter))
67 .await
68 .map_err(|_| EthStreamError::StreamTimeout)?
69 })
70 }
71}
72
73#[derive(Debug)]
75pub struct EthereumEthHandshake<'a, S: ?Sized>(pub &'a mut S);
76
77impl<S: ?Sized, E> EthereumEthHandshake<'_, S>
78where
79 S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
80 EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
81{
82 pub async fn eth_handshake(
84 self,
85 unified_status: UnifiedStatus,
86 fork_filter: ForkFilter,
87 ) -> Result<UnifiedStatus, EthStreamError> {
88 let unauth = self.0;
89
90 let status = unified_status.into_message();
91
92 let status_msg = alloy_rlp::encode(ProtocolMessage::<EthNetworkPrimitives>::from(
94 EthMessage::Status(status),
95 ))
96 .into();
97 unauth.send(status_msg).await.map_err(EthStreamError::from)?;
98
99 let their_msg_res = unauth.next().await;
101 let their_msg = match their_msg_res {
102 Some(Ok(msg)) => msg,
103 Some(Err(e)) => return Err(EthStreamError::from(e)),
104 None => {
105 unauth
106 .disconnect(DisconnectReason::DisconnectRequested)
107 .await
108 .map_err(EthStreamError::from)?;
109 return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse));
110 }
111 };
112
113 if their_msg.len() > MAX_MESSAGE_SIZE {
114 unauth
115 .disconnect(DisconnectReason::ProtocolBreach)
116 .await
117 .map_err(EthStreamError::from)?;
118 return Err(EthStreamError::MessageTooBig(their_msg.len()));
119 }
120
121 let version = status.version();
122 let their_status_message = match ProtocolMessage::<EthNetworkPrimitives>::decode_status(
123 version,
124 &mut their_msg.as_ref(),
125 ) {
126 Ok(status) => status,
127 Err(err) => {
128 debug!("decode error in eth handshake: msg={their_msg:x}");
129 unauth
130 .disconnect(DisconnectReason::ProtocolBreach)
131 .await
132 .map_err(EthStreamError::from)?;
133 return Err(EthStreamError::InvalidMessage(err));
134 }
135 };
136
137 trace!("Validating incoming ETH status from peer");
138
139 if status.genesis() != their_status_message.genesis() {
140 unauth
141 .disconnect(DisconnectReason::ProtocolBreach)
142 .await
143 .map_err(EthStreamError::from)?;
144 return Err(EthHandshakeError::MismatchedGenesis(
145 GotExpected { expected: status.genesis(), got: their_status_message.genesis() }
146 .into(),
147 )
148 .into());
149 }
150
151 if status.version() != their_status_message.version() {
152 unauth
153 .disconnect(DisconnectReason::ProtocolBreach)
154 .await
155 .map_err(EthStreamError::from)?;
156 return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
157 got: their_status_message.version(),
158 expected: status.version(),
159 })
160 .into());
161 }
162
163 if *status.chain() != *their_status_message.chain() {
164 unauth
165 .disconnect(DisconnectReason::ProtocolBreach)
166 .await
167 .map_err(EthStreamError::from)?;
168 return Err(EthHandshakeError::MismatchedChain(GotExpected {
169 got: *their_status_message.chain(),
170 expected: *status.chain(),
171 })
172 .into());
173 }
174
175 if let StatusMessage::Legacy(s) = &their_status_message &&
177 s.total_difficulty.bit_len() > 160
178 {
179 unauth
180 .disconnect(DisconnectReason::ProtocolBreach)
181 .await
182 .map_err(EthStreamError::from)?;
183 return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
184 got: s.total_difficulty.bit_len(),
185 maximum: 160,
186 }
187 .into());
188 }
189
190 if let Err(err) = fork_filter
192 .validate(their_status_message.forkid())
193 .map_err(EthHandshakeError::InvalidFork)
194 {
195 unauth
196 .disconnect(DisconnectReason::ProtocolBreach)
197 .await
198 .map_err(EthStreamError::from)?;
199 return Err(err.into());
200 }
201
202 if let StatusMessage::Eth69(s) = &their_status_message {
203 if s.earliest > s.latest {
204 return Err(EthHandshakeError::EarliestBlockGreaterThanLatestBlock {
205 got: s.earliest,
206 latest: s.latest,
207 }
208 .into());
209 }
210
211 if s.blockhash.is_zero() {
212 return Err(EthHandshakeError::BlockhashZero.into());
213 }
214 }
215
216 Ok(UnifiedStatus::from_message(their_status_message))
217 }
218}