Skip to main content

reth_network/session/
active.rs

1//! Represents an established session.
2
3use core::sync::atomic::Ordering;
4use std::{
5    collections::VecDeque,
6    future::Future,
7    net::SocketAddr,
8    pin::Pin,
9    sync::{
10        atomic::{AtomicU64, AtomicUsize},
11        Arc,
12    },
13    task::{ready, Context, Poll},
14    time::{Duration, Instant},
15};
16
17use crate::{
18    message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
19    session::{
20        conn::EthRlpxConnection,
21        handle::{ActiveSessionMessage, SessionCommand},
22        BlockRangeInfo, EthVersion, SessionId,
23    },
24};
25use alloy_eips::merge::EPOCH_SLOTS;
26use alloy_primitives::Sealable;
27use futures::{stream::Fuse, SinkExt, StreamExt};
28use metrics::{Counter, Gauge};
29use reth_eth_wire::{
30    errors::{EthHandshakeError, EthStreamError},
31    message::{EthBroadcastMessage, MessageError},
32    Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives, NewBlockPayload,
33};
34use reth_eth_wire_types::{message::RequestPair, NewPooledTransactionHashes, RawCapabilityMessage};
35use reth_metrics::common::mpsc::MeteredPollSender;
36use reth_network_api::PeerRequest;
37use reth_network_p2p::error::RequestError;
38use reth_network_peers::PeerId;
39use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
40use reth_primitives_traits::Block;
41use rustc_hash::FxHashMap;
42use tokio::{
43    sync::{mpsc, mpsc::error::TrySendError, oneshot},
44    time::Interval,
45};
46use tokio_stream::wrappers::ReceiverStream;
47use tokio_util::sync::PollSender;
48use tracing::{debug, trace};
49
50/// The recommended interval at which to check if a new range update should be sent to the remote
51/// peer.
52///
53/// Updates are only sent when the block height has advanced by at least one epoch (32 blocks)
54/// since the last update. The interval is set to one epoch duration in seconds.
55pub(super) const RANGE_UPDATE_INTERVAL: Duration = Duration::from_secs(EPOCH_SLOTS * 12);
56
57// Constants for timeout updating.
58
59/// Minimum timeout value
60const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
61
62/// Maximum timeout value
63const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
64/// How much the new measurements affect the current timeout (X percent)
65const SAMPLE_IMPACT: f64 = 0.1;
66/// Amount of RTTs before timeout
67const TIMEOUT_SCALING: u32 = 3;
68
69/// Restricts the number of queued outgoing messages for larger responses:
70///  - Block Bodies
71///  - Receipts
72///  - Headers
73///  - `PooledTransactions`
74///
75/// With proper softlimits in place (2MB) this targets 10MB (4+1 * 2MB) of outgoing response data.
76///
77/// This parameter serves as backpressure for reading additional requests from the remote.
78/// Once we've queued up more responses than this, the session should prioritize message flushing
79/// before reading any more messages from the remote peer, throttling the peer.
80const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
81
82/// Minimum capacity to retain for buffered incoming requests from the remote peer.
83const MIN_RECEIVED_REQUESTS_CAPACITY: usize = 1;
84
85/// Soft limit for the total number of buffered outgoing broadcast items (e.g. transaction hashes).
86///
87/// Many small broadcast messages carrying a single tx hash each are equivalent in cost to one
88/// message carrying many hashes. This limit counts individual items (hashes, transactions, blocks)
89/// rather than messages, so that many small messages don't trigger aggressive drops unnecessarily.
90const MAX_QUEUED_BROADCAST_ITEMS: usize = 4096;
91
92/// Shared counter for in-flight broadcast items (tx hashes, transactions, blocks) across the
93/// bounded command channel, unbounded overflow channel, and session outgoing queue.
94///
95/// Wrapped in a newtype so the backing storage can be changed later (e.g. to track memory) without
96/// touching every call-site.
97#[derive(Debug, Clone)]
98pub(crate) struct BroadcastItemCounter(Arc<AtomicUsize>);
99
100impl BroadcastItemCounter {
101    /// Creates a new counter starting at zero.
102    pub(crate) fn new() -> Self {
103        Self(Arc::new(AtomicUsize::new(0)))
104    }
105
106    /// Returns the current count.
107    pub(crate) fn get(&self) -> usize {
108        self.0.load(Ordering::Relaxed)
109    }
110
111    /// Attempts to add `n` items. Returns `true` if under the limit, `false` if over (no change).
112    pub(crate) fn try_add(&self, n: usize) -> bool {
113        let prev = self.0.fetch_add(n, Ordering::Relaxed);
114        if prev >= MAX_QUEUED_BROADCAST_ITEMS {
115            self.0.fetch_sub(n, Ordering::Relaxed);
116            false
117        } else {
118            true
119        }
120    }
121
122    /// Subtracts `n` items from the counter.
123    pub(crate) fn sub(&self, n: usize) {
124        self.0.fetch_sub(n, Ordering::Relaxed);
125    }
126}
127
128/// The type that advances an established session by listening for incoming messages (from local
129/// node or read from connection) and emitting events back to the
130/// [`SessionManager`](super::SessionManager).
131///
132/// It listens for
133///    - incoming commands from the [`SessionManager`](super::SessionManager)
134///    - incoming _internal_ requests/broadcasts via the request/command channel
135///    - incoming requests/broadcasts _from remote_ via the connection
136///    - responses for handled ETH requests received from the remote peer.
137#[expect(dead_code)]
138pub(crate) struct ActiveSession<N: NetworkPrimitives> {
139    /// Keeps track of request ids.
140    pub(crate) next_id: u64,
141    /// The underlying connection.
142    pub(crate) conn: EthRlpxConnection<N>,
143    /// Identifier of the node we're connected to.
144    pub(crate) remote_peer_id: PeerId,
145    /// The address we're connected to.
146    pub(crate) remote_addr: SocketAddr,
147    /// All capabilities the peer announced
148    pub(crate) remote_capabilities: Arc<Capabilities>,
149    /// Internal identifier of this session
150    pub(crate) session_id: SessionId,
151    /// Incoming commands from the manager
152    pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
153    /// Unbounded channel for commands that couldn't fit in the bounded channel (broadcast
154    /// overflow) and for disconnect commands that must never be dropped.
155    pub(crate) unbounded_rx: mpsc::UnboundedReceiver<SessionCommand<N>>,
156    /// Counter for broadcast messages received via the unbounded overflow channel.
157    pub(crate) unbounded_broadcast_msgs: Counter,
158    /// Sink to send messages to the [`SessionManager`](super::SessionManager).
159    pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
160    /// A message that needs to be delivered to the session manager
161    pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
162    /// Incoming internal requests which are delegated to the remote peer.
163    pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
164    /// All requests sent to the remote peer we're waiting on a response
165    pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
166    /// All requests that were sent by the remote peer and we're waiting on an internal response
167    pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
168    /// Buffered messages that should be handled and sent to the peer.
169    pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
170    /// The maximum time we wait for a response from a peer.
171    pub(crate) internal_request_timeout: Arc<AtomicU64>,
172    /// Interval when to check for timed out requests.
173    pub(crate) internal_request_timeout_interval: Interval,
174    /// If an [`ActiveSession`] does not receive a response at all within this duration then it is
175    /// considered a protocol violation and the session will initiate a drop.
176    pub(crate) protocol_breach_request_timeout: Duration,
177    /// Used to reserve a slot to guarantee that the termination message is delivered
178    pub(crate) terminate_message:
179        Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
180    /// The eth69 range info for the remote peer.
181    /// This is `None` for peers negotiating versions below `eth/69`.
182    pub(crate) range_info: Option<BlockRangeInfo>,
183    /// The eth69 range info for the local node (this node).
184    /// This represents the range of blocks that this node can serve to other peers.
185    pub(crate) local_range_info: BlockRangeInfo,
186    /// Optional interval for sending periodic range updates to the remote peer (eth69+)
187    /// The interval is set to one epoch duration (~6.4 minutes), but updates are only sent when
188    /// the block height has advanced by at least one epoch (32 blocks) since the last update
189    pub(crate) range_update_interval: Option<Interval>,
190    /// The last latest block number we sent in a range update
191    /// Used to avoid sending unnecessary updates when block height hasn't changed significantly
192    pub(crate) last_sent_latest_block: Option<u64>,
193}
194
195impl<N: NetworkPrimitives> ActiveSession<N> {
196    /// Returns `true` if the session is currently in the process of disconnecting
197    fn is_disconnecting(&self) -> bool {
198        self.conn.inner().is_disconnecting()
199    }
200
201    /// Returns the next request id
202    const fn next_id(&mut self) -> u64 {
203        let id = self.next_id;
204        self.next_id += 1;
205        id
206    }
207
208    /// Shrinks the capacity of the internal buffers.
209    pub fn shrink_to_fit(&mut self) {
210        self.received_requests_from_remote.shrink_to(MIN_RECEIVED_REQUESTS_CAPACITY);
211        self.queued_outgoing.shrink_to(MAX_QUEUED_OUTGOING_RESPONSES);
212    }
213
214    /// Handle a message read from the connection.
215    ///
216    /// Returns an error if the message is considered to be in violation of the protocol.
217    fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
218        /// A macro that handles an incoming request
219        /// This creates a new channel and tries to send the sender half to the session while
220        /// storing the receiver half internally so the pending response can be polled.
221        macro_rules! on_request {
222            ($req:ident, $resp_item:ident, $req_item:ident) => {{
223                let RequestPair { request_id, message: request } = $req;
224                let (tx, response) = oneshot::channel();
225                let received = ReceivedRequest {
226                    request_id,
227                    rx: PeerResponse::$resp_item { response },
228                    received: Instant::now(),
229                };
230                self.received_requests_from_remote.push(received);
231                self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
232                    request,
233                    response: tx,
234                }))
235                .into()
236            }};
237        }
238
239        /// Processes a response received from the peer
240        macro_rules! on_response {
241            ($resp:ident, $item:ident) => {{
242                let RequestPair { request_id, message } = $resp;
243                if let Some(req) = self.inflight_requests.remove(&request_id) {
244                    match req.request {
245                        RequestState::Waiting(PeerRequest::$item { response, .. }) => {
246                            trace!(peer_id=?self.remote_peer_id, ?request_id, "received response from peer");
247                            let _ = response.send(Ok(message));
248                            self.update_request_timeout(req.timestamp, Instant::now());
249                        }
250                        RequestState::Waiting(request) => {
251                            request.send_bad_response();
252                        }
253                        RequestState::TimedOut => {
254                            // request was already timed out internally
255                            self.update_request_timeout(req.timestamp, Instant::now());
256                        }
257                    }
258                } else {
259                    trace!(peer_id=?self.remote_peer_id, ?request_id, "received response to unknown request");
260                    // we received a response to a request we never sent
261                    self.on_bad_message();
262                }
263
264                OnIncomingMessageOutcome::Ok
265            }};
266        }
267
268        match msg {
269            message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
270                error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
271                message,
272            },
273            EthMessage::NewBlockHashes(msg) => {
274                self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
275            }
276            EthMessage::NewBlock(msg) => {
277                let block = NewBlockMessage {
278                    hash: msg.block().header().hash_slow(),
279                    block: Arc::new(*msg),
280                };
281                self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
282            }
283            EthMessage::Transactions(msg) => {
284                self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
285            }
286            EthMessage::NewPooledTransactionHashes66(msg) => {
287                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
288            }
289            EthMessage::NewPooledTransactionHashes68(msg) => {
290                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
291            }
292            EthMessage::NewPooledTransactionHashes72(msg) => {
293                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
294            }
295            EthMessage::GetBlockHeaders(req) => {
296                on_request!(req, BlockHeaders, GetBlockHeaders)
297            }
298            EthMessage::BlockHeaders(resp) => {
299                on_response!(resp, GetBlockHeaders)
300            }
301            EthMessage::GetBlockBodies(req) => {
302                on_request!(req, BlockBodies, GetBlockBodies)
303            }
304            EthMessage::BlockBodies(resp) => {
305                on_response!(resp, GetBlockBodies)
306            }
307            EthMessage::GetPooledTransactions(req) => {
308                on_request!(req, PooledTransactions, GetPooledTransactions)
309            }
310            EthMessage::PooledTransactions(resp) => {
311                on_response!(resp, GetPooledTransactions)
312            }
313            EthMessage::GetNodeData(req) => {
314                on_request!(req, NodeData, GetNodeData)
315            }
316            EthMessage::NodeData(resp) => {
317                on_response!(resp, GetNodeData)
318            }
319            EthMessage::GetReceipts(req) => {
320                if self.conn.version() >= EthVersion::Eth69 {
321                    on_request!(req, Receipts69, GetReceipts69)
322                } else {
323                    on_request!(req, Receipts, GetReceipts)
324                }
325            }
326            EthMessage::GetReceipts70(req) => {
327                on_request!(req, Receipts70, GetReceipts70)
328            }
329            EthMessage::Receipts(resp) => {
330                on_response!(resp, GetReceipts)
331            }
332            EthMessage::Receipts69(resp) => {
333                on_response!(resp, GetReceipts69)
334            }
335            EthMessage::Receipts70(resp) => {
336                on_response!(resp, GetReceipts70)
337            }
338            EthMessage::GetBlockAccessLists(req) => {
339                on_request!(req, BlockAccessLists, GetBlockAccessLists)
340            }
341            EthMessage::BlockAccessLists(resp) => {
342                on_response!(resp, GetBlockAccessLists)
343            }
344            EthMessage::Cells(resp) => {
345                on_response!(resp, GetCells)
346            }
347            EthMessage::BlockRangeUpdate(msg) => {
348                // Validate that earliest <= latest according to the spec
349                if msg.earliest > msg.latest {
350                    return OnIncomingMessageOutcome::BadMessage {
351                        error: EthStreamError::InvalidMessage(MessageError::Other(format!(
352                            "invalid block range: earliest ({}) > latest ({})",
353                            msg.earliest, msg.latest
354                        ))),
355                        message: EthMessage::BlockRangeUpdate(msg),
356                    };
357                }
358
359                // Validate that the latest hash is not zero
360                if msg.latest_hash.is_zero() {
361                    return OnIncomingMessageOutcome::BadMessage {
362                        error: EthStreamError::InvalidMessage(MessageError::Other(
363                            "invalid block range: latest_hash cannot be zero".to_string(),
364                        )),
365                        message: EthMessage::BlockRangeUpdate(msg),
366                    };
367                }
368
369                if let Some(range_info) = self.range_info.as_ref() {
370                    range_info.update(msg.earliest, msg.latest, msg.latest_hash);
371                }
372
373                OnIncomingMessageOutcome::Ok
374            }
375            EthMessage::GetCells(resp) => {
376                on_request!(resp, Cells, GetCells)
377            }
378            EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
379        }
380    }
381
382    /// Handle an internal peer request that will be sent to the remote.
383    fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
384        let version = self.conn.version();
385        if !Self::is_request_supported_for_version(&request, version) {
386            debug!(
387                target: "net",
388                ?request,
389                peer_id=?self.remote_peer_id,
390                ?version,
391                "Request not supported for negotiated eth version",
392            );
393            request.send_err_response(RequestError::UnsupportedCapability);
394            return;
395        }
396
397        let request_id = self.next_id();
398        trace!(?request, peer_id=?self.remote_peer_id, ?request_id, "sending request to peer");
399        let msg = request.create_request_message(request_id).map_versioned(version);
400
401        self.queued_outgoing.push_back(msg.into());
402        let req = InflightRequest {
403            request: RequestState::Waiting(request),
404            timestamp: Instant::now(),
405            deadline,
406        };
407        self.inflight_requests.insert(request_id, req);
408    }
409
410    #[inline]
411    fn is_request_supported_for_version(request: &PeerRequest<N>, version: EthVersion) -> bool {
412        request.is_supported_by_eth_version(version)
413    }
414
415    /// Handle a message received from the internal network
416    fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
417        match msg {
418            PeerMessage::NewBlockHashes(msg) => {
419                self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
420            }
421            PeerMessage::NewBlock(msg) => {
422                self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
423            }
424            PeerMessage::PooledTransactions(msg) => {
425                if msg.is_valid_for_version(self.conn.version()) {
426                    self.queued_outgoing.push_pooled_hashes(msg);
427                } else {
428                    self.queued_outgoing.broadcast_items.sub(msg.len());
429                    debug!(target: "net", ?msg,  version=?self.conn.version(), "Message is invalid for connection version, skipping");
430                }
431            }
432            PeerMessage::EthRequest(req) => {
433                let deadline = self.request_deadline();
434                self.on_internal_peer_request(req, deadline);
435            }
436            PeerMessage::SendTransactions(msg) => {
437                self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
438            }
439            PeerMessage::BlockRangeUpdated(_) => {}
440            PeerMessage::ReceivedTransaction(_) => {
441                unreachable!("Not emitted by network")
442            }
443            PeerMessage::Other(other) => {
444                self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
445            }
446        }
447    }
448
449    /// Returns the deadline timestamp at which the request times out
450    fn request_deadline(&self) -> Instant {
451        Instant::now() +
452            Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
453    }
454
455    /// Handle a Response to the peer
456    ///
457    /// This will queue the response to be sent to the peer
458    fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
459        match resp.try_into_message(id) {
460            Ok(msg) => {
461                self.queued_outgoing.push_back(msg.into());
462            }
463            Err(err) => {
464                debug!(target: "net", %err, "Failed to respond to received request");
465            }
466        }
467    }
468
469    /// Send a message back to the [`SessionManager`](super::SessionManager).
470    ///
471    /// Returns the message if the bounded channel is currently unable to handle this message.
472    #[expect(clippy::result_large_err)]
473    fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
474        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
475
476        match sender
477            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
478        {
479            Ok(_) => Ok(()),
480            Err(err) => {
481                trace!(
482                    target: "net",
483                    %err,
484                    "no capacity for incoming broadcast",
485                );
486                match err {
487                    TrySendError::Full(msg) => Err(msg),
488                    TrySendError::Closed(_) => Ok(()),
489                }
490            }
491        }
492    }
493
494    /// Send a message back to the [`SessionManager`](super::SessionManager)
495    /// covering both broadcasts and incoming requests.
496    ///
497    /// Returns the message if the bounded channel is currently unable to handle this message.
498    #[expect(clippy::result_large_err)]
499    fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
500        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
501
502        match sender
503            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
504        {
505            Ok(_) => Ok(()),
506            Err(err) => {
507                trace!(
508                    target: "net",
509                    %err,
510                    "no capacity for incoming request",
511                );
512                match err {
513                    TrySendError::Full(msg) => Err(msg),
514                    TrySendError::Closed(_) => {
515                        // Note: this would mean the `SessionManager` was dropped, which is already
516                        // handled by checking if the command receiver channel has been closed.
517                        Ok(())
518                    }
519                }
520            }
521        }
522    }
523
524    /// Notify the manager that the peer sent a bad message
525    fn on_bad_message(&self) {
526        let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
527        let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
528    }
529
530    /// Report back that this session has been closed.
531    fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
532        trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
533        let msg = ActiveSessionMessage::Disconnected {
534            peer_id: self.remote_peer_id,
535            remote_addr: self.remote_addr,
536        };
537
538        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
539        self.poll_terminate_message(cx).expect("message is set")
540    }
541
542    /// Report back that this session has been closed due to an error
543    fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
544        let msg = ActiveSessionMessage::ClosedOnConnectionError {
545            peer_id: self.remote_peer_id,
546            remote_addr: self.remote_addr,
547            error,
548        };
549        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
550        self.poll_terminate_message(cx).expect("message is set")
551    }
552
553    /// Starts the disconnect process
554    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
555        Ok(self.conn.inner_mut().start_disconnect(reason)?)
556    }
557
558    /// Flushes the disconnect message and emits the corresponding message
559    fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
560        debug_assert!(self.is_disconnecting(), "not disconnecting");
561
562        // try to close the flush out the remaining Disconnect message
563        let _ = ready!(self.conn.poll_close_unpin(cx));
564        self.emit_disconnect(cx)
565    }
566
567    /// Attempts to disconnect by sending the given disconnect reason
568    fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
569        match self.start_disconnect(reason) {
570            Ok(()) => {
571                // we're done
572                self.poll_disconnect(cx)
573            }
574            Err(err) => {
575                debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
576                self.close_on_error(err, cx)
577            }
578        }
579    }
580
581    /// Checks for _internally_ timed out requests.
582    ///
583    /// If a requests misses its deadline, then it is timed out internally.
584    /// If a request misses the `protocol_breach_request_timeout` then this session is considered in
585    /// protocol violation and will close.
586    ///
587    /// Returns `true` if a peer missed the `protocol_breach_request_timeout`, in which case the
588    /// session should be terminated.
589    #[must_use]
590    fn check_timed_out_requests(&mut self, now: Instant) -> bool {
591        for (id, req) in &mut self.inflight_requests {
592            if req.is_timed_out(now) {
593                if req.is_waiting() {
594                    debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
595                    req.timeout();
596                } else if now - req.timestamp > self.protocol_breach_request_timeout {
597                    return true
598                }
599            }
600        }
601
602        false
603    }
604
605    /// Updates the request timeout with a request's timestamps
606    fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
607        let elapsed = received.saturating_duration_since(sent);
608
609        let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
610        let request_timeout = calculate_new_timeout(current, elapsed);
611        self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
612        self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
613    }
614
615    /// If a termination message is queued this will try to send it
616    fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
617        let (mut tx, msg) = self.terminate_message.take()?;
618        match tx.poll_reserve(cx) {
619            Poll::Pending => {
620                self.terminate_message = Some((tx, msg));
621                return Some(Poll::Pending)
622            }
623            Poll::Ready(Ok(())) => {
624                let _ = tx.send_item(msg);
625            }
626            Poll::Ready(Err(_)) => {
627                // channel closed
628            }
629        }
630        // terminate the task
631        Some(Poll::Ready(()))
632    }
633}
634
635impl<N: NetworkPrimitives> Future for ActiveSession<N> {
636    type Output = ();
637
638    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
639        let this = self.get_mut();
640
641        // if the session is terminate we have to send the termination message before we can close
642        if let Some(terminate) = this.poll_terminate_message(cx) {
643            return terminate
644        }
645
646        if this.is_disconnecting() {
647            return this.poll_disconnect(cx)
648        }
649
650        // The receive loop can be CPU intensive since it involves message decoding which could take
651        // up a lot of resources and increase latencies for other sessions if not yielded manually.
652        // If the budget is exhausted we manually yield back control to the (coop) scheduler. This
653        // manual yield point should prevent situations where polling appears to be frozen. See also <https://tokio.rs/blog/2020-04-preemption>
654        // And tokio's docs on cooperative scheduling <https://docs.rs/tokio/latest/tokio/task/#cooperative-scheduling>
655        let mut budget = 4;
656
657        // The main poll loop that drives the session
658        'main: loop {
659            let mut progress = false;
660
661            // we prioritize incoming commands sent from the session manager
662            loop {
663                match this.commands_rx.poll_next_unpin(cx) {
664                    Poll::Pending => break,
665                    Poll::Ready(None) => {
666                        // this is only possible when the manager was dropped, in which case we also
667                        // terminate this session
668                        return Poll::Ready(())
669                    }
670                    Poll::Ready(Some(cmd)) => {
671                        progress = true;
672                        match cmd {
673                            SessionCommand::Disconnect { reason } => {
674                                debug!(
675                                    target: "net::session",
676                                    ?reason,
677                                    remote_peer_id=?this.remote_peer_id,
678                                    "Received disconnect command for session"
679                                );
680                                let reason =
681                                    reason.unwrap_or(DisconnectReason::DisconnectRequested);
682
683                                return this.try_disconnect(reason, cx)
684                            }
685                            SessionCommand::Message(msg) => {
686                                this.on_internal_peer_message(msg);
687                            }
688                        }
689                    }
690                }
691            }
692
693            // Drain the unbounded channel (broadcast overflow + disconnect commands)
694            while let Poll::Ready(Some(cmd)) = this.unbounded_rx.poll_recv(cx) {
695                progress = true;
696                match cmd {
697                    SessionCommand::Message(msg) => {
698                        this.unbounded_broadcast_msgs.increment(1);
699                        this.on_internal_peer_message(msg);
700                    }
701                    SessionCommand::Disconnect { reason } => {
702                        let reason = reason.unwrap_or(DisconnectReason::DisconnectRequested);
703                        return this.try_disconnect(reason, cx)
704                    }
705                }
706            }
707
708            let deadline = this.request_deadline();
709
710            while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
711                progress = true;
712                this.on_internal_peer_request(req, deadline);
713            }
714
715            // Advance all active requests.
716            // We remove each request one by one and add them back.
717            for idx in (0..this.received_requests_from_remote.len()).rev() {
718                let mut req = this.received_requests_from_remote.swap_remove(idx);
719                match req.rx.poll(cx) {
720                    Poll::Pending => {
721                        // not ready yet
722                        this.received_requests_from_remote.push(req);
723                    }
724                    Poll::Ready(resp) => {
725                        this.handle_outgoing_response(req.request_id, resp);
726                    }
727                }
728            }
729
730            // Send messages by advancing the sink and queuing in buffered messages
731            while this.conn.poll_ready_unpin(cx).is_ready() {
732                if let Some(msg) = this.queued_outgoing.pop_front() {
733                    progress = true;
734                    let res = match msg {
735                        OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
736                        OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
737                        OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
738                    };
739                    if let Err(err) = res {
740                        debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
741                        // notify the manager
742                        return this.close_on_error(err, cx)
743                    }
744                } else {
745                    // no more messages to send over the wire
746                    break
747                }
748            }
749
750            // read incoming messages from the wire
751            'receive: loop {
752                // ensure we still have enough budget for another iteration
753                budget -= 1;
754                if budget == 0 {
755                    // make sure we're woken up again
756                    cx.waker().wake_by_ref();
757                    break 'main
758                }
759
760                // try to resend the pending message that we could not send because the channel was
761                // full. [`PollSender`] will ensure that we're woken up again when the channel is
762                // ready to receive the message, and will only error if the channel is closed.
763                if let Some(msg) = this.pending_message_to_session.take() {
764                    match this.to_session_manager.poll_reserve(cx) {
765                        Poll::Ready(Ok(_)) => {
766                            let _ = this.to_session_manager.send_item(msg);
767                        }
768                        Poll::Ready(Err(_)) => return Poll::Ready(()),
769                        Poll::Pending => {
770                            this.pending_message_to_session = Some(msg);
771                            break 'receive
772                        }
773                    };
774                }
775
776                // check whether we should throttle incoming messages
777                if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
778                    // we're currently waiting for the responses to the peer's requests which aren't
779                    // queued as outgoing yet
780                    //
781                    // Note: we don't need to register the waker here because we polled the requests
782                    // above
783                    break 'receive
784                }
785
786                // we also need to check if we have multiple responses queued up
787                if this.queued_outgoing.response_count() > MAX_QUEUED_OUTGOING_RESPONSES {
788                    // if we've queued up more responses than allowed, we don't poll for new
789                    // messages and break the receive loop early
790                    //
791                    // Note: we don't need to register the waker here because we still have
792                    // queued messages and the sink impl registered the waker because we've
793                    // already advanced it to `Pending` earlier
794                    break 'receive
795                }
796
797                match this.conn.poll_next_unpin(cx) {
798                    Poll::Pending => break,
799                    Poll::Ready(None) => {
800                        if this.is_disconnecting() {
801                            break
802                        }
803                        debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
804                        return this.emit_disconnect(cx)
805                    }
806                    Poll::Ready(Some(res)) => {
807                        match res {
808                            Ok(msg) => {
809                                trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
810                                // decode and handle message
811                                match this.on_incoming_message(msg) {
812                                    OnIncomingMessageOutcome::Ok => {
813                                        // handled successfully
814                                        progress = true;
815                                    }
816                                    OnIncomingMessageOutcome::BadMessage { error, message } => {
817                                        debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
818                                        this.on_bad_message();
819                                        return this
820                                            .try_disconnect(DisconnectReason::ProtocolBreach, cx)
821                                    }
822                                    OnIncomingMessageOutcome::NoCapacity(msg) => {
823                                        // failed to send due to lack of capacity
824                                        this.pending_message_to_session = Some(msg);
825                                    }
826                                }
827                            }
828                            Err(err) => {
829                                debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
830                                if err.is_protocol_breach() {
831                                    this.on_bad_message();
832                                    return this.try_disconnect(DisconnectReason::ProtocolBreach, cx)
833                                }
834                                return this.close_on_error(err, cx)
835                            }
836                        }
837                    }
838                }
839            }
840
841            if !progress {
842                break 'main
843            }
844        }
845
846        if let Some(interval) = &mut this.range_update_interval {
847            // Check if we should send a range update based on block height changes
848            while interval.poll_tick(cx).is_ready() {
849                let current_latest = this.local_range_info.latest();
850                let should_send = if let Some(last_sent) = this.last_sent_latest_block {
851                    // Only send if block height has advanced by at least one epoch (32 blocks)
852                    current_latest.saturating_sub(last_sent) >= EPOCH_SLOTS
853                } else {
854                    true // First update, always send
855                };
856
857                if should_send {
858                    this.queued_outgoing.push_back(
859                        EthMessage::BlockRangeUpdate(this.local_range_info.to_message()).into(),
860                    );
861                    this.last_sent_latest_block = Some(current_latest);
862                }
863            }
864        }
865
866        while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
867            // check for timed out requests
868            if this.check_timed_out_requests(Instant::now()) &&
869                let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx)
870            {
871                let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
872                this.pending_message_to_session = Some(msg);
873            }
874        }
875
876        this.shrink_to_fit();
877
878        Poll::Pending
879    }
880}
881
882/// Tracks a request received from the peer
883pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
884    /// Protocol Identifier
885    request_id: u64,
886    /// Receiver half of the channel that's supposed to receive the proper response.
887    rx: PeerResponse<N>,
888    /// Timestamp when we read this msg from the wire.
889    #[expect(dead_code)]
890    received: Instant,
891}
892
893/// A request that waits for a response from the peer
894pub(crate) struct InflightRequest<R> {
895    /// Request we sent to peer and the internal response channel
896    request: RequestState<R>,
897    /// Instant when the request was sent
898    timestamp: Instant,
899    /// Time limit for the response
900    deadline: Instant,
901}
902
903// === impl InflightRequest ===
904
905impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
906    /// Returns true if the request is timedout
907    #[inline]
908    fn is_timed_out(&self, now: Instant) -> bool {
909        now > self.deadline
910    }
911
912    /// Returns true if we're still waiting for a response
913    #[inline]
914    const fn is_waiting(&self) -> bool {
915        matches!(self.request, RequestState::Waiting(_))
916    }
917
918    /// This will timeout the request by sending an error response to the internal channel
919    fn timeout(&mut self) {
920        let mut req = RequestState::TimedOut;
921        std::mem::swap(&mut self.request, &mut req);
922
923        if let RequestState::Waiting(req) = req {
924            req.send_err_response(RequestError::Timeout);
925        }
926    }
927}
928
929/// All outcome variants when handling an incoming message
930enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
931    /// Message successfully handled.
932    Ok,
933    /// Message is considered to be in violation of the protocol
934    BadMessage { error: EthStreamError, message: EthMessage<N> },
935    /// Currently no capacity to handle the message
936    NoCapacity(ActiveSessionMessage<N>),
937}
938
939impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
940    for OnIncomingMessageOutcome<N>
941{
942    fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
943        match res {
944            Ok(_) => Self::Ok,
945            Err(msg) => Self::NoCapacity(msg),
946        }
947    }
948}
949
950enum RequestState<R> {
951    /// Waiting for the response
952    Waiting(R),
953    /// Request already timed out
954    TimedOut,
955}
956
957/// Outgoing messages that can be sent over the wire.
958#[derive(Debug)]
959pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
960    /// A message that is owned.
961    Eth(EthMessage<N>),
962    /// A message that may be shared by multiple sessions.
963    Broadcast(EthBroadcastMessage<N>),
964    /// A raw capability message
965    Raw(RawCapabilityMessage),
966}
967
968impl<N: NetworkPrimitives> OutgoingMessage<N> {
969    /// Returns true if this is a response.
970    const fn is_response(&self) -> bool {
971        match self {
972            Self::Eth(msg) => msg.is_response(),
973            _ => false,
974        }
975    }
976
977    /// Returns the number of broadcast items in this message.
978    ///
979    /// For transaction hash announcements this is the number of hashes, for full transaction
980    /// broadcasts it is the number of transactions, and for blocks it is 1.
981    /// Request/response messages return 0.
982    fn broadcast_item_count(&self) -> usize {
983        match self {
984            Self::Eth(msg) => match msg {
985                EthMessage::NewBlockHashes(h) => h.len(),
986                EthMessage::NewPooledTransactionHashes66(h) => h.len(),
987                EthMessage::NewPooledTransactionHashes68(h) => h.hashes.len(),
988                EthMessage::NewPooledTransactionHashes72(h) => h.hashes.len(),
989                _ => 0,
990            },
991            Self::Broadcast(msg) => match msg {
992                EthBroadcastMessage::NewBlock(_) => 1,
993                EthBroadcastMessage::Transactions(txs) => txs.len(),
994            },
995            Self::Raw(_) => 0,
996        }
997    }
998
999    /// Tries to merge pooled transaction hash announcements into this message, consuming the
1000    /// incoming hashes. Returns `Some(incoming)` back if the variants don't match.
1001    fn try_merge_hashes(
1002        &mut self,
1003        incoming: NewPooledTransactionHashes,
1004    ) -> Option<NewPooledTransactionHashes> {
1005        let Self::Eth(eth) = self else { return Some(incoming) };
1006        match (eth, incoming) {
1007            (
1008                EthMessage::NewPooledTransactionHashes66(existing),
1009                NewPooledTransactionHashes::Eth66(inc),
1010            ) => {
1011                existing.extend(inc);
1012                None
1013            }
1014            (
1015                EthMessage::NewPooledTransactionHashes68(existing),
1016                NewPooledTransactionHashes::Eth68(inc),
1017            ) => {
1018                existing.hashes.extend(inc.hashes);
1019                existing.sizes.extend(inc.sizes);
1020                existing.types.extend(inc.types);
1021                None
1022            }
1023            (
1024                EthMessage::NewPooledTransactionHashes72(existing),
1025                NewPooledTransactionHashes::Eth72(inc),
1026            ) => {
1027                existing.hashes.extend(inc.hashes);
1028                existing.sizes.extend(inc.sizes);
1029                existing.types.extend(inc.types);
1030                None
1031            }
1032            (_, incoming) => Some(incoming),
1033        }
1034    }
1035}
1036
1037impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
1038    fn from(value: EthMessage<N>) -> Self {
1039        Self::Eth(value)
1040    }
1041}
1042
1043impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
1044    fn from(value: EthBroadcastMessage<N>) -> Self {
1045        Self::Broadcast(value)
1046    }
1047}
1048
1049/// Calculates a new timeout using an updated estimation of the RTT
1050#[inline]
1051fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
1052    let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
1053
1054    // this dampens sudden changes by taking a weighted mean of the old and new values
1055    let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
1056
1057    smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
1058}
1059
1060/// A helper struct that wraps the queue of outgoing messages with broadcast-aware tracking.
1061///
1062/// Tracks both the total number of queued messages (via a metric gauge) and the total number of
1063/// broadcast items (tx hashes, transactions, blocks) via a shared atomic counter. The atomic
1064/// counter is shared with [`ActiveSessionHandle`](super::handle::ActiveSessionHandle) so the
1065/// [`SessionManager`](super::SessionManager) can apply size-based backpressure.
1066pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
1067    messages: VecDeque<OutgoingMessage<N>>,
1068    /// Number of queued response messages, tracked separately so the session can apply
1069    /// backpressure on incoming requests without scanning the whole queue.
1070    queued_responses: usize,
1071    count: Gauge,
1072    /// Shared counter of buffered broadcast items for size-based backpressure.
1073    broadcast_items: BroadcastItemCounter,
1074}
1075
1076impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
1077    pub(crate) const fn new(metric: Gauge, broadcast_items: BroadcastItemCounter) -> Self {
1078        Self { messages: VecDeque::new(), queued_responses: 0, count: metric, broadcast_items }
1079    }
1080
1081    /// Returns the number of queued response messages.
1082    pub(crate) const fn response_count(&self) -> usize {
1083        self.queued_responses
1084    }
1085
1086    pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
1087        self.queued_responses += message.is_response() as usize;
1088        self.messages.push_back(message);
1089        self.count.increment(1);
1090    }
1091
1092    pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
1093        self.messages.pop_front().inspect(|msg| {
1094            self.count.decrement(1);
1095            self.queued_responses -= msg.is_response() as usize;
1096            let items = msg.broadcast_item_count();
1097            if items > 0 {
1098                self.broadcast_items.sub(items);
1099            }
1100        })
1101    }
1102
1103    /// Pushes a pooled transaction hash announcement, merging into the last queued message if
1104    /// it is the same variant (eth66, eth68, or eth72).
1105    pub(crate) fn push_pooled_hashes(&mut self, msg: NewPooledTransactionHashes) {
1106        let msg = if let Some(last) = self.messages.back_mut() {
1107            match last.try_merge_hashes(msg) {
1108                None => return,
1109                Some(msg) => msg,
1110            }
1111        } else {
1112            msg
1113        };
1114        self.messages.push_back(EthMessage::from(msg).into());
1115        self.count.increment(1);
1116    }
1117
1118    pub(crate) fn shrink_to(&mut self, min_capacity: usize) {
1119        self.messages.shrink_to(min_capacity);
1120    }
1121}
1122
1123impl<N: NetworkPrimitives> Drop for QueuedOutgoingMessages<N> {
1124    fn drop(&mut self) {
1125        // Ensure gauge is decremented for any remaining items to avoid metric leak on teardown.
1126        let remaining = self.messages.len();
1127        if remaining > 0 {
1128            self.count.decrement(remaining as f64);
1129        }
1130    }
1131}
1132
1133#[cfg(test)]
1134mod tests {
1135    use super::*;
1136    use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
1137    use alloy_eips::eip2124::ForkFilter;
1138    use reth_chainspec::MAINNET;
1139    use reth_ecies::stream::ECIESStream;
1140    use reth_eth_wire::{
1141        handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockAccessLists,
1142        GetBlockBodies, HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream,
1143        UnauthedP2PStream, UnifiedStatus,
1144    };
1145    use reth_eth_wire_types::{
1146        message::MAX_MESSAGE_SIZE, EthMessageID, NewPooledTransactionHashes72, RawCapabilityMessage,
1147    };
1148    use reth_ethereum_forks::EthereumHardfork;
1149    use reth_network_peers::pk2id;
1150    use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
1151    use secp256k1::{SecretKey, SECP256K1};
1152    use tokio::{
1153        net::{TcpListener, TcpStream},
1154        sync::mpsc,
1155    };
1156
1157    /// Returns a testing `HelloMessage` and new secretkey
1158    fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
1159        HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
1160    }
1161
1162    struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
1163        _remote_capabilities: Arc<Capabilities>,
1164        active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
1165        active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
1166        to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
1167        secret_key: SecretKey,
1168        local_peer_id: PeerId,
1169        hello: HelloMessageWithProtocols,
1170        status: UnifiedStatus,
1171        fork_filter: ForkFilter,
1172        next_id: usize,
1173    }
1174
1175    impl<N: NetworkPrimitives> SessionBuilder<N> {
1176        fn next_id(&mut self) -> SessionId {
1177            let id = self.next_id;
1178            self.next_id += 1;
1179            SessionId(id)
1180        }
1181
1182        /// Connects a new Eth stream and executes the given closure with that established stream
1183        fn with_client_stream<F, O>(
1184            &self,
1185            local_addr: SocketAddr,
1186            f: F,
1187        ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1188        where
1189            F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
1190            O: Future<Output = ()> + Send + Sync,
1191        {
1192            let mut status = self.status;
1193            let fork_filter = self.fork_filter.clone();
1194            let local_peer_id = self.local_peer_id;
1195            let mut hello = self.hello.clone();
1196            let key = SecretKey::new(&mut rand_08::thread_rng());
1197            hello.id = pk2id(&key.public_key(SECP256K1));
1198            Box::pin(async move {
1199                let outgoing = TcpStream::connect(local_addr).await.unwrap();
1200                let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
1201
1202                let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
1203
1204                let eth_version = p2p_stream.shared_capabilities().eth_version().unwrap();
1205                status.set_eth_version(eth_version);
1206
1207                let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
1208                    .handshake(status, fork_filter)
1209                    .await
1210                    .unwrap();
1211                f(client_stream).await
1212            })
1213        }
1214
1215        async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
1216            let remote_addr = stream.local_addr().unwrap();
1217            let session_id = self.next_id();
1218            let (_disconnect_tx, disconnect_rx) = oneshot::channel();
1219            let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
1220
1221            tokio::task::spawn(start_pending_incoming_session(
1222                Arc::new(EthHandshake::default()),
1223                MAX_MESSAGE_SIZE,
1224                disconnect_rx,
1225                session_id,
1226                stream,
1227                pending_sessions_tx,
1228                remote_addr,
1229                self.secret_key,
1230                self.hello.clone(),
1231                self.status,
1232                self.fork_filter.clone(),
1233                Default::default(),
1234            ));
1235
1236            let mut stream = ReceiverStream::new(pending_sessions_rx);
1237
1238            match stream.next().await.unwrap() {
1239                PendingSessionEvent::Established {
1240                    session_id,
1241                    remote_addr,
1242                    peer_id,
1243                    capabilities,
1244                    conn,
1245                    ..
1246                } => {
1247                    let (_to_session_tx, messages_rx) = mpsc::channel(10);
1248                    let (commands_to_session, commands_rx) = mpsc::channel(10);
1249                    let (_unbounded_tx, unbounded_rx) = mpsc::unbounded_channel();
1250                    let poll_sender = PollSender::new(self.active_session_tx.clone());
1251
1252                    self.to_sessions.push(commands_to_session);
1253
1254                    ActiveSession {
1255                        next_id: 0,
1256                        remote_peer_id: peer_id,
1257                        remote_addr,
1258                        remote_capabilities: Arc::clone(&capabilities),
1259                        session_id,
1260                        commands_rx: ReceiverStream::new(commands_rx),
1261                        unbounded_rx,
1262                        unbounded_broadcast_msgs: Counter::noop(),
1263                        to_session_manager: MeteredPollSender::new(
1264                            poll_sender,
1265                            "network_active_session",
1266                        ),
1267                        pending_message_to_session: None,
1268                        internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
1269                        inflight_requests: Default::default(),
1270                        conn,
1271                        queued_outgoing: QueuedOutgoingMessages::new(
1272                            Gauge::noop(),
1273                            BroadcastItemCounter::new(),
1274                        ),
1275                        received_requests_from_remote: Default::default(),
1276                        internal_request_timeout_interval: tokio::time::interval(
1277                            INITIAL_REQUEST_TIMEOUT,
1278                        ),
1279                        internal_request_timeout: Arc::new(AtomicU64::new(
1280                            INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
1281                        )),
1282                        protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
1283                        terminate_message: None,
1284                        range_info: None,
1285                        local_range_info: BlockRangeInfo::new(
1286                            0,
1287                            1000,
1288                            alloy_primitives::B256::ZERO,
1289                        ),
1290                        range_update_interval: None,
1291                        last_sent_latest_block: None,
1292                    }
1293                }
1294                ev => {
1295                    panic!("unexpected message {ev:?}")
1296                }
1297            }
1298        }
1299    }
1300
1301    impl Default for SessionBuilder {
1302        fn default() -> Self {
1303            let (active_session_tx, active_session_rx) = mpsc::channel(100);
1304
1305            let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1306            let local_peer_id = pk2id(&pk);
1307
1308            Self {
1309                next_id: 0,
1310                _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1311                active_session_tx,
1312                active_session_rx: ReceiverStream::new(active_session_rx),
1313                to_sessions: vec![],
1314                hello: eth_hello(&secret_key),
1315                secret_key,
1316                local_peer_id,
1317                status: StatusBuilder::default().build(),
1318                fork_filter: MAINNET
1319                    .hardfork_fork_filter(EthereumHardfork::Frontier)
1320                    .expect("The Frontier fork filter should exist on mainnet"),
1321            }
1322        }
1323    }
1324
1325    #[tokio::test(flavor = "multi_thread")]
1326    async fn test_disconnect() {
1327        let mut builder = SessionBuilder::default();
1328
1329        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1330        let local_addr = listener.local_addr().unwrap();
1331
1332        let expected_disconnect = DisconnectReason::UselessPeer;
1333
1334        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1335            let msg = client_stream.next().await.unwrap().unwrap_err();
1336            assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1337        });
1338
1339        tokio::task::spawn(async move {
1340            let (incoming, _) = listener.accept().await.unwrap();
1341            let mut session = builder.connect_incoming(incoming).await;
1342
1343            session.start_disconnect(expected_disconnect).unwrap();
1344            session.await
1345        });
1346
1347        fut.await;
1348    }
1349
1350    #[tokio::test(flavor = "multi_thread")]
1351    async fn test_invalid_message_disconnects_with_protocol_breach() {
1352        let mut builder = SessionBuilder::default();
1353
1354        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1355        let local_addr = listener.local_addr().unwrap();
1356
1357        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1358            client_stream
1359                .start_send_raw(RawCapabilityMessage::eth(
1360                    EthMessageID::PooledTransactions,
1361                    vec![0xc0].into(),
1362                ))
1363                .unwrap();
1364            client_stream.flush().await.unwrap();
1365
1366            let msg = client_stream.next().await.unwrap().unwrap_err();
1367            assert_eq!(msg.as_disconnected(), Some(DisconnectReason::ProtocolBreach));
1368        });
1369
1370        let (tx, rx) = oneshot::channel();
1371
1372        tokio::task::spawn(async move {
1373            let (incoming, _) = listener.accept().await.unwrap();
1374            let session = builder.connect_incoming(incoming).await;
1375            session.await;
1376
1377            tx.send(()).unwrap();
1378        });
1379
1380        fut.await;
1381        rx.await.unwrap();
1382    }
1383
1384    #[tokio::test(flavor = "multi_thread")]
1385    async fn handle_dropped_stream() {
1386        let mut builder = SessionBuilder::default();
1387
1388        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1389        let local_addr = listener.local_addr().unwrap();
1390
1391        let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1392            drop(client_stream);
1393            tokio::time::sleep(Duration::from_secs(1)).await
1394        });
1395
1396        let (tx, rx) = oneshot::channel();
1397
1398        tokio::task::spawn(async move {
1399            let (incoming, _) = listener.accept().await.unwrap();
1400            let session = builder.connect_incoming(incoming).await;
1401            session.await;
1402
1403            tx.send(()).unwrap();
1404        });
1405
1406        tokio::task::spawn(fut);
1407
1408        rx.await.unwrap();
1409    }
1410
1411    #[tokio::test(flavor = "multi_thread")]
1412    async fn test_send_many_messages() {
1413        reth_tracing::init_test_tracing();
1414        let mut builder = SessionBuilder::default();
1415
1416        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1417        let local_addr = listener.local_addr().unwrap();
1418
1419        let num_messages = 100;
1420
1421        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1422            for _ in 0..num_messages {
1423                client_stream
1424                    .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1425                    .await
1426                    .unwrap();
1427            }
1428        });
1429
1430        let (tx, rx) = oneshot::channel();
1431
1432        tokio::task::spawn(async move {
1433            let (incoming, _) = listener.accept().await.unwrap();
1434            let session = builder.connect_incoming(incoming).await;
1435            session.await;
1436
1437            tx.send(()).unwrap();
1438        });
1439
1440        tokio::task::spawn(fut);
1441
1442        rx.await.unwrap();
1443    }
1444
1445    #[tokio::test(flavor = "multi_thread")]
1446    async fn test_request_timeout() {
1447        reth_tracing::init_test_tracing();
1448
1449        let mut builder = SessionBuilder::default();
1450
1451        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1452        let local_addr = listener.local_addr().unwrap();
1453
1454        let request_timeout = Duration::from_millis(100);
1455        let drop_timeout = Duration::from_millis(1500);
1456
1457        let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1458            let _client_stream = client_stream;
1459            tokio::time::sleep(drop_timeout * 60).await;
1460        });
1461        tokio::task::spawn(fut);
1462
1463        let (incoming, _) = listener.accept().await.unwrap();
1464        let mut session = builder.connect_incoming(incoming).await;
1465        session
1466            .internal_request_timeout
1467            .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1468        session.protocol_breach_request_timeout = drop_timeout;
1469        session.internal_request_timeout_interval =
1470            tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1471        let (tx, rx) = oneshot::channel();
1472        let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1473        session.on_internal_peer_request(req, Instant::now());
1474        tokio::spawn(session);
1475
1476        let err = rx.await.unwrap().unwrap_err();
1477        assert_eq!(err, RequestError::Timeout);
1478
1479        // wait for protocol breach error
1480        let msg = builder.active_session_rx.next().await.unwrap();
1481        match msg {
1482            ActiveSessionMessage::ProtocolBreach { .. } => {}
1483            ev => unreachable!("{ev:?}"),
1484        }
1485    }
1486
1487    #[test]
1488    fn eth72_pooled_hashes_count_broadcast_items() {
1489        let hashes =
1490            vec![alloy_primitives::B256::repeat_byte(1), alloy_primitives::B256::repeat_byte(2)];
1491        let msg: OutgoingMessage<EthNetworkPrimitives> =
1492            EthMessage::NewPooledTransactionHashes72(NewPooledTransactionHashes72 {
1493                types: vec![0; hashes.len()],
1494                sizes: vec![1; hashes.len()],
1495                hashes,
1496                cell_mask: None,
1497            })
1498            .into();
1499
1500        assert_eq!(2, msg.broadcast_item_count());
1501    }
1502
1503    #[test]
1504    fn test_reject_bal_request_for_eth70() {
1505        let (tx, _rx) = oneshot::channel();
1506        let request: PeerRequest<EthNetworkPrimitives> =
1507            PeerRequest::GetBlockAccessLists { request: GetBlockAccessLists(vec![]), response: tx };
1508
1509        assert!(!ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1510            &request,
1511            EthVersion::Eth70
1512        ));
1513        assert!(ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1514            &request,
1515            EthVersion::Eth71
1516        ));
1517    }
1518
1519    #[tokio::test(flavor = "multi_thread")]
1520    async fn test_keep_alive() {
1521        let mut builder = SessionBuilder::default();
1522
1523        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1524        let local_addr = listener.local_addr().unwrap();
1525
1526        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1527            let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1528            client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1529        });
1530
1531        let (tx, rx) = oneshot::channel();
1532
1533        tokio::task::spawn(async move {
1534            let (incoming, _) = listener.accept().await.unwrap();
1535            let session = builder.connect_incoming(incoming).await;
1536            session.await;
1537
1538            tx.send(()).unwrap();
1539        });
1540
1541        tokio::task::spawn(fut);
1542
1543        rx.await.unwrap();
1544    }
1545
1546    // This tests that incoming messages are delivered when there's capacity.
1547    #[tokio::test(flavor = "multi_thread")]
1548    async fn test_send_at_capacity() {
1549        let mut builder = SessionBuilder::default();
1550
1551        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1552        let local_addr = listener.local_addr().unwrap();
1553
1554        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1555            client_stream
1556                .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1557                .await
1558                .unwrap();
1559            let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1560        });
1561        tokio::task::spawn(fut);
1562
1563        let (incoming, _) = listener.accept().await.unwrap();
1564        let session = builder.connect_incoming(incoming).await;
1565
1566        // fill the entire message buffer with an unrelated message
1567        let mut num_fill_messages = 0;
1568        loop {
1569            if builder
1570                .active_session_tx
1571                .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1572                .is_err()
1573            {
1574                break
1575            }
1576            num_fill_messages += 1;
1577        }
1578
1579        tokio::task::spawn(async move {
1580            session.await;
1581        });
1582
1583        tokio::time::sleep(Duration::from_millis(100)).await;
1584
1585        for _ in 0..num_fill_messages {
1586            let message = builder.active_session_rx.next().await.unwrap();
1587            match message {
1588                ActiveSessionMessage::ProtocolBreach { .. } => {}
1589                ev => unreachable!("{ev:?}"),
1590            }
1591        }
1592
1593        let message = builder.active_session_rx.next().await.unwrap();
1594        match message {
1595            ActiveSessionMessage::ValidMessage {
1596                message: PeerMessage::PooledTransactions(_),
1597                ..
1598            } => {}
1599            _ => unreachable!(),
1600        }
1601    }
1602
1603    #[test]
1604    fn timeout_calculation_sanity_tests() {
1605        let rtt = Duration::from_secs(5);
1606        // timeout for an RTT of `rtt`
1607        let timeout = rtt * TIMEOUT_SCALING;
1608
1609        // if rtt hasn't changed, timeout shouldn't change
1610        assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1611
1612        // if rtt changed, the new timeout should change less than it
1613        assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1614        assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1615        assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1616        assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1617    }
1618}