reth_eth_wire/
handshake.rs

1use crate::{
2    errors::{EthHandshakeError, EthStreamError, P2PStreamError},
3    ethstream::MAX_STATUS_SIZE,
4    CanDisconnect,
5};
6use bytes::{Bytes, BytesMut};
7use futures::{Sink, SinkExt, Stream};
8use reth_eth_wire_types::{
9    DisconnectReason, EthMessage, EthNetworkPrimitives, ProtocolMessage, Status, StatusMessage,
10};
11use reth_ethereum_forks::ForkFilter;
12use reth_primitives_traits::GotExpected;
13use std::{fmt::Debug, future::Future, pin::Pin, time::Duration};
14use tokio::time::timeout;
15use tokio_stream::StreamExt;
16use tracing::{debug, trace};
17
18/// A trait that knows how to perform the P2P handshake.
19pub trait EthRlpxHandshake: Debug + Send + Sync + 'static {
20    /// Perform the P2P handshake for the `eth` protocol.
21    fn handshake<'a>(
22        &'a self,
23        unauth: &'a mut dyn UnauthEth,
24        status: Status,
25        fork_filter: ForkFilter,
26        timeout_limit: Duration,
27    ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>>;
28}
29
30/// An unauthenticated stream that can send and receive messages.
31pub trait UnauthEth:
32    Stream<Item = Result<BytesMut, P2PStreamError>>
33    + Sink<Bytes, Error = P2PStreamError>
34    + CanDisconnect<Bytes>
35    + Unpin
36    + Send
37{
38}
39
40impl<T> UnauthEth for T where
41    T: Stream<Item = Result<BytesMut, P2PStreamError>>
42        + Sink<Bytes, Error = P2PStreamError>
43        + CanDisconnect<Bytes>
44        + Unpin
45        + Send
46{
47}
48
49/// The Ethereum P2P handshake.
50///
51/// This performs the regular ethereum `eth` rlpx handshake.
52#[derive(Debug, Default, Clone)]
53#[non_exhaustive]
54pub struct EthHandshake;
55
56impl EthRlpxHandshake for EthHandshake {
57    fn handshake<'a>(
58        &'a self,
59        unauth: &'a mut dyn UnauthEth,
60        status: Status,
61        fork_filter: ForkFilter,
62        timeout_limit: Duration,
63    ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>> {
64        Box::pin(async move {
65            timeout(timeout_limit, EthereumEthHandshake(unauth).eth_handshake(status, fork_filter))
66                .await
67                .map_err(|_| EthStreamError::StreamTimeout)?
68        })
69    }
70}
71
72/// A type that performs the ethereum specific `eth` protocol handshake.
73#[derive(Debug)]
74pub struct EthereumEthHandshake<'a, S: ?Sized>(pub &'a mut S);
75
76impl<S: ?Sized, E> EthereumEthHandshake<'_, S>
77where
78    S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
79    EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
80{
81    /// Performs the `eth` rlpx protocol handshake using the given input stream.
82    pub async fn eth_handshake(
83        self,
84        status: Status,
85        fork_filter: ForkFilter,
86    ) -> Result<Status, EthStreamError> {
87        let unauth = self.0;
88        // Send our status message
89        let status_msg =
90            alloy_rlp::encode(ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::<
91                EthNetworkPrimitives,
92            >::Status(
93                StatusMessage::Legacy(status),
94            )))
95            .into();
96        unauth.send(status_msg).await.map_err(EthStreamError::from)?;
97
98        // Receive peer's response
99        let their_msg_res = unauth.next().await;
100        let their_msg = match their_msg_res {
101            Some(Ok(msg)) => msg,
102            Some(Err(e)) => return Err(EthStreamError::from(e)),
103            None => {
104                unauth
105                    .disconnect(DisconnectReason::DisconnectRequested)
106                    .await
107                    .map_err(EthStreamError::from)?;
108                return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse));
109            }
110        };
111
112        if their_msg.len() > MAX_STATUS_SIZE {
113            unauth
114                .disconnect(DisconnectReason::ProtocolBreach)
115                .await
116                .map_err(EthStreamError::from)?;
117            return Err(EthStreamError::MessageTooBig(their_msg.len()));
118        }
119
120        let version = status.version;
121        let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
122            version,
123            &mut their_msg.as_ref(),
124        ) {
125            Ok(m) => m,
126            Err(err) => {
127                debug!("decode error in eth handshake: msg={their_msg:x}");
128                unauth
129                    .disconnect(DisconnectReason::DisconnectRequested)
130                    .await
131                    .map_err(EthStreamError::from)?;
132                return Err(EthStreamError::InvalidMessage(err));
133            }
134        };
135
136        // Validate peer response
137        match msg.message {
138            EthMessage::Status(their_status_message) => {
139                trace!("Validating incoming ETH status from peer");
140
141                if status.genesis != their_status_message.genesis() {
142                    unauth
143                        .disconnect(DisconnectReason::ProtocolBreach)
144                        .await
145                        .map_err(EthStreamError::from)?;
146                    return Err(EthHandshakeError::MismatchedGenesis(
147                        GotExpected {
148                            expected: status.genesis,
149                            got: their_status_message.genesis(),
150                        }
151                        .into(),
152                    )
153                    .into());
154                }
155
156                if status.version != their_status_message.version() {
157                    unauth
158                        .disconnect(DisconnectReason::ProtocolBreach)
159                        .await
160                        .map_err(EthStreamError::from)?;
161                    return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
162                        got: their_status_message.version(),
163                        expected: status.version,
164                    })
165                    .into());
166                }
167
168                if status.chain != *their_status_message.chain() {
169                    unauth
170                        .disconnect(DisconnectReason::ProtocolBreach)
171                        .await
172                        .map_err(EthStreamError::from)?;
173                    return Err(EthHandshakeError::MismatchedChain(GotExpected {
174                        got: *their_status_message.chain(),
175                        expected: status.chain,
176                    })
177                    .into());
178                }
179
180                // Ensure total difficulty is reasonable
181                if status.total_difficulty.bit_len() > 160 {
182                    unauth
183                        .disconnect(DisconnectReason::ProtocolBreach)
184                        .await
185                        .map_err(EthStreamError::from)?;
186                    return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
187                        got: status.total_difficulty.bit_len(),
188                        maximum: 160,
189                    }
190                    .into());
191                }
192
193                // Fork validation
194                if let Err(err) = fork_filter
195                    .validate(their_status_message.forkid())
196                    .map_err(EthHandshakeError::InvalidFork)
197                {
198                    unauth
199                        .disconnect(DisconnectReason::ProtocolBreach)
200                        .await
201                        .map_err(EthStreamError::from)?;
202                    return Err(err.into());
203                }
204
205                Ok(their_status_message.to_legacy())
206            }
207            _ => {
208                unauth
209                    .disconnect(DisconnectReason::ProtocolBreach)
210                    .await
211                    .map_err(EthStreamError::from)?;
212                Err(EthStreamError::EthHandshakeError(
213                    EthHandshakeError::NonStatusMessageInHandshake,
214                ))
215            }
216        }
217    }
218}