1use crate::{GetHeaders, NodeType, RessMessage, RessProtocolMessage, RessProtocolProvider};
2use alloy_consensus::Header;
3use alloy_primitives::{bytes::BytesMut, BlockHash, Bytes, B256};
4use futures::{stream::FuturesUnordered, Stream, StreamExt};
5use reth_eth_wire::{message::RequestPair, multiplex::ProtocolConnection};
6use reth_ethereum_primitives::BlockBody;
7use reth_network_api::{test_utils::PeersHandle, PeerId, ReputationChangeKind};
8use reth_storage_errors::ProviderResult;
9use std::{
10 collections::HashMap,
11 future::Future,
12 pin::Pin,
13 sync::{
14 atomic::{AtomicU64, Ordering},
15 Arc,
16 },
17 task::{Context, Poll},
18};
19use tokio::sync::oneshot;
20use tokio_stream::wrappers::UnboundedReceiverStream;
21use tracing::*;
22
23#[derive(Debug)]
25pub struct RessProtocolConnection<P> {
26 provider: P,
28 node_type: NodeType,
30 peers_handle: PeersHandle,
32 peer_id: PeerId,
34 conn: ProtocolConnection,
36 commands: UnboundedReceiverStream<RessPeerRequest>,
38 active_connections: Arc<AtomicU64>,
40 node_type_sent: bool,
42 terminated: bool,
44 next_id: u64,
46 inflight_requests: HashMap<u64, RessPeerRequest>,
48 pending_witnesses: FuturesUnordered<WitnessFut>,
50}
51
52impl<P> RessProtocolConnection<P> {
53 pub fn new(
55 provider: P,
56 node_type: NodeType,
57 peers_handle: PeersHandle,
58 peer_id: PeerId,
59 conn: ProtocolConnection,
60 commands: UnboundedReceiverStream<RessPeerRequest>,
61 active_connections: Arc<AtomicU64>,
62 ) -> Self {
63 Self {
64 provider,
65 node_type,
66 peers_handle,
67 peer_id,
68 conn,
69 commands,
70 active_connections,
71 node_type_sent: false,
72 terminated: false,
73 next_id: 0,
74 inflight_requests: HashMap::default(),
75 pending_witnesses: FuturesUnordered::new(),
76 }
77 }
78
79 const fn next_id(&mut self) -> u64 {
81 let id = self.next_id;
82 self.next_id += 1;
83 id
84 }
85
86 fn report_bad_message(&self) {
88 self.peers_handle.reputation_change(self.peer_id, ReputationChangeKind::BadMessage);
89 }
90
91 fn on_command(&mut self, command: RessPeerRequest) -> RessProtocolMessage {
92 let next_id = self.next_id();
93 let message = match &command {
94 RessPeerRequest::GetHeaders { request, .. } => {
95 RessProtocolMessage::get_headers(next_id, *request)
96 }
97 RessPeerRequest::GetBlockBodies { request, .. } => {
98 RessProtocolMessage::get_block_bodies(next_id, request.clone())
99 }
100 RessPeerRequest::GetWitness { block_hash, .. } => {
101 RessProtocolMessage::get_witness(next_id, *block_hash)
102 }
103 RessPeerRequest::GetBytecode { code_hash, .. } => {
104 RessProtocolMessage::get_bytecode(next_id, *code_hash)
105 }
106 };
107 self.inflight_requests.insert(next_id, command);
108 message
109 }
110}
111
112impl<P> RessProtocolConnection<P>
113where
114 P: RessProtocolProvider + Clone + 'static,
115{
116 fn on_headers_request(&self, request: GetHeaders) -> Vec<Header> {
117 match self.provider.headers(request) {
118 Ok(headers) => headers,
119 Err(error) => {
120 trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, %error, "error retrieving headers");
121 Default::default()
122 }
123 }
124 }
125
126 fn on_block_bodies_request(&self, request: Vec<B256>) -> Vec<BlockBody> {
127 match self.provider.block_bodies(request.clone()) {
128 Ok(bodies) => bodies,
129 Err(error) => {
130 trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, %error, "error retrieving block bodies");
131 Default::default()
132 }
133 }
134 }
135
136 fn on_bytecode_request(&self, code_hash: B256) -> Bytes {
137 match self.provider.bytecode(code_hash) {
138 Ok(Some(bytecode)) => bytecode,
139 Ok(None) => {
140 trace!(target: "ress::net::connection", peer_id = %self.peer_id, %code_hash, "bytecode not found");
141 Default::default()
142 }
143 Err(error) => {
144 trace!(target: "ress::net::connection", peer_id = %self.peer_id, %code_hash, %error, "error retrieving bytecode");
145 Default::default()
146 }
147 }
148 }
149
150 fn on_witness_response(
151 &self,
152 request: RequestPair<B256>,
153 witness_result: ProviderResult<Vec<Bytes>>,
154 ) -> RessProtocolMessage {
155 let peer_id = self.peer_id;
156 let block_hash = request.message;
157 let witness = match witness_result {
158 Ok(witness) => {
159 trace!(target: "ress::net::connection", %peer_id, %block_hash, len = witness.len(), "witness found");
160 witness
161 }
162 Err(error) => {
163 trace!(target: "ress::net::connection", %peer_id, %block_hash, %error, "error retrieving witness");
164 Default::default()
165 }
166 };
167 RessProtocolMessage::witness(request.request_id, witness)
168 }
169
170 fn on_ress_message(&mut self, msg: RessProtocolMessage) -> OnRessMessageOutcome {
171 match msg.message {
172 RessMessage::NodeType(node_type) => {
173 if !self.node_type.is_valid_connection(&node_type) {
174 return OnRessMessageOutcome::Terminate;
176 }
177 }
178 RessMessage::GetHeaders(req) => {
179 let request = req.message;
180 trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, "serving headers");
181 let header = self.on_headers_request(request);
182 let response = RessProtocolMessage::headers(req.request_id, header);
183 return OnRessMessageOutcome::Response(response.encoded());
184 }
185 RessMessage::GetBlockBodies(req) => {
186 let request = req.message;
187 trace!(target: "ress::net::connection", peer_id = %self.peer_id, ?request, "serving block bodies");
188 let bodies = self.on_block_bodies_request(request);
189 let response = RessProtocolMessage::block_bodies(req.request_id, bodies);
190 return OnRessMessageOutcome::Response(response.encoded());
191 }
192 RessMessage::GetBytecode(req) => {
193 let code_hash = req.message;
194 trace!(target: "ress::net::connection", peer_id = %self.peer_id, %code_hash, "serving bytecode");
195 let bytecode = self.on_bytecode_request(code_hash);
196 let response = RessProtocolMessage::bytecode(req.request_id, bytecode);
197 return OnRessMessageOutcome::Response(response.encoded());
198 }
199 RessMessage::GetWitness(req) => {
200 let block_hash = req.message;
201 trace!(target: "ress::net::connection", peer_id = %self.peer_id, %block_hash, "serving witness");
202 let provider = self.provider.clone();
203 self.pending_witnesses.push(Box::pin(async move {
204 let result = provider.witness(block_hash).await;
205 (req, result)
206 }));
207 }
208 RessMessage::Headers(res) => {
209 if let Some(RessPeerRequest::GetHeaders { tx, .. }) =
210 self.inflight_requests.remove(&res.request_id)
211 {
212 let _ = tx.send(res.message);
213 } else {
214 self.report_bad_message();
215 }
216 }
217 RessMessage::BlockBodies(res) => {
218 if let Some(RessPeerRequest::GetBlockBodies { tx, .. }) =
219 self.inflight_requests.remove(&res.request_id)
220 {
221 let _ = tx.send(res.message);
222 } else {
223 self.report_bad_message();
224 }
225 }
226 RessMessage::Bytecode(res) => {
227 if let Some(RessPeerRequest::GetBytecode { tx, .. }) =
228 self.inflight_requests.remove(&res.request_id)
229 {
230 let _ = tx.send(res.message);
231 } else {
232 self.report_bad_message();
233 }
234 }
235 RessMessage::Witness(res) => {
236 if let Some(RessPeerRequest::GetWitness { tx, .. }) =
237 self.inflight_requests.remove(&res.request_id)
238 {
239 let _ = tx.send(res.message);
240 } else {
241 self.report_bad_message();
242 }
243 }
244 };
245 OnRessMessageOutcome::None
246 }
247}
248
249impl<P> Drop for RessProtocolConnection<P> {
250 fn drop(&mut self) {
251 let _ = self
252 .active_connections
253 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |c| Some(c.saturating_sub(1)));
254 }
255}
256
257impl<P> Stream for RessProtocolConnection<P>
258where
259 P: RessProtocolProvider + Clone + Unpin + 'static,
260{
261 type Item = BytesMut;
262
263 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
264 let this = self.get_mut();
265
266 if this.terminated {
267 return Poll::Ready(None)
268 }
269
270 if !this.node_type_sent {
271 this.node_type_sent = true;
272 return Poll::Ready(Some(RessProtocolMessage::node_type(this.node_type).encoded()))
273 }
274
275 'conn: loop {
276 if let Poll::Ready(Some(cmd)) = this.commands.poll_next_unpin(cx) {
277 let message = this.on_command(cmd);
278 let encoded = message.encoded();
279 trace!(target: "ress::net::connection", peer_id = %this.peer_id, ?message, encoded = alloy_primitives::hex::encode(&encoded), "Sending peer command");
280 return Poll::Ready(Some(encoded));
281 }
282
283 if let Poll::Ready(Some((request, witness_result))) =
284 this.pending_witnesses.poll_next_unpin(cx)
285 {
286 let response = this.on_witness_response(request, witness_result);
287 return Poll::Ready(Some(response.encoded()));
288 }
289
290 if let Poll::Ready(maybe_msg) = this.conn.poll_next_unpin(cx) {
291 let Some(next) = maybe_msg else { break 'conn };
292 let msg = match RessProtocolMessage::decode_message(&mut &next[..]) {
293 Ok(msg) => {
294 trace!(target: "ress::net::connection", peer_id = %this.peer_id, message = ?msg.message_type, "Processing message");
295 msg
296 }
297 Err(error) => {
298 trace!(target: "ress::net::connection", peer_id = %this.peer_id, %error, "Error decoding peer message");
299 this.report_bad_message();
300 continue;
301 }
302 };
303
304 match this.on_ress_message(msg) {
305 OnRessMessageOutcome::Response(bytes) => return Poll::Ready(Some(bytes)),
306 OnRessMessageOutcome::Terminate => break 'conn,
307 OnRessMessageOutcome::None => {}
308 };
309
310 continue;
311 }
312
313 return Poll::Pending;
314 }
315
316 this.terminated = true;
318 Poll::Ready(None)
319 }
320}
321
322type WitnessFut =
323 Pin<Box<dyn Future<Output = (RequestPair<B256>, ProviderResult<Vec<Bytes>>)> + Send>>;
324
325#[derive(Debug)]
327pub enum RessPeerRequest {
328 GetHeaders {
330 request: GetHeaders,
332 tx: oneshot::Sender<Vec<Header>>,
334 },
335 GetBlockBodies {
337 request: Vec<BlockHash>,
339 tx: oneshot::Sender<Vec<BlockBody>>,
341 },
342 GetBytecode {
344 code_hash: B256,
346 tx: oneshot::Sender<Bytes>,
348 },
349 GetWitness {
351 block_hash: BlockHash,
353 tx: oneshot::Sender<Vec<Bytes>>,
355 },
356}
357
358enum OnRessMessageOutcome {
359 Response(BytesMut),
361 Terminate,
363 None,
365}