reth_eth_wire/
handshake.rs

1use crate::{
2    errors::{EthHandshakeError, EthStreamError, P2PStreamError},
3    ethstream::MAX_STATUS_SIZE,
4    CanDisconnect,
5};
6use bytes::{Bytes, BytesMut};
7use derive_more::with_trait::Debug;
8use futures::{Sink, SinkExt, Stream};
9use reth_eth_wire_types::{
10    DisconnectReason, EthMessage, EthNetworkPrimitives, ProtocolMessage, Status,
11};
12use reth_ethereum_forks::ForkFilter;
13use reth_primitives_traits::GotExpected;
14use std::{future::Future, pin::Pin, time::Duration};
15use tokio::time::timeout;
16use tokio_stream::StreamExt;
17use tracing::{debug, trace};
18
19/// A trait that knows how to perform the P2P handshake.
20pub trait EthRlpxHandshake: Debug + Send + Sync + 'static {
21    /// Perform the P2P handshake for the `eth` protocol.
22    fn handshake<'a>(
23        &'a self,
24        unauth: &'a mut dyn UnauthEth,
25        status: Status,
26        fork_filter: ForkFilter,
27        timeout_limit: Duration,
28    ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>>;
29}
30
31/// An unauthenticated stream that can send and receive messages.
32pub trait UnauthEth:
33    Stream<Item = Result<BytesMut, P2PStreamError>>
34    + Sink<Bytes, Error = P2PStreamError>
35    + CanDisconnect<Bytes>
36    + Unpin
37    + Send
38{
39}
40
41impl<T> UnauthEth for T where
42    T: Stream<Item = Result<BytesMut, P2PStreamError>>
43        + Sink<Bytes, Error = P2PStreamError>
44        + CanDisconnect<Bytes>
45        + Unpin
46        + Send
47{
48}
49
50/// The Ethereum P2P handshake.
51///
52/// This performs the regular ethereum `eth` rlpx handshake.
53#[derive(Debug, Default, Clone)]
54#[non_exhaustive]
55pub struct EthHandshake;
56
57impl EthRlpxHandshake for EthHandshake {
58    fn handshake<'a>(
59        &'a self,
60        unauth: &'a mut dyn UnauthEth,
61        status: Status,
62        fork_filter: ForkFilter,
63        timeout_limit: Duration,
64    ) -> Pin<Box<dyn Future<Output = Result<Status, EthStreamError>> + 'a + Send>> {
65        Box::pin(async move {
66            timeout(timeout_limit, EthereumEthHandshake(unauth).eth_handshake(status, fork_filter))
67                .await
68                .map_err(|_| EthStreamError::StreamTimeout)?
69        })
70    }
71}
72
73/// A type that performs the ethereum specific `eth` protocol handshake.
74#[derive(Debug)]
75pub struct EthereumEthHandshake<'a, S: ?Sized>(pub &'a mut S);
76
77impl<S: ?Sized, E> EthereumEthHandshake<'_, S>
78where
79    S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
80    EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
81{
82    /// Performs the `eth` rlpx protocol handshake using the given input stream.
83    pub async fn eth_handshake(
84        self,
85        status: Status,
86        fork_filter: ForkFilter,
87    ) -> Result<Status, EthStreamError> {
88        let unauth = self.0;
89        // Send our status message
90        let status_msg =
91            alloy_rlp::encode(ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::<
92                EthNetworkPrimitives,
93            >::Status(
94                status
95            )))
96            .into();
97        unauth.send(status_msg).await.map_err(EthStreamError::from)?;
98
99        // Receive peer's response
100        let their_msg_res = unauth.next().await;
101        let their_msg = match their_msg_res {
102            Some(Ok(msg)) => msg,
103            Some(Err(e)) => return Err(EthStreamError::from(e)),
104            None => {
105                unauth
106                    .disconnect(DisconnectReason::DisconnectRequested)
107                    .await
108                    .map_err(EthStreamError::from)?;
109                return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse));
110            }
111        };
112
113        if their_msg.len() > MAX_STATUS_SIZE {
114            unauth
115                .disconnect(DisconnectReason::ProtocolBreach)
116                .await
117                .map_err(EthStreamError::from)?;
118            return Err(EthStreamError::MessageTooBig(their_msg.len()));
119        }
120
121        let version = status.version;
122        let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
123            version,
124            &mut their_msg.as_ref(),
125        ) {
126            Ok(m) => m,
127            Err(err) => {
128                debug!("decode error in eth handshake: msg={their_msg:x}");
129                unauth
130                    .disconnect(DisconnectReason::DisconnectRequested)
131                    .await
132                    .map_err(EthStreamError::from)?;
133                return Err(EthStreamError::InvalidMessage(err));
134            }
135        };
136
137        // Validate peer response
138        match msg.message {
139            EthMessage::Status(their_status) => {
140                trace!("Validating incoming ETH status from peer");
141
142                if status.genesis != their_status.genesis {
143                    unauth
144                        .disconnect(DisconnectReason::ProtocolBreach)
145                        .await
146                        .map_err(EthStreamError::from)?;
147                    return Err(EthHandshakeError::MismatchedGenesis(
148                        GotExpected { expected: status.genesis, got: their_status.genesis }.into(),
149                    )
150                    .into());
151                }
152
153                if status.version != their_status.version {
154                    unauth
155                        .disconnect(DisconnectReason::ProtocolBreach)
156                        .await
157                        .map_err(EthStreamError::from)?;
158                    return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
159                        got: their_status.version,
160                        expected: status.version,
161                    })
162                    .into());
163                }
164
165                if status.chain != their_status.chain {
166                    unauth
167                        .disconnect(DisconnectReason::ProtocolBreach)
168                        .await
169                        .map_err(EthStreamError::from)?;
170                    return Err(EthHandshakeError::MismatchedChain(GotExpected {
171                        got: their_status.chain,
172                        expected: status.chain,
173                    })
174                    .into());
175                }
176
177                // Ensure total difficulty is reasonable
178                if status.total_difficulty.bit_len() > 160 {
179                    unauth
180                        .disconnect(DisconnectReason::ProtocolBreach)
181                        .await
182                        .map_err(EthStreamError::from)?;
183                    return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
184                        got: status.total_difficulty.bit_len(),
185                        maximum: 160,
186                    }
187                    .into());
188                }
189
190                // Fork validation
191                if let Err(err) = fork_filter
192                    .validate(their_status.forkid)
193                    .map_err(EthHandshakeError::InvalidFork)
194                {
195                    unauth
196                        .disconnect(DisconnectReason::ProtocolBreach)
197                        .await
198                        .map_err(EthStreamError::from)?;
199                    return Err(err.into());
200                }
201
202                Ok(their_status)
203            }
204            _ => {
205                unauth
206                    .disconnect(DisconnectReason::ProtocolBreach)
207                    .await
208                    .map_err(EthStreamError::from)?;
209                Err(EthStreamError::EthHandshakeError(
210                    EthHandshakeError::NonStatusMessageInHandshake,
211                ))
212            }
213        }
214    }
215}