Skip to main content

reth_network/session/
active.rs

1//! Represents an established session.
2
3use core::sync::atomic::Ordering;
4use std::{
5    collections::VecDeque,
6    future::Future,
7    net::SocketAddr,
8    pin::Pin,
9    sync::{
10        atomic::{AtomicU64, AtomicUsize},
11        Arc,
12    },
13    task::{ready, Context, Poll},
14    time::{Duration, Instant},
15};
16
17use crate::{
18    message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
19    session::{
20        conn::EthRlpxConnection,
21        handle::{ActiveSessionMessage, SessionCommand},
22        BlockRangeInfo, EthVersion, SessionId,
23    },
24};
25use alloy_eips::merge::EPOCH_SLOTS;
26use alloy_primitives::Sealable;
27use futures::{stream::Fuse, SinkExt, StreamExt};
28use metrics::{Counter, Gauge};
29use reth_eth_wire::{
30    errors::{EthHandshakeError, EthStreamError},
31    message::{EthBroadcastMessage, MessageError},
32    Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives, NewBlockPayload,
33};
34use reth_eth_wire_types::{message::RequestPair, NewPooledTransactionHashes, RawCapabilityMessage};
35use reth_metrics::common::mpsc::MeteredPollSender;
36use reth_network_api::PeerRequest;
37use reth_network_p2p::error::RequestError;
38use reth_network_peers::PeerId;
39use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
40use reth_primitives_traits::Block;
41use rustc_hash::FxHashMap;
42use tokio::{
43    sync::{mpsc, mpsc::error::TrySendError, oneshot},
44    time::Interval,
45};
46use tokio_stream::wrappers::ReceiverStream;
47use tokio_util::sync::PollSender;
48use tracing::{debug, trace};
49
50/// The recommended interval at which to check if a new range update should be sent to the remote
51/// peer.
52///
53/// Updates are only sent when the block height has advanced by at least one epoch (32 blocks)
54/// since the last update. The interval is set to one epoch duration in seconds.
55pub(super) const RANGE_UPDATE_INTERVAL: Duration = Duration::from_secs(EPOCH_SLOTS * 12);
56
57// Constants for timeout updating.
58
59/// Minimum timeout value
60const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
61
62/// Maximum timeout value
63const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
64/// How much the new measurements affect the current timeout (X percent)
65const SAMPLE_IMPACT: f64 = 0.1;
66/// Amount of RTTs before timeout
67const TIMEOUT_SCALING: u32 = 3;
68
69/// Restricts the number of queued outgoing messages for larger responses:
70///  - Block Bodies
71///  - Receipts
72///  - Headers
73///  - `PooledTransactions`
74///
75/// With proper softlimits in place (2MB) this targets 10MB (4+1 * 2MB) of outgoing response data.
76///
77/// This parameter serves as backpressure for reading additional requests from the remote.
78/// Once we've queued up more responses than this, the session should prioritize message flushing
79/// before reading any more messages from the remote peer, throttling the peer.
80const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
81
82/// Soft limit for the total number of buffered outgoing broadcast items (e.g. transaction hashes).
83///
84/// Many small broadcast messages carrying a single tx hash each are equivalent in cost to one
85/// message carrying many hashes. This limit counts individual items (hashes, transactions, blocks)
86/// rather than messages, so that many small messages don't trigger aggressive drops unnecessarily.
87const MAX_QUEUED_BROADCAST_ITEMS: usize = 4096;
88
89/// Shared counter for in-flight broadcast items (tx hashes, transactions, blocks) across the
90/// bounded command channel, unbounded overflow channel, and session outgoing queue.
91///
92/// Wrapped in a newtype so the backing storage can be changed later (e.g. to track memory) without
93/// touching every call-site.
94#[derive(Debug, Clone)]
95pub(crate) struct BroadcastItemCounter(Arc<AtomicUsize>);
96
97impl BroadcastItemCounter {
98    /// Creates a new counter starting at zero.
99    pub(crate) fn new() -> Self {
100        Self(Arc::new(AtomicUsize::new(0)))
101    }
102
103    /// Returns the current count.
104    pub(crate) fn get(&self) -> usize {
105        self.0.load(Ordering::Relaxed)
106    }
107
108    /// Attempts to add `n` items. Returns `true` if under the limit, `false` if over (no change).
109    pub(crate) fn try_add(&self, n: usize) -> bool {
110        let prev = self.0.fetch_add(n, Ordering::Relaxed);
111        if prev >= MAX_QUEUED_BROADCAST_ITEMS {
112            self.0.fetch_sub(n, Ordering::Relaxed);
113            false
114        } else {
115            true
116        }
117    }
118
119    /// Subtracts `n` items from the counter.
120    pub(crate) fn sub(&self, n: usize) {
121        self.0.fetch_sub(n, Ordering::Relaxed);
122    }
123}
124
125/// The type that advances an established session by listening for incoming messages (from local
126/// node or read from connection) and emitting events back to the
127/// [`SessionManager`](super::SessionManager).
128///
129/// It listens for
130///    - incoming commands from the [`SessionManager`](super::SessionManager)
131///    - incoming _internal_ requests/broadcasts via the request/command channel
132///    - incoming requests/broadcasts _from remote_ via the connection
133///    - responses for handled ETH requests received from the remote peer.
134#[expect(dead_code)]
135pub(crate) struct ActiveSession<N: NetworkPrimitives> {
136    /// Keeps track of request ids.
137    pub(crate) next_id: u64,
138    /// The underlying connection.
139    pub(crate) conn: EthRlpxConnection<N>,
140    /// Identifier of the node we're connected to.
141    pub(crate) remote_peer_id: PeerId,
142    /// The address we're connected to.
143    pub(crate) remote_addr: SocketAddr,
144    /// All capabilities the peer announced
145    pub(crate) remote_capabilities: Arc<Capabilities>,
146    /// Internal identifier of this session
147    pub(crate) session_id: SessionId,
148    /// Incoming commands from the manager
149    pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
150    /// Unbounded channel for commands that couldn't fit in the bounded channel (broadcast
151    /// overflow) and for disconnect commands that must never be dropped.
152    pub(crate) unbounded_rx: mpsc::UnboundedReceiver<SessionCommand<N>>,
153    /// Counter for broadcast messages received via the unbounded overflow channel.
154    pub(crate) unbounded_broadcast_msgs: Counter,
155    /// Sink to send messages to the [`SessionManager`](super::SessionManager).
156    pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
157    /// A message that needs to be delivered to the session manager
158    pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
159    /// Incoming internal requests which are delegated to the remote peer.
160    pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
161    /// All requests sent to the remote peer we're waiting on a response
162    pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
163    /// All requests that were sent by the remote peer and we're waiting on an internal response
164    pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
165    /// Buffered messages that should be handled and sent to the peer.
166    pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
167    /// The maximum time we wait for a response from a peer.
168    pub(crate) internal_request_timeout: Arc<AtomicU64>,
169    /// Interval when to check for timed out requests.
170    pub(crate) internal_request_timeout_interval: Interval,
171    /// If an [`ActiveSession`] does not receive a response at all within this duration then it is
172    /// considered a protocol violation and the session will initiate a drop.
173    pub(crate) protocol_breach_request_timeout: Duration,
174    /// Used to reserve a slot to guarantee that the termination message is delivered
175    pub(crate) terminate_message:
176        Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
177    /// The eth69 range info for the remote peer.
178    /// This is `None` for peers negotiating versions below `eth/69`.
179    pub(crate) range_info: Option<BlockRangeInfo>,
180    /// The eth69 range info for the local node (this node).
181    /// This represents the range of blocks that this node can serve to other peers.
182    pub(crate) local_range_info: BlockRangeInfo,
183    /// Optional interval for sending periodic range updates to the remote peer (eth69+)
184    /// The interval is set to one epoch duration (~6.4 minutes), but updates are only sent when
185    /// the block height has advanced by at least one epoch (32 blocks) since the last update
186    pub(crate) range_update_interval: Option<Interval>,
187    /// The last latest block number we sent in a range update
188    /// Used to avoid sending unnecessary updates when block height hasn't changed significantly
189    pub(crate) last_sent_latest_block: Option<u64>,
190}
191
192impl<N: NetworkPrimitives> ActiveSession<N> {
193    /// Returns `true` if the session is currently in the process of disconnecting
194    fn is_disconnecting(&self) -> bool {
195        self.conn.inner().is_disconnecting()
196    }
197
198    /// Returns the next request id
199    const fn next_id(&mut self) -> u64 {
200        let id = self.next_id;
201        self.next_id += 1;
202        id
203    }
204
205    /// Shrinks the capacity of the internal buffers.
206    pub fn shrink_to_fit(&mut self) {
207        self.received_requests_from_remote.shrink_to_fit();
208        self.queued_outgoing.shrink_to_fit();
209    }
210
211    /// Returns how many responses we've currently queued up.
212    fn queued_response_count(&self) -> usize {
213        self.queued_outgoing.messages.iter().filter(|m| m.is_response()).count()
214    }
215
216    /// Handle a message read from the connection.
217    ///
218    /// Returns an error if the message is considered to be in violation of the protocol.
219    fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
220        /// A macro that handles an incoming request
221        /// This creates a new channel and tries to send the sender half to the session while
222        /// storing the receiver half internally so the pending response can be polled.
223        macro_rules! on_request {
224            ($req:ident, $resp_item:ident, $req_item:ident) => {{
225                let RequestPair { request_id, message: request } = $req;
226                let (tx, response) = oneshot::channel();
227                let received = ReceivedRequest {
228                    request_id,
229                    rx: PeerResponse::$resp_item { response },
230                    received: Instant::now(),
231                };
232                self.received_requests_from_remote.push(received);
233                self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
234                    request,
235                    response: tx,
236                }))
237                .into()
238            }};
239        }
240
241        /// Processes a response received from the peer
242        macro_rules! on_response {
243            ($resp:ident, $item:ident) => {{
244                let RequestPair { request_id, message } = $resp;
245                if let Some(req) = self.inflight_requests.remove(&request_id) {
246                    match req.request {
247                        RequestState::Waiting(PeerRequest::$item { response, .. }) => {
248                            trace!(peer_id=?self.remote_peer_id, ?request_id, "received response from peer");
249                            let _ = response.send(Ok(message));
250                            self.update_request_timeout(req.timestamp, Instant::now());
251                        }
252                        RequestState::Waiting(request) => {
253                            request.send_bad_response();
254                        }
255                        RequestState::TimedOut => {
256                            // request was already timed out internally
257                            self.update_request_timeout(req.timestamp, Instant::now());
258                        }
259                    }
260                } else {
261                    trace!(peer_id=?self.remote_peer_id, ?request_id, "received response to unknown request");
262                    // we received a response to a request we never sent
263                    self.on_bad_message();
264                }
265
266                OnIncomingMessageOutcome::Ok
267            }};
268        }
269
270        match msg {
271            message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
272                error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
273                message,
274            },
275            EthMessage::NewBlockHashes(msg) => {
276                self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
277            }
278            EthMessage::NewBlock(msg) => {
279                let block = NewBlockMessage {
280                    hash: msg.block().header().hash_slow(),
281                    block: Arc::new(*msg),
282                };
283                self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
284            }
285            EthMessage::Transactions(msg) => {
286                self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
287            }
288            EthMessage::NewPooledTransactionHashes66(msg) => {
289                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
290            }
291            EthMessage::NewPooledTransactionHashes68(msg) => {
292                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
293            }
294            EthMessage::GetBlockHeaders(req) => {
295                on_request!(req, BlockHeaders, GetBlockHeaders)
296            }
297            EthMessage::BlockHeaders(resp) => {
298                on_response!(resp, GetBlockHeaders)
299            }
300            EthMessage::GetBlockBodies(req) => {
301                on_request!(req, BlockBodies, GetBlockBodies)
302            }
303            EthMessage::BlockBodies(resp) => {
304                on_response!(resp, GetBlockBodies)
305            }
306            EthMessage::GetPooledTransactions(req) => {
307                on_request!(req, PooledTransactions, GetPooledTransactions)
308            }
309            EthMessage::PooledTransactions(resp) => {
310                on_response!(resp, GetPooledTransactions)
311            }
312            EthMessage::GetNodeData(req) => {
313                on_request!(req, NodeData, GetNodeData)
314            }
315            EthMessage::NodeData(resp) => {
316                on_response!(resp, GetNodeData)
317            }
318            EthMessage::GetReceipts(req) => {
319                if self.conn.version() >= EthVersion::Eth69 {
320                    on_request!(req, Receipts69, GetReceipts69)
321                } else {
322                    on_request!(req, Receipts, GetReceipts)
323                }
324            }
325            EthMessage::GetReceipts70(req) => {
326                on_request!(req, Receipts70, GetReceipts70)
327            }
328            EthMessage::Receipts(resp) => {
329                on_response!(resp, GetReceipts)
330            }
331            EthMessage::Receipts69(resp) => {
332                on_response!(resp, GetReceipts69)
333            }
334            EthMessage::Receipts70(resp) => {
335                on_response!(resp, GetReceipts70)
336            }
337            EthMessage::GetBlockAccessLists(req) => {
338                on_request!(req, BlockAccessLists, GetBlockAccessLists)
339            }
340            EthMessage::BlockAccessLists(resp) => {
341                on_response!(resp, GetBlockAccessLists)
342            }
343            EthMessage::BlockRangeUpdate(msg) => {
344                // Validate that earliest <= latest according to the spec
345                if msg.earliest > msg.latest {
346                    return OnIncomingMessageOutcome::BadMessage {
347                        error: EthStreamError::InvalidMessage(MessageError::Other(format!(
348                            "invalid block range: earliest ({}) > latest ({})",
349                            msg.earliest, msg.latest
350                        ))),
351                        message: EthMessage::BlockRangeUpdate(msg),
352                    };
353                }
354
355                // Validate that the latest hash is not zero
356                if msg.latest_hash.is_zero() {
357                    return OnIncomingMessageOutcome::BadMessage {
358                        error: EthStreamError::InvalidMessage(MessageError::Other(
359                            "invalid block range: latest_hash cannot be zero".to_string(),
360                        )),
361                        message: EthMessage::BlockRangeUpdate(msg),
362                    };
363                }
364
365                if let Some(range_info) = self.range_info.as_ref() {
366                    range_info.update(msg.earliest, msg.latest, msg.latest_hash);
367                }
368
369                OnIncomingMessageOutcome::Ok
370            }
371            EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
372        }
373    }
374
375    /// Handle an internal peer request that will be sent to the remote.
376    fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
377        let version = self.conn.version();
378        if !Self::is_request_supported_for_version(&request, version) {
379            debug!(
380                target: "net",
381                ?request,
382                peer_id=?self.remote_peer_id,
383                ?version,
384                "Request not supported for negotiated eth version",
385            );
386            request.send_err_response(RequestError::UnsupportedCapability);
387            return;
388        }
389
390        let request_id = self.next_id();
391        trace!(?request, peer_id=?self.remote_peer_id, ?request_id, "sending request to peer");
392        let msg = request.create_request_message(request_id).map_versioned(version);
393
394        self.queued_outgoing.push_back(msg.into());
395        let req = InflightRequest {
396            request: RequestState::Waiting(request),
397            timestamp: Instant::now(),
398            deadline,
399        };
400        self.inflight_requests.insert(request_id, req);
401    }
402
403    #[inline]
404    fn is_request_supported_for_version(request: &PeerRequest<N>, version: EthVersion) -> bool {
405        request.is_supported_by_eth_version(version)
406    }
407
408    /// Handle a message received from the internal network
409    fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
410        match msg {
411            PeerMessage::NewBlockHashes(msg) => {
412                self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
413            }
414            PeerMessage::NewBlock(msg) => {
415                self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
416            }
417            PeerMessage::PooledTransactions(msg) => {
418                if msg.is_valid_for_version(self.conn.version()) {
419                    self.queued_outgoing.push_pooled_hashes(msg);
420                } else {
421                    self.queued_outgoing.broadcast_items.sub(msg.len());
422                    debug!(target: "net", ?msg,  version=?self.conn.version(), "Message is invalid for connection version, skipping");
423                }
424            }
425            PeerMessage::EthRequest(req) => {
426                let deadline = self.request_deadline();
427                self.on_internal_peer_request(req, deadline);
428            }
429            PeerMessage::SendTransactions(msg) => {
430                self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
431            }
432            PeerMessage::BlockRangeUpdated(_) => {}
433            PeerMessage::ReceivedTransaction(_) => {
434                unreachable!("Not emitted by network")
435            }
436            PeerMessage::Other(other) => {
437                self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
438            }
439        }
440    }
441
442    /// Returns the deadline timestamp at which the request times out
443    fn request_deadline(&self) -> Instant {
444        Instant::now() +
445            Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
446    }
447
448    /// Handle a Response to the peer
449    ///
450    /// This will queue the response to be sent to the peer
451    fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
452        match resp.try_into_message(id) {
453            Ok(msg) => {
454                self.queued_outgoing.push_back(msg.into());
455            }
456            Err(err) => {
457                debug!(target: "net", %err, "Failed to respond to received request");
458            }
459        }
460    }
461
462    /// Send a message back to the [`SessionManager`](super::SessionManager).
463    ///
464    /// Returns the message if the bounded channel is currently unable to handle this message.
465    #[expect(clippy::result_large_err)]
466    fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
467        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
468
469        match sender
470            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
471        {
472            Ok(_) => Ok(()),
473            Err(err) => {
474                trace!(
475                    target: "net",
476                    %err,
477                    "no capacity for incoming broadcast",
478                );
479                match err {
480                    TrySendError::Full(msg) => Err(msg),
481                    TrySendError::Closed(_) => Ok(()),
482                }
483            }
484        }
485    }
486
487    /// Send a message back to the [`SessionManager`](super::SessionManager)
488    /// covering both broadcasts and incoming requests.
489    ///
490    /// Returns the message if the bounded channel is currently unable to handle this message.
491    #[expect(clippy::result_large_err)]
492    fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
493        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
494
495        match sender
496            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
497        {
498            Ok(_) => Ok(()),
499            Err(err) => {
500                trace!(
501                    target: "net",
502                    %err,
503                    "no capacity for incoming request",
504                );
505                match err {
506                    TrySendError::Full(msg) => Err(msg),
507                    TrySendError::Closed(_) => {
508                        // Note: this would mean the `SessionManager` was dropped, which is already
509                        // handled by checking if the command receiver channel has been closed.
510                        Ok(())
511                    }
512                }
513            }
514        }
515    }
516
517    /// Notify the manager that the peer sent a bad message
518    fn on_bad_message(&self) {
519        let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
520        let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
521    }
522
523    /// Report back that this session has been closed.
524    fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
525        trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
526        let msg = ActiveSessionMessage::Disconnected {
527            peer_id: self.remote_peer_id,
528            remote_addr: self.remote_addr,
529        };
530
531        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
532        self.poll_terminate_message(cx).expect("message is set")
533    }
534
535    /// Report back that this session has been closed due to an error
536    fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
537        let msg = ActiveSessionMessage::ClosedOnConnectionError {
538            peer_id: self.remote_peer_id,
539            remote_addr: self.remote_addr,
540            error,
541        };
542        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
543        self.poll_terminate_message(cx).expect("message is set")
544    }
545
546    /// Starts the disconnect process
547    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
548        Ok(self.conn.inner_mut().start_disconnect(reason)?)
549    }
550
551    /// Flushes the disconnect message and emits the corresponding message
552    fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
553        debug_assert!(self.is_disconnecting(), "not disconnecting");
554
555        // try to close the flush out the remaining Disconnect message
556        let _ = ready!(self.conn.poll_close_unpin(cx));
557        self.emit_disconnect(cx)
558    }
559
560    /// Attempts to disconnect by sending the given disconnect reason
561    fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
562        match self.start_disconnect(reason) {
563            Ok(()) => {
564                // we're done
565                self.poll_disconnect(cx)
566            }
567            Err(err) => {
568                debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
569                self.close_on_error(err, cx)
570            }
571        }
572    }
573
574    /// Checks for _internally_ timed out requests.
575    ///
576    /// If a requests misses its deadline, then it is timed out internally.
577    /// If a request misses the `protocol_breach_request_timeout` then this session is considered in
578    /// protocol violation and will close.
579    ///
580    /// Returns `true` if a peer missed the `protocol_breach_request_timeout`, in which case the
581    /// session should be terminated.
582    #[must_use]
583    fn check_timed_out_requests(&mut self, now: Instant) -> bool {
584        for (id, req) in &mut self.inflight_requests {
585            if req.is_timed_out(now) {
586                if req.is_waiting() {
587                    debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
588                    req.timeout();
589                } else if now - req.timestamp > self.protocol_breach_request_timeout {
590                    return true
591                }
592            }
593        }
594
595        false
596    }
597
598    /// Updates the request timeout with a request's timestamps
599    fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
600        let elapsed = received.saturating_duration_since(sent);
601
602        let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
603        let request_timeout = calculate_new_timeout(current, elapsed);
604        self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
605        self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
606    }
607
608    /// If a termination message is queued this will try to send it
609    fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
610        let (mut tx, msg) = self.terminate_message.take()?;
611        match tx.poll_reserve(cx) {
612            Poll::Pending => {
613                self.terminate_message = Some((tx, msg));
614                return Some(Poll::Pending)
615            }
616            Poll::Ready(Ok(())) => {
617                let _ = tx.send_item(msg);
618            }
619            Poll::Ready(Err(_)) => {
620                // channel closed
621            }
622        }
623        // terminate the task
624        Some(Poll::Ready(()))
625    }
626}
627
628impl<N: NetworkPrimitives> Future for ActiveSession<N> {
629    type Output = ();
630
631    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
632        let this = self.get_mut();
633
634        // if the session is terminate we have to send the termination message before we can close
635        if let Some(terminate) = this.poll_terminate_message(cx) {
636            return terminate
637        }
638
639        if this.is_disconnecting() {
640            return this.poll_disconnect(cx)
641        }
642
643        // The receive loop can be CPU intensive since it involves message decoding which could take
644        // up a lot of resources and increase latencies for other sessions if not yielded manually.
645        // If the budget is exhausted we manually yield back control to the (coop) scheduler. This
646        // manual yield point should prevent situations where polling appears to be frozen. See also <https://tokio.rs/blog/2020-04-preemption>
647        // And tokio's docs on cooperative scheduling <https://docs.rs/tokio/latest/tokio/task/#cooperative-scheduling>
648        let mut budget = 4;
649
650        // The main poll loop that drives the session
651        'main: loop {
652            let mut progress = false;
653
654            // we prioritize incoming commands sent from the session manager
655            loop {
656                match this.commands_rx.poll_next_unpin(cx) {
657                    Poll::Pending => break,
658                    Poll::Ready(None) => {
659                        // this is only possible when the manager was dropped, in which case we also
660                        // terminate this session
661                        return Poll::Ready(())
662                    }
663                    Poll::Ready(Some(cmd)) => {
664                        progress = true;
665                        match cmd {
666                            SessionCommand::Disconnect { reason } => {
667                                debug!(
668                                    target: "net::session",
669                                    ?reason,
670                                    remote_peer_id=?this.remote_peer_id,
671                                    "Received disconnect command for session"
672                                );
673                                let reason =
674                                    reason.unwrap_or(DisconnectReason::DisconnectRequested);
675
676                                return this.try_disconnect(reason, cx)
677                            }
678                            SessionCommand::Message(msg) => {
679                                this.on_internal_peer_message(msg);
680                            }
681                        }
682                    }
683                }
684            }
685
686            // Drain the unbounded channel (broadcast overflow + disconnect commands)
687            while let Poll::Ready(Some(cmd)) = this.unbounded_rx.poll_recv(cx) {
688                progress = true;
689                match cmd {
690                    SessionCommand::Message(msg) => {
691                        this.unbounded_broadcast_msgs.increment(1);
692                        this.on_internal_peer_message(msg);
693                    }
694                    SessionCommand::Disconnect { reason } => {
695                        let reason = reason.unwrap_or(DisconnectReason::DisconnectRequested);
696                        return this.try_disconnect(reason, cx)
697                    }
698                }
699            }
700
701            let deadline = this.request_deadline();
702
703            while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
704                progress = true;
705                this.on_internal_peer_request(req, deadline);
706            }
707
708            // Advance all active requests.
709            // We remove each request one by one and add them back.
710            for idx in (0..this.received_requests_from_remote.len()).rev() {
711                let mut req = this.received_requests_from_remote.swap_remove(idx);
712                match req.rx.poll(cx) {
713                    Poll::Pending => {
714                        // not ready yet
715                        this.received_requests_from_remote.push(req);
716                    }
717                    Poll::Ready(resp) => {
718                        this.handle_outgoing_response(req.request_id, resp);
719                    }
720                }
721            }
722
723            // Send messages by advancing the sink and queuing in buffered messages
724            while this.conn.poll_ready_unpin(cx).is_ready() {
725                if let Some(msg) = this.queued_outgoing.pop_front() {
726                    progress = true;
727                    let res = match msg {
728                        OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
729                        OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
730                        OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
731                    };
732                    if let Err(err) = res {
733                        debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
734                        // notify the manager
735                        return this.close_on_error(err, cx)
736                    }
737                } else {
738                    // no more messages to send over the wire
739                    break
740                }
741            }
742
743            // read incoming messages from the wire
744            'receive: loop {
745                // ensure we still have enough budget for another iteration
746                budget -= 1;
747                if budget == 0 {
748                    // make sure we're woken up again
749                    cx.waker().wake_by_ref();
750                    break 'main
751                }
752
753                // try to resend the pending message that we could not send because the channel was
754                // full. [`PollSender`] will ensure that we're woken up again when the channel is
755                // ready to receive the message, and will only error if the channel is closed.
756                if let Some(msg) = this.pending_message_to_session.take() {
757                    match this.to_session_manager.poll_reserve(cx) {
758                        Poll::Ready(Ok(_)) => {
759                            let _ = this.to_session_manager.send_item(msg);
760                        }
761                        Poll::Ready(Err(_)) => return Poll::Ready(()),
762                        Poll::Pending => {
763                            this.pending_message_to_session = Some(msg);
764                            break 'receive
765                        }
766                    };
767                }
768
769                // check whether we should throttle incoming messages
770                if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
771                    // we're currently waiting for the responses to the peer's requests which aren't
772                    // queued as outgoing yet
773                    //
774                    // Note: we don't need to register the waker here because we polled the requests
775                    // above
776                    break 'receive
777                }
778
779                // we also need to check if we have multiple responses queued up
780                if this.queued_outgoing.messages.len() > MAX_QUEUED_OUTGOING_RESPONSES &&
781                    this.queued_response_count() > MAX_QUEUED_OUTGOING_RESPONSES
782                {
783                    // if we've queued up more responses than allowed, we don't poll for new
784                    // messages and break the receive loop early
785                    //
786                    // Note: we don't need to register the waker here because we still have
787                    // queued messages and the sink impl registered the waker because we've
788                    // already advanced it to `Pending` earlier
789                    break 'receive
790                }
791
792                match this.conn.poll_next_unpin(cx) {
793                    Poll::Pending => break,
794                    Poll::Ready(None) => {
795                        if this.is_disconnecting() {
796                            break
797                        }
798                        debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
799                        return this.emit_disconnect(cx)
800                    }
801                    Poll::Ready(Some(res)) => {
802                        match res {
803                            Ok(msg) => {
804                                trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
805                                // decode and handle message
806                                match this.on_incoming_message(msg) {
807                                    OnIncomingMessageOutcome::Ok => {
808                                        // handled successfully
809                                        progress = true;
810                                    }
811                                    OnIncomingMessageOutcome::BadMessage { error, message } => {
812                                        debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
813                                        this.on_bad_message();
814                                        return this
815                                            .try_disconnect(DisconnectReason::ProtocolBreach, cx)
816                                    }
817                                    OnIncomingMessageOutcome::NoCapacity(msg) => {
818                                        // failed to send due to lack of capacity
819                                        this.pending_message_to_session = Some(msg);
820                                    }
821                                }
822                            }
823                            Err(err) => {
824                                debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
825                                if err.is_protocol_breach() {
826                                    this.on_bad_message();
827                                    return this.try_disconnect(DisconnectReason::ProtocolBreach, cx)
828                                }
829                                return this.close_on_error(err, cx)
830                            }
831                        }
832                    }
833                }
834            }
835
836            if !progress {
837                break 'main
838            }
839        }
840
841        if let Some(interval) = &mut this.range_update_interval {
842            // Check if we should send a range update based on block height changes
843            while interval.poll_tick(cx).is_ready() {
844                let current_latest = this.local_range_info.latest();
845                let should_send = if let Some(last_sent) = this.last_sent_latest_block {
846                    // Only send if block height has advanced by at least one epoch (32 blocks)
847                    current_latest.saturating_sub(last_sent) >= EPOCH_SLOTS
848                } else {
849                    true // First update, always send
850                };
851
852                if should_send {
853                    this.queued_outgoing.push_back(
854                        EthMessage::BlockRangeUpdate(this.local_range_info.to_message()).into(),
855                    );
856                    this.last_sent_latest_block = Some(current_latest);
857                }
858            }
859        }
860
861        while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
862            // check for timed out requests
863            if this.check_timed_out_requests(Instant::now()) &&
864                let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx)
865            {
866                let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
867                this.pending_message_to_session = Some(msg);
868            }
869        }
870
871        this.shrink_to_fit();
872
873        Poll::Pending
874    }
875}
876
877/// Tracks a request received from the peer
878pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
879    /// Protocol Identifier
880    request_id: u64,
881    /// Receiver half of the channel that's supposed to receive the proper response.
882    rx: PeerResponse<N>,
883    /// Timestamp when we read this msg from the wire.
884    #[expect(dead_code)]
885    received: Instant,
886}
887
888/// A request that waits for a response from the peer
889pub(crate) struct InflightRequest<R> {
890    /// Request we sent to peer and the internal response channel
891    request: RequestState<R>,
892    /// Instant when the request was sent
893    timestamp: Instant,
894    /// Time limit for the response
895    deadline: Instant,
896}
897
898// === impl InflightRequest ===
899
900impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
901    /// Returns true if the request is timedout
902    #[inline]
903    fn is_timed_out(&self, now: Instant) -> bool {
904        now > self.deadline
905    }
906
907    /// Returns true if we're still waiting for a response
908    #[inline]
909    const fn is_waiting(&self) -> bool {
910        matches!(self.request, RequestState::Waiting(_))
911    }
912
913    /// This will timeout the request by sending an error response to the internal channel
914    fn timeout(&mut self) {
915        let mut req = RequestState::TimedOut;
916        std::mem::swap(&mut self.request, &mut req);
917
918        if let RequestState::Waiting(req) = req {
919            req.send_err_response(RequestError::Timeout);
920        }
921    }
922}
923
924/// All outcome variants when handling an incoming message
925enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
926    /// Message successfully handled.
927    Ok,
928    /// Message is considered to be in violation of the protocol
929    BadMessage { error: EthStreamError, message: EthMessage<N> },
930    /// Currently no capacity to handle the message
931    NoCapacity(ActiveSessionMessage<N>),
932}
933
934impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
935    for OnIncomingMessageOutcome<N>
936{
937    fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
938        match res {
939            Ok(_) => Self::Ok,
940            Err(msg) => Self::NoCapacity(msg),
941        }
942    }
943}
944
945enum RequestState<R> {
946    /// Waiting for the response
947    Waiting(R),
948    /// Request already timed out
949    TimedOut,
950}
951
952/// Outgoing messages that can be sent over the wire.
953#[derive(Debug)]
954pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
955    /// A message that is owned.
956    Eth(EthMessage<N>),
957    /// A message that may be shared by multiple sessions.
958    Broadcast(EthBroadcastMessage<N>),
959    /// A raw capability message
960    Raw(RawCapabilityMessage),
961}
962
963impl<N: NetworkPrimitives> OutgoingMessage<N> {
964    /// Returns true if this is a response.
965    const fn is_response(&self) -> bool {
966        match self {
967            Self::Eth(msg) => msg.is_response(),
968            _ => false,
969        }
970    }
971
972    /// Returns the number of broadcast items in this message.
973    ///
974    /// For transaction hash announcements this is the number of hashes, for full transaction
975    /// broadcasts it is the number of transactions, and for blocks it is 1.
976    /// Request/response messages return 0.
977    fn broadcast_item_count(&self) -> usize {
978        match self {
979            Self::Eth(msg) => match msg {
980                EthMessage::NewBlockHashes(h) => h.len(),
981                EthMessage::NewPooledTransactionHashes66(h) => h.len(),
982                EthMessage::NewPooledTransactionHashes68(h) => h.hashes.len(),
983                _ => 0,
984            },
985            Self::Broadcast(msg) => match msg {
986                EthBroadcastMessage::NewBlock(_) => 1,
987                EthBroadcastMessage::Transactions(txs) => txs.len(),
988            },
989            Self::Raw(_) => 0,
990        }
991    }
992
993    /// Tries to merge pooled transaction hash announcements into this message, consuming the
994    /// incoming hashes. Returns `Some(incoming)` back if the variants don't match.
995    fn try_merge_hashes(
996        &mut self,
997        incoming: NewPooledTransactionHashes,
998    ) -> Option<NewPooledTransactionHashes> {
999        let Self::Eth(eth) = self else { return Some(incoming) };
1000        match (eth, incoming) {
1001            (
1002                EthMessage::NewPooledTransactionHashes66(existing),
1003                NewPooledTransactionHashes::Eth66(inc),
1004            ) => {
1005                existing.extend(inc);
1006                None
1007            }
1008            (
1009                EthMessage::NewPooledTransactionHashes68(existing),
1010                NewPooledTransactionHashes::Eth68(inc),
1011            ) => {
1012                existing.hashes.extend(inc.hashes);
1013                existing.sizes.extend(inc.sizes);
1014                existing.types.extend(inc.types);
1015                None
1016            }
1017            (_, incoming) => Some(incoming),
1018        }
1019    }
1020}
1021
1022impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
1023    fn from(value: EthMessage<N>) -> Self {
1024        Self::Eth(value)
1025    }
1026}
1027
1028impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
1029    fn from(value: EthBroadcastMessage<N>) -> Self {
1030        Self::Broadcast(value)
1031    }
1032}
1033
1034/// Calculates a new timeout using an updated estimation of the RTT
1035#[inline]
1036fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
1037    let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
1038
1039    // this dampens sudden changes by taking a weighted mean of the old and new values
1040    let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
1041
1042    smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
1043}
1044
1045/// A helper struct that wraps the queue of outgoing messages with broadcast-aware tracking.
1046///
1047/// Tracks both the total number of queued messages (via a metric gauge) and the total number of
1048/// broadcast items (tx hashes, transactions, blocks) via a shared atomic counter. The atomic
1049/// counter is shared with [`ActiveSessionHandle`](super::handle::ActiveSessionHandle) so the
1050/// [`SessionManager`](super::SessionManager) can apply size-based backpressure.
1051pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
1052    messages: VecDeque<OutgoingMessage<N>>,
1053    count: Gauge,
1054    /// Shared counter of buffered broadcast items for size-based backpressure.
1055    broadcast_items: BroadcastItemCounter,
1056}
1057
1058impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
1059    pub(crate) const fn new(metric: Gauge, broadcast_items: BroadcastItemCounter) -> Self {
1060        Self { messages: VecDeque::new(), count: metric, broadcast_items }
1061    }
1062
1063    pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
1064        self.messages.push_back(message);
1065        self.count.increment(1);
1066    }
1067
1068    pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
1069        self.messages.pop_front().inspect(|msg| {
1070            self.count.decrement(1);
1071            let items = msg.broadcast_item_count();
1072            if items > 0 {
1073                self.broadcast_items.sub(items);
1074            }
1075        })
1076    }
1077
1078    /// Pushes a pooled transaction hash announcement, merging into the last queued message if
1079    /// it is the same variant (eth66 or eth68).
1080    pub(crate) fn push_pooled_hashes(&mut self, msg: NewPooledTransactionHashes) {
1081        let msg = if let Some(last) = self.messages.back_mut() {
1082            match last.try_merge_hashes(msg) {
1083                None => return,
1084                Some(msg) => msg,
1085            }
1086        } else {
1087            msg
1088        };
1089        self.messages.push_back(EthMessage::from(msg).into());
1090        self.count.increment(1);
1091    }
1092
1093    pub(crate) fn shrink_to_fit(&mut self) {
1094        self.messages.shrink_to_fit();
1095    }
1096}
1097
1098impl<N: NetworkPrimitives> Drop for QueuedOutgoingMessages<N> {
1099    fn drop(&mut self) {
1100        // Ensure gauge is decremented for any remaining items to avoid metric leak on teardown.
1101        let remaining = self.messages.len();
1102        if remaining > 0 {
1103            self.count.decrement(remaining as f64);
1104        }
1105    }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110    use super::*;
1111    use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
1112    use alloy_eips::eip2124::ForkFilter;
1113    use reth_chainspec::MAINNET;
1114    use reth_ecies::stream::ECIESStream;
1115    use reth_eth_wire::{
1116        handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockAccessLists,
1117        GetBlockBodies, HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream,
1118        UnauthedP2PStream, UnifiedStatus,
1119    };
1120    use reth_eth_wire_types::{message::MAX_MESSAGE_SIZE, EthMessageID, RawCapabilityMessage};
1121    use reth_ethereum_forks::EthereumHardfork;
1122    use reth_network_peers::pk2id;
1123    use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
1124    use secp256k1::{SecretKey, SECP256K1};
1125    use tokio::{
1126        net::{TcpListener, TcpStream},
1127        sync::mpsc,
1128    };
1129
1130    /// Returns a testing `HelloMessage` and new secretkey
1131    fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
1132        HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
1133    }
1134
1135    struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
1136        _remote_capabilities: Arc<Capabilities>,
1137        active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
1138        active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
1139        to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
1140        secret_key: SecretKey,
1141        local_peer_id: PeerId,
1142        hello: HelloMessageWithProtocols,
1143        status: UnifiedStatus,
1144        fork_filter: ForkFilter,
1145        next_id: usize,
1146    }
1147
1148    impl<N: NetworkPrimitives> SessionBuilder<N> {
1149        fn next_id(&mut self) -> SessionId {
1150            let id = self.next_id;
1151            self.next_id += 1;
1152            SessionId(id)
1153        }
1154
1155        /// Connects a new Eth stream and executes the given closure with that established stream
1156        fn with_client_stream<F, O>(
1157            &self,
1158            local_addr: SocketAddr,
1159            f: F,
1160        ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1161        where
1162            F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
1163            O: Future<Output = ()> + Send + Sync,
1164        {
1165            let mut status = self.status;
1166            let fork_filter = self.fork_filter.clone();
1167            let local_peer_id = self.local_peer_id;
1168            let mut hello = self.hello.clone();
1169            let key = SecretKey::new(&mut rand_08::thread_rng());
1170            hello.id = pk2id(&key.public_key(SECP256K1));
1171            Box::pin(async move {
1172                let outgoing = TcpStream::connect(local_addr).await.unwrap();
1173                let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
1174
1175                let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
1176
1177                let eth_version = p2p_stream.shared_capabilities().eth_version().unwrap();
1178                status.set_eth_version(eth_version);
1179
1180                let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
1181                    .handshake(status, fork_filter)
1182                    .await
1183                    .unwrap();
1184                f(client_stream).await
1185            })
1186        }
1187
1188        async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
1189            let remote_addr = stream.local_addr().unwrap();
1190            let session_id = self.next_id();
1191            let (_disconnect_tx, disconnect_rx) = oneshot::channel();
1192            let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
1193
1194            tokio::task::spawn(start_pending_incoming_session(
1195                Arc::new(EthHandshake::default()),
1196                MAX_MESSAGE_SIZE,
1197                disconnect_rx,
1198                session_id,
1199                stream,
1200                pending_sessions_tx,
1201                remote_addr,
1202                self.secret_key,
1203                self.hello.clone(),
1204                self.status,
1205                self.fork_filter.clone(),
1206                Default::default(),
1207            ));
1208
1209            let mut stream = ReceiverStream::new(pending_sessions_rx);
1210
1211            match stream.next().await.unwrap() {
1212                PendingSessionEvent::Established {
1213                    session_id,
1214                    remote_addr,
1215                    peer_id,
1216                    capabilities,
1217                    conn,
1218                    ..
1219                } => {
1220                    let (_to_session_tx, messages_rx) = mpsc::channel(10);
1221                    let (commands_to_session, commands_rx) = mpsc::channel(10);
1222                    let (_unbounded_tx, unbounded_rx) = mpsc::unbounded_channel();
1223                    let poll_sender = PollSender::new(self.active_session_tx.clone());
1224
1225                    self.to_sessions.push(commands_to_session);
1226
1227                    ActiveSession {
1228                        next_id: 0,
1229                        remote_peer_id: peer_id,
1230                        remote_addr,
1231                        remote_capabilities: Arc::clone(&capabilities),
1232                        session_id,
1233                        commands_rx: ReceiverStream::new(commands_rx),
1234                        unbounded_rx,
1235                        unbounded_broadcast_msgs: Counter::noop(),
1236                        to_session_manager: MeteredPollSender::new(
1237                            poll_sender,
1238                            "network_active_session",
1239                        ),
1240                        pending_message_to_session: None,
1241                        internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
1242                        inflight_requests: Default::default(),
1243                        conn,
1244                        queued_outgoing: QueuedOutgoingMessages::new(
1245                            Gauge::noop(),
1246                            BroadcastItemCounter::new(),
1247                        ),
1248                        received_requests_from_remote: Default::default(),
1249                        internal_request_timeout_interval: tokio::time::interval(
1250                            INITIAL_REQUEST_TIMEOUT,
1251                        ),
1252                        internal_request_timeout: Arc::new(AtomicU64::new(
1253                            INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
1254                        )),
1255                        protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
1256                        terminate_message: None,
1257                        range_info: None,
1258                        local_range_info: BlockRangeInfo::new(
1259                            0,
1260                            1000,
1261                            alloy_primitives::B256::ZERO,
1262                        ),
1263                        range_update_interval: None,
1264                        last_sent_latest_block: None,
1265                    }
1266                }
1267                ev => {
1268                    panic!("unexpected message {ev:?}")
1269                }
1270            }
1271        }
1272    }
1273
1274    impl Default for SessionBuilder {
1275        fn default() -> Self {
1276            let (active_session_tx, active_session_rx) = mpsc::channel(100);
1277
1278            let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1279            let local_peer_id = pk2id(&pk);
1280
1281            Self {
1282                next_id: 0,
1283                _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1284                active_session_tx,
1285                active_session_rx: ReceiverStream::new(active_session_rx),
1286                to_sessions: vec![],
1287                hello: eth_hello(&secret_key),
1288                secret_key,
1289                local_peer_id,
1290                status: StatusBuilder::default().build(),
1291                fork_filter: MAINNET
1292                    .hardfork_fork_filter(EthereumHardfork::Frontier)
1293                    .expect("The Frontier fork filter should exist on mainnet"),
1294            }
1295        }
1296    }
1297
1298    #[tokio::test(flavor = "multi_thread")]
1299    async fn test_disconnect() {
1300        let mut builder = SessionBuilder::default();
1301
1302        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1303        let local_addr = listener.local_addr().unwrap();
1304
1305        let expected_disconnect = DisconnectReason::UselessPeer;
1306
1307        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1308            let msg = client_stream.next().await.unwrap().unwrap_err();
1309            assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1310        });
1311
1312        tokio::task::spawn(async move {
1313            let (incoming, _) = listener.accept().await.unwrap();
1314            let mut session = builder.connect_incoming(incoming).await;
1315
1316            session.start_disconnect(expected_disconnect).unwrap();
1317            session.await
1318        });
1319
1320        fut.await;
1321    }
1322
1323    #[tokio::test(flavor = "multi_thread")]
1324    async fn test_invalid_message_disconnects_with_protocol_breach() {
1325        let mut builder = SessionBuilder::default();
1326
1327        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1328        let local_addr = listener.local_addr().unwrap();
1329
1330        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1331            client_stream
1332                .start_send_raw(RawCapabilityMessage::eth(
1333                    EthMessageID::PooledTransactions,
1334                    vec![0xc0].into(),
1335                ))
1336                .unwrap();
1337            client_stream.flush().await.unwrap();
1338
1339            let msg = client_stream.next().await.unwrap().unwrap_err();
1340            assert_eq!(msg.as_disconnected(), Some(DisconnectReason::ProtocolBreach));
1341        });
1342
1343        let (tx, rx) = oneshot::channel();
1344
1345        tokio::task::spawn(async move {
1346            let (incoming, _) = listener.accept().await.unwrap();
1347            let session = builder.connect_incoming(incoming).await;
1348            session.await;
1349
1350            tx.send(()).unwrap();
1351        });
1352
1353        fut.await;
1354        rx.await.unwrap();
1355    }
1356
1357    #[tokio::test(flavor = "multi_thread")]
1358    async fn handle_dropped_stream() {
1359        let mut builder = SessionBuilder::default();
1360
1361        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1362        let local_addr = listener.local_addr().unwrap();
1363
1364        let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1365            drop(client_stream);
1366            tokio::time::sleep(Duration::from_secs(1)).await
1367        });
1368
1369        let (tx, rx) = oneshot::channel();
1370
1371        tokio::task::spawn(async move {
1372            let (incoming, _) = listener.accept().await.unwrap();
1373            let session = builder.connect_incoming(incoming).await;
1374            session.await;
1375
1376            tx.send(()).unwrap();
1377        });
1378
1379        tokio::task::spawn(fut);
1380
1381        rx.await.unwrap();
1382    }
1383
1384    #[tokio::test(flavor = "multi_thread")]
1385    async fn test_send_many_messages() {
1386        reth_tracing::init_test_tracing();
1387        let mut builder = SessionBuilder::default();
1388
1389        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1390        let local_addr = listener.local_addr().unwrap();
1391
1392        let num_messages = 100;
1393
1394        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1395            for _ in 0..num_messages {
1396                client_stream
1397                    .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1398                    .await
1399                    .unwrap();
1400            }
1401        });
1402
1403        let (tx, rx) = oneshot::channel();
1404
1405        tokio::task::spawn(async move {
1406            let (incoming, _) = listener.accept().await.unwrap();
1407            let session = builder.connect_incoming(incoming).await;
1408            session.await;
1409
1410            tx.send(()).unwrap();
1411        });
1412
1413        tokio::task::spawn(fut);
1414
1415        rx.await.unwrap();
1416    }
1417
1418    #[tokio::test(flavor = "multi_thread")]
1419    async fn test_request_timeout() {
1420        reth_tracing::init_test_tracing();
1421
1422        let mut builder = SessionBuilder::default();
1423
1424        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1425        let local_addr = listener.local_addr().unwrap();
1426
1427        let request_timeout = Duration::from_millis(100);
1428        let drop_timeout = Duration::from_millis(1500);
1429
1430        let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1431            let _client_stream = client_stream;
1432            tokio::time::sleep(drop_timeout * 60).await;
1433        });
1434        tokio::task::spawn(fut);
1435
1436        let (incoming, _) = listener.accept().await.unwrap();
1437        let mut session = builder.connect_incoming(incoming).await;
1438        session
1439            .internal_request_timeout
1440            .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1441        session.protocol_breach_request_timeout = drop_timeout;
1442        session.internal_request_timeout_interval =
1443            tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1444        let (tx, rx) = oneshot::channel();
1445        let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1446        session.on_internal_peer_request(req, Instant::now());
1447        tokio::spawn(session);
1448
1449        let err = rx.await.unwrap().unwrap_err();
1450        assert_eq!(err, RequestError::Timeout);
1451
1452        // wait for protocol breach error
1453        let msg = builder.active_session_rx.next().await.unwrap();
1454        match msg {
1455            ActiveSessionMessage::ProtocolBreach { .. } => {}
1456            ev => unreachable!("{ev:?}"),
1457        }
1458    }
1459
1460    #[test]
1461    fn test_reject_bal_request_for_eth70() {
1462        let (tx, _rx) = oneshot::channel();
1463        let request: PeerRequest<EthNetworkPrimitives> =
1464            PeerRequest::GetBlockAccessLists { request: GetBlockAccessLists(vec![]), response: tx };
1465
1466        assert!(!ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1467            &request,
1468            EthVersion::Eth70
1469        ));
1470        assert!(ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1471            &request,
1472            EthVersion::Eth71
1473        ));
1474    }
1475
1476    #[tokio::test(flavor = "multi_thread")]
1477    async fn test_keep_alive() {
1478        let mut builder = SessionBuilder::default();
1479
1480        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1481        let local_addr = listener.local_addr().unwrap();
1482
1483        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1484            let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1485            client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1486        });
1487
1488        let (tx, rx) = oneshot::channel();
1489
1490        tokio::task::spawn(async move {
1491            let (incoming, _) = listener.accept().await.unwrap();
1492            let session = builder.connect_incoming(incoming).await;
1493            session.await;
1494
1495            tx.send(()).unwrap();
1496        });
1497
1498        tokio::task::spawn(fut);
1499
1500        rx.await.unwrap();
1501    }
1502
1503    // This tests that incoming messages are delivered when there's capacity.
1504    #[tokio::test(flavor = "multi_thread")]
1505    async fn test_send_at_capacity() {
1506        let mut builder = SessionBuilder::default();
1507
1508        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1509        let local_addr = listener.local_addr().unwrap();
1510
1511        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1512            client_stream
1513                .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1514                .await
1515                .unwrap();
1516            let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1517        });
1518        tokio::task::spawn(fut);
1519
1520        let (incoming, _) = listener.accept().await.unwrap();
1521        let session = builder.connect_incoming(incoming).await;
1522
1523        // fill the entire message buffer with an unrelated message
1524        let mut num_fill_messages = 0;
1525        loop {
1526            if builder
1527                .active_session_tx
1528                .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1529                .is_err()
1530            {
1531                break
1532            }
1533            num_fill_messages += 1;
1534        }
1535
1536        tokio::task::spawn(async move {
1537            session.await;
1538        });
1539
1540        tokio::time::sleep(Duration::from_millis(100)).await;
1541
1542        for _ in 0..num_fill_messages {
1543            let message = builder.active_session_rx.next().await.unwrap();
1544            match message {
1545                ActiveSessionMessage::ProtocolBreach { .. } => {}
1546                ev => unreachable!("{ev:?}"),
1547            }
1548        }
1549
1550        let message = builder.active_session_rx.next().await.unwrap();
1551        match message {
1552            ActiveSessionMessage::ValidMessage {
1553                message: PeerMessage::PooledTransactions(_),
1554                ..
1555            } => {}
1556            _ => unreachable!(),
1557        }
1558    }
1559
1560    #[test]
1561    fn timeout_calculation_sanity_tests() {
1562        let rtt = Duration::from_secs(5);
1563        // timeout for an RTT of `rtt`
1564        let timeout = rtt * TIMEOUT_SCALING;
1565
1566        // if rtt hasn't changed, timeout shouldn't change
1567        assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1568
1569        // if rtt changed, the new timeout should change less than it
1570        assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1571        assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1572        assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1573        assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1574    }
1575}