reth_eth_wire/
handshake.rs

1use crate::{
2    errors::{EthHandshakeError, EthStreamError, P2PStreamError},
3    ethstream::MAX_STATUS_SIZE,
4    CanDisconnect,
5};
6use bytes::{Bytes, BytesMut};
7use futures::{Sink, SinkExt, Stream};
8use reth_eth_wire_types::{
9    DisconnectReason, EthMessage, EthNetworkPrimitives, ProtocolMessage, StatusMessage,
10    UnifiedStatus,
11};
12use reth_ethereum_forks::ForkFilter;
13use reth_primitives_traits::GotExpected;
14use std::{fmt::Debug, future::Future, pin::Pin, time::Duration};
15use tokio::time::timeout;
16use tokio_stream::StreamExt;
17use tracing::{debug, trace};
18
19/// A trait that knows how to perform the P2P handshake.
20pub trait EthRlpxHandshake: Debug + Send + Sync + 'static {
21    /// Perform the P2P handshake for the `eth` protocol.
22    fn handshake<'a>(
23        &'a self,
24        unauth: &'a mut dyn UnauthEth,
25        status: UnifiedStatus,
26        fork_filter: ForkFilter,
27        timeout_limit: Duration,
28    ) -> Pin<Box<dyn Future<Output = Result<UnifiedStatus, EthStreamError>> + 'a + Send>>;
29}
30
31/// An unauthenticated stream that can send and receive messages.
32pub trait UnauthEth:
33    Stream<Item = Result<BytesMut, P2PStreamError>>
34    + Sink<Bytes, Error = P2PStreamError>
35    + CanDisconnect<Bytes>
36    + Unpin
37    + Send
38{
39}
40
41impl<T> UnauthEth for T where
42    T: Stream<Item = Result<BytesMut, P2PStreamError>>
43        + Sink<Bytes, Error = P2PStreamError>
44        + CanDisconnect<Bytes>
45        + Unpin
46        + Send
47{
48}
49
50/// The Ethereum P2P handshake.
51///
52/// This performs the regular ethereum `eth` rlpx handshake.
53#[derive(Debug, Default, Clone)]
54#[non_exhaustive]
55pub struct EthHandshake;
56
57impl EthRlpxHandshake for EthHandshake {
58    fn handshake<'a>(
59        &'a self,
60        unauth: &'a mut dyn UnauthEth,
61        status: UnifiedStatus,
62        fork_filter: ForkFilter,
63        timeout_limit: Duration,
64    ) -> Pin<Box<dyn Future<Output = Result<UnifiedStatus, EthStreamError>> + 'a + Send>> {
65        Box::pin(async move {
66            timeout(timeout_limit, EthereumEthHandshake(unauth).eth_handshake(status, fork_filter))
67                .await
68                .map_err(|_| EthStreamError::StreamTimeout)?
69        })
70    }
71}
72
73/// A type that performs the ethereum specific `eth` protocol handshake.
74#[derive(Debug)]
75pub struct EthereumEthHandshake<'a, S: ?Sized>(pub &'a mut S);
76
77impl<S: ?Sized, E> EthereumEthHandshake<'_, S>
78where
79    S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Send + Unpin,
80    EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
81{
82    /// Performs the `eth` rlpx protocol handshake using the given input stream.
83    pub async fn eth_handshake(
84        self,
85        unified_status: UnifiedStatus,
86        fork_filter: ForkFilter,
87    ) -> Result<UnifiedStatus, EthStreamError> {
88        let unauth = self.0;
89
90        let status = unified_status.into_message();
91
92        // Send our status message
93        let status_msg = alloy_rlp::encode(ProtocolMessage::<EthNetworkPrimitives>::from(
94            EthMessage::Status(status),
95        ))
96        .into();
97        unauth.send(status_msg).await.map_err(EthStreamError::from)?;
98
99        // Receive peer's response
100        let their_msg_res = unauth.next().await;
101        let their_msg = match their_msg_res {
102            Some(Ok(msg)) => msg,
103            Some(Err(e)) => return Err(EthStreamError::from(e)),
104            None => {
105                unauth
106                    .disconnect(DisconnectReason::DisconnectRequested)
107                    .await
108                    .map_err(EthStreamError::from)?;
109                return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse));
110            }
111        };
112
113        if their_msg.len() > MAX_STATUS_SIZE {
114            unauth
115                .disconnect(DisconnectReason::ProtocolBreach)
116                .await
117                .map_err(EthStreamError::from)?;
118            return Err(EthStreamError::MessageTooBig(their_msg.len()));
119        }
120
121        let version = status.version();
122        let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
123            version,
124            &mut their_msg.as_ref(),
125        ) {
126            Ok(m) => m,
127            Err(err) => {
128                debug!("decode error in eth handshake: msg={their_msg:x}");
129                unauth
130                    .disconnect(DisconnectReason::DisconnectRequested)
131                    .await
132                    .map_err(EthStreamError::from)?;
133                return Err(EthStreamError::InvalidMessage(err));
134            }
135        };
136
137        // Validate peer response
138        match msg.message {
139            EthMessage::Status(their_status_message) => {
140                trace!("Validating incoming ETH status from peer");
141
142                if status.genesis() != their_status_message.genesis() {
143                    unauth
144                        .disconnect(DisconnectReason::ProtocolBreach)
145                        .await
146                        .map_err(EthStreamError::from)?;
147                    return Err(EthHandshakeError::MismatchedGenesis(
148                        GotExpected {
149                            expected: status.genesis(),
150                            got: their_status_message.genesis(),
151                        }
152                        .into(),
153                    )
154                    .into());
155                }
156
157                if status.version() != their_status_message.version() {
158                    unauth
159                        .disconnect(DisconnectReason::ProtocolBreach)
160                        .await
161                        .map_err(EthStreamError::from)?;
162                    return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
163                        got: their_status_message.version(),
164                        expected: status.version(),
165                    })
166                    .into());
167                }
168
169                if *status.chain() != *their_status_message.chain() {
170                    unauth
171                        .disconnect(DisconnectReason::ProtocolBreach)
172                        .await
173                        .map_err(EthStreamError::from)?;
174                    return Err(EthHandshakeError::MismatchedChain(GotExpected {
175                        got: *their_status_message.chain(),
176                        expected: *status.chain(),
177                    })
178                    .into());
179                }
180
181                // Ensure total difficulty is reasonable
182                if let StatusMessage::Legacy(s) = status {
183                    if s.total_difficulty.bit_len() > 160 {
184                        unauth
185                            .disconnect(DisconnectReason::ProtocolBreach)
186                            .await
187                            .map_err(EthStreamError::from)?;
188                        return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
189                            got: s.total_difficulty.bit_len(),
190                            maximum: 160,
191                        }
192                        .into());
193                    }
194                }
195
196                // Fork validation
197                if let Err(err) = fork_filter
198                    .validate(their_status_message.forkid())
199                    .map_err(EthHandshakeError::InvalidFork)
200                {
201                    unauth
202                        .disconnect(DisconnectReason::ProtocolBreach)
203                        .await
204                        .map_err(EthStreamError::from)?;
205                    return Err(err.into());
206                }
207
208                Ok(UnifiedStatus::from_message(their_status_message))
209            }
210            _ => {
211                unauth
212                    .disconnect(DisconnectReason::ProtocolBreach)
213                    .await
214                    .map_err(EthStreamError::from)?;
215                Err(EthStreamError::EthHandshakeError(
216                    EthHandshakeError::NonStatusMessageInHandshake,
217                ))
218            }
219        }
220    }
221}