reth_ress_protocol/
connection.rs

1use crate::{GetHeaders, NodeType, RessMessage, RessProtocolMessage, RessProtocolProvider};
2use alloy_consensus::Header;
3use alloy_primitives::{bytes::BytesMut, BlockHash, Bytes, B256};
4use futures::{stream::FuturesUnordered, Stream, StreamExt};
5use reth_eth_wire::{message::RequestPair, multiplex::ProtocolConnection};
6use reth_ethereum_primitives::BlockBody;
7use reth_network_api::{test_utils::PeersHandle, PeerId, ReputationChangeKind};
8use reth_storage_errors::ProviderResult;
9use std::{
10    collections::HashMap,
11    future::Future,
12    pin::Pin,
13    sync::{
14        atomic::{AtomicU64, Ordering},
15        Arc,
16    },
17    task::{Context, Poll},
18};
19use tokio::sync::oneshot;
20use tokio_stream::wrappers::UnboundedReceiverStream;
21use tracing::*;
22
23/// The connection handler for the custom `RLPx` protocol.
24#[derive(Debug)]
25pub struct RessProtocolConnection<P> {
26    /// Provider.
27    provider: P,
28    /// The type of this node..
29    node_type: NodeType,
30    /// Peers handle.
31    peers_handle: PeersHandle,
32    /// Peer ID.
33    peer_id: PeerId,
34    /// Protocol connection.
35    conn: ProtocolConnection,
36    /// Stream of incoming commands.
37    commands: UnboundedReceiverStream<RessPeerRequest>,
38    /// The total number of active connections.
39    active_connections: Arc<AtomicU64>,
40    /// Flag indicating whether the node type was sent to the peer.
41    node_type_sent: bool,
42    /// Flag indicating whether this stream has previously been terminated.
43    terminated: bool,
44    /// Incremental counter for request ids.
45    next_id: u64,
46    /// Collection of inflight requests.
47    inflight_requests: HashMap<u64, RessPeerRequest>,
48    /// Pending witness responses.
49    pending_witnesses: FuturesUnordered<WitnessFut>,
50}
51
52impl<P> RessProtocolConnection<P> {
53    /// Create new connection.
54    pub fn new(
55        provider: P,
56        node_type: NodeType,
57        peers_handle: PeersHandle,
58        peer_id: PeerId,
59        conn: ProtocolConnection,
60        commands: UnboundedReceiverStream<RessPeerRequest>,
61        active_connections: Arc<AtomicU64>,
62    ) -> Self {
63        Self {
64            provider,
65            node_type,
66            peers_handle,
67            peer_id,
68            conn,
69            commands,
70            active_connections,
71            node_type_sent: false,
72            terminated: false,
73            next_id: 0,
74            inflight_requests: HashMap::default(),
75            pending_witnesses: FuturesUnordered::new(),
76        }
77    }
78
79    /// Returns the next request id
80    const fn next_id(&mut self) -> u64 {
81        let id = self.next_id;
82        self.next_id += 1;
83        id
84    }
85
86    /// Report bad message from current peer.
87    fn report_bad_message(&self) {
88        self.peers_handle.reputation_change(self.peer_id, ReputationChangeKind::BadMessage);
89    }
90
91    fn on_command(&mut self, command: RessPeerRequest) -> RessProtocolMessage {
92        let next_id = self.next_id();
93        let message = match &command {
94            RessPeerRequest::GetHeaders { request, .. } => {
95                RessProtocolMessage::get_headers(next_id, *request)
96            }
97            RessPeerRequest::GetBlockBodies { request, .. } => {
98                RessProtocolMessage::get_block_bodies(next_id, request.clone())
99            }
100            RessPeerRequest::GetWitness { block_hash, .. } => {
101                RessProtocolMessage::get_witness(next_id, *block_hash)
102            }
103            RessPeerRequest::GetBytecode { code_hash, .. } => {
104                RessProtocolMessage::get_bytecode(next_id, *code_hash)
105            }
106        };
107        self.inflight_requests.insert(next_id, command);
108        message
109    }
110}
111
112impl<P> RessProtocolConnection<P>
113where
114    P: RessProtocolProvider + Clone + 'static,
115{
116    fn on_headers_request(&self, request: GetHeaders) -> Vec<Header> {
117        match self.provider.headers(request) {
118            Ok(headers) => headers,
119            Err(error) => {
120                trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, %error, "error retrieving headers");
121                Default::default()
122            }
123        }
124    }
125
126    fn on_block_bodies_request(&self, request: Vec<B256>) -> Vec<BlockBody> {
127        match self.provider.block_bodies(request.clone()) {
128            Ok(bodies) => bodies,
129            Err(error) => {
130                trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, %error, "error retrieving block bodies");
131                Default::default()
132            }
133        }
134    }
135
136    fn on_bytecode_request(&self, code_hash: B256) -> Bytes {
137        match self.provider.bytecode(code_hash) {
138            Ok(Some(bytecode)) => bytecode,
139            Ok(None) => {
140                trace!(target: "ress::net::connection", peer_id = %self.peer_id, %code_hash, "bytecode not found");
141                Default::default()
142            }
143            Err(error) => {
144                trace!(target: "ress::net::connection", peer_id = %self.peer_id, %code_hash, %error, "error retrieving bytecode");
145                Default::default()
146            }
147        }
148    }
149
150    fn on_witness_response(
151        &self,
152        request: RequestPair<B256>,
153        witness_result: ProviderResult<Vec<Bytes>>,
154    ) -> RessProtocolMessage {
155        let peer_id = self.peer_id;
156        let block_hash = request.message;
157        let witness = match witness_result {
158            Ok(witness) => {
159                trace!(target: "ress::net::connection", %peer_id, %block_hash, len = witness.len(), "witness found");
160                witness
161            }
162            Err(error) => {
163                trace!(target: "ress::net::connection", %peer_id, %block_hash, %error, "error retrieving witness");
164                Default::default()
165            }
166        };
167        RessProtocolMessage::witness(request.request_id, witness)
168    }
169
170    fn on_ress_message(&mut self, msg: RessProtocolMessage) -> OnRessMessageOutcome {
171        match msg.message {
172            RessMessage::NodeType(node_type) => {
173                if !self.node_type.is_valid_connection(&node_type) {
174                    // Note types are not compatible, terminate the connection.
175                    return OnRessMessageOutcome::Terminate;
176                }
177            }
178            RessMessage::GetHeaders(req) => {
179                let request = req.message;
180                trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, "serving headers");
181                let header = self.on_headers_request(request);
182                let response = RessProtocolMessage::headers(req.request_id, header);
183                return OnRessMessageOutcome::Response(response.encoded());
184            }
185            RessMessage::GetBlockBodies(req) => {
186                let request = req.message;
187                trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, "serving block bodies");
188                let bodies = self.on_block_bodies_request(request);
189                let response = RessProtocolMessage::block_bodies(req.request_id, bodies);
190                return OnRessMessageOutcome::Response(response.encoded());
191            }
192            RessMessage::GetBytecode(req) => {
193                let code_hash = req.message;
194                trace!(target: "ress::net::connection", peer_id = %self.peer_id, %code_hash, "serving bytecode");
195                let bytecode = self.on_bytecode_request(code_hash);
196                let response = RessProtocolMessage::bytecode(req.request_id, bytecode);
197                return OnRessMessageOutcome::Response(response.encoded());
198            }
199            RessMessage::GetWitness(req) => {
200                let block_hash = req.message;
201                trace!(target: "ress::net::connection", peer_id = %self.peer_id, %block_hash, "serving witness");
202                let provider = self.provider.clone();
203                self.pending_witnesses.push(Box::pin(async move {
204                    let result = provider.witness(block_hash).await;
205                    (req, result)
206                }));
207            }
208            RessMessage::Headers(res) => {
209                if let Some(RessPeerRequest::GetHeaders { tx, .. }) =
210                    self.inflight_requests.remove(&res.request_id)
211                {
212                    let _ = tx.send(res.message);
213                } else {
214                    self.report_bad_message();
215                }
216            }
217            RessMessage::BlockBodies(res) => {
218                if let Some(RessPeerRequest::GetBlockBodies { tx, .. }) =
219                    self.inflight_requests.remove(&res.request_id)
220                {
221                    let _ = tx.send(res.message);
222                } else {
223                    self.report_bad_message();
224                }
225            }
226            RessMessage::Bytecode(res) => {
227                if let Some(RessPeerRequest::GetBytecode { tx, .. }) =
228                    self.inflight_requests.remove(&res.request_id)
229                {
230                    let _ = tx.send(res.message);
231                } else {
232                    self.report_bad_message();
233                }
234            }
235            RessMessage::Witness(res) => {
236                if let Some(RessPeerRequest::GetWitness { tx, .. }) =
237                    self.inflight_requests.remove(&res.request_id)
238                {
239                    let _ = tx.send(res.message);
240                } else {
241                    self.report_bad_message();
242                }
243            }
244        };
245        OnRessMessageOutcome::None
246    }
247}
248
249impl<P> Drop for RessProtocolConnection<P> {
250    fn drop(&mut self) {
251        let _ = self
252            .active_connections
253            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |c| Some(c.saturating_sub(1)));
254    }
255}
256
257impl<P> Stream for RessProtocolConnection<P>
258where
259    P: RessProtocolProvider + Clone + Unpin + 'static,
260{
261    type Item = BytesMut;
262
263    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
264        let this = self.get_mut();
265
266        if this.terminated {
267            return Poll::Ready(None)
268        }
269
270        if !this.node_type_sent {
271            this.node_type_sent = true;
272            return Poll::Ready(Some(RessProtocolMessage::node_type(this.node_type).encoded()))
273        }
274
275        'conn: loop {
276            if let Poll::Ready(Some(cmd)) = this.commands.poll_next_unpin(cx) {
277                let message = this.on_command(cmd);
278                let encoded = message.encoded();
279                trace!(target: "ress::net::connection", peer_id = %this.peer_id, ?message, encoded = alloy_primitives::hex::encode(&encoded), "Sending peer command");
280                return Poll::Ready(Some(encoded));
281            }
282
283            if let Poll::Ready(Some((request, witness_result))) =
284                this.pending_witnesses.poll_next_unpin(cx)
285            {
286                let response = this.on_witness_response(request, witness_result);
287                return Poll::Ready(Some(response.encoded()));
288            }
289
290            if let Poll::Ready(maybe_msg) = this.conn.poll_next_unpin(cx) {
291                let Some(next) = maybe_msg else { break 'conn };
292                let msg = match RessProtocolMessage::decode_message(&mut &next[..]) {
293                    Ok(msg) => {
294                        trace!(target: "ress::net::connection", peer_id = %this.peer_id, message = ?msg.message_type, "Processing message");
295                        msg
296                    }
297                    Err(error) => {
298                        trace!(target: "ress::net::connection", peer_id = %this.peer_id, %error, "Error decoding peer message");
299                        this.report_bad_message();
300                        continue;
301                    }
302                };
303
304                match this.on_ress_message(msg) {
305                    OnRessMessageOutcome::Response(bytes) => return Poll::Ready(Some(bytes)),
306                    OnRessMessageOutcome::Terminate => break 'conn,
307                    OnRessMessageOutcome::None => {}
308                };
309
310                continue;
311            }
312
313            return Poll::Pending;
314        }
315
316        // Terminating the connection.
317        this.terminated = true;
318        Poll::Ready(None)
319    }
320}
321
322type WitnessFut =
323    Pin<Box<dyn Future<Output = (RequestPair<B256>, ProviderResult<Vec<Bytes>>)> + Send>>;
324
325/// Ress peer request.
326#[derive(Debug)]
327pub enum RessPeerRequest {
328    /// Get block headers.
329    GetHeaders {
330        /// The request for block headers.
331        request: GetHeaders,
332        /// The sender for the response.
333        tx: oneshot::Sender<Vec<Header>>,
334    },
335    /// Get block bodies.
336    GetBlockBodies {
337        /// The request for block bodies.
338        request: Vec<BlockHash>,
339        /// The sender for the response.
340        tx: oneshot::Sender<Vec<BlockBody>>,
341    },
342    /// Get bytecode for specific code hash
343    GetBytecode {
344        /// Target code hash that we want to get bytecode for.
345        code_hash: B256,
346        /// The sender for the response.
347        tx: oneshot::Sender<Bytes>,
348    },
349    /// Get witness for specific block.
350    GetWitness {
351        /// Target block hash that we want to get witness for.
352        block_hash: BlockHash,
353        /// The sender for the response.
354        tx: oneshot::Sender<Vec<Bytes>>,
355    },
356}
357
358enum OnRessMessageOutcome {
359    /// Response to send to the peer.
360    Response(BytesMut),
361    /// Terminate the connection.
362    Terminate,
363    /// No action.
364    None,
365}