Skip to main content

reth_network/session/
active.rs

1//! Represents an established session.
2
3use core::sync::atomic::Ordering;
4use std::{
5    collections::VecDeque,
6    future::Future,
7    net::SocketAddr,
8    pin::Pin,
9    sync::{
10        atomic::{AtomicU64, AtomicUsize},
11        Arc,
12    },
13    task::{ready, Context, Poll},
14    time::{Duration, Instant},
15};
16
17use crate::{
18    message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
19    session::{
20        conn::EthRlpxConnection,
21        handle::{ActiveSessionMessage, SessionCommand},
22        BlockRangeInfo, EthVersion, SessionId,
23    },
24};
25use alloy_eips::merge::EPOCH_SLOTS;
26use alloy_primitives::Sealable;
27use futures::{stream::Fuse, SinkExt, StreamExt};
28use metrics::{Counter, Gauge};
29use reth_eth_wire::{
30    errors::{EthHandshakeError, EthStreamError},
31    message::{EthBroadcastMessage, MessageError},
32    Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives, NewBlockPayload,
33};
34use reth_eth_wire_types::{message::RequestPair, NewPooledTransactionHashes, RawCapabilityMessage};
35use reth_metrics::common::mpsc::MeteredPollSender;
36use reth_network_api::PeerRequest;
37use reth_network_p2p::error::RequestError;
38use reth_network_peers::PeerId;
39use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
40use reth_primitives_traits::Block;
41use rustc_hash::FxHashMap;
42use tokio::{
43    sync::{mpsc, mpsc::error::TrySendError, oneshot},
44    time::Interval,
45};
46use tokio_stream::wrappers::ReceiverStream;
47use tokio_util::sync::PollSender;
48use tracing::{debug, trace};
49
50/// The recommended interval at which to check if a new range update should be sent to the remote
51/// peer.
52///
53/// Updates are only sent when the block height has advanced by at least one epoch (32 blocks)
54/// since the last update. The interval is set to one epoch duration in seconds.
55pub(super) const RANGE_UPDATE_INTERVAL: Duration = Duration::from_secs(EPOCH_SLOTS * 12);
56
57// Constants for timeout updating.
58
59/// Minimum timeout value
60const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
61
62/// Maximum timeout value
63const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
64/// How much the new measurements affect the current timeout (X percent)
65const SAMPLE_IMPACT: f64 = 0.1;
66/// Amount of RTTs before timeout
67const TIMEOUT_SCALING: u32 = 3;
68
69/// Restricts the number of queued outgoing messages for larger responses:
70///  - Block Bodies
71///  - Receipts
72///  - Headers
73///  - `PooledTransactions`
74///
75/// With proper softlimits in place (2MB) this targets 10MB (4+1 * 2MB) of outgoing response data.
76///
77/// This parameter serves as backpressure for reading additional requests from the remote.
78/// Once we've queued up more responses than this, the session should prioritize message flushing
79/// before reading any more messages from the remote peer, throttling the peer.
80const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
81
82/// Minimum capacity to retain for buffered incoming requests from the remote peer.
83const MIN_RECEIVED_REQUESTS_CAPACITY: usize = 1;
84
85/// Soft limit for the total number of buffered outgoing broadcast items (e.g. transaction hashes).
86///
87/// Many small broadcast messages carrying a single tx hash each are equivalent in cost to one
88/// message carrying many hashes. This limit counts individual items (hashes, transactions, blocks)
89/// rather than messages, so that many small messages don't trigger aggressive drops unnecessarily.
90const MAX_QUEUED_BROADCAST_ITEMS: usize = 4096;
91
92/// Shared counter for in-flight broadcast items (tx hashes, transactions, blocks) across the
93/// bounded command channel, unbounded overflow channel, and session outgoing queue.
94///
95/// Wrapped in a newtype so the backing storage can be changed later (e.g. to track memory) without
96/// touching every call-site.
97#[derive(Debug, Clone)]
98pub(crate) struct BroadcastItemCounter(Arc<AtomicUsize>);
99
100impl BroadcastItemCounter {
101    /// Creates a new counter starting at zero.
102    pub(crate) fn new() -> Self {
103        Self(Arc::new(AtomicUsize::new(0)))
104    }
105
106    /// Returns the current count.
107    pub(crate) fn get(&self) -> usize {
108        self.0.load(Ordering::Relaxed)
109    }
110
111    /// Attempts to add `n` items. Returns `true` if under the limit, `false` if over (no change).
112    pub(crate) fn try_add(&self, n: usize) -> bool {
113        let prev = self.0.fetch_add(n, Ordering::Relaxed);
114        if prev >= MAX_QUEUED_BROADCAST_ITEMS {
115            self.0.fetch_sub(n, Ordering::Relaxed);
116            false
117        } else {
118            true
119        }
120    }
121
122    /// Subtracts `n` items from the counter.
123    pub(crate) fn sub(&self, n: usize) {
124        self.0.fetch_sub(n, Ordering::Relaxed);
125    }
126}
127
128/// The type that advances an established session by listening for incoming messages (from local
129/// node or read from connection) and emitting events back to the
130/// [`SessionManager`](super::SessionManager).
131///
132/// It listens for
133///    - incoming commands from the [`SessionManager`](super::SessionManager)
134///    - incoming _internal_ requests/broadcasts via the request/command channel
135///    - incoming requests/broadcasts _from remote_ via the connection
136///    - responses for handled ETH requests received from the remote peer.
137#[expect(dead_code)]
138pub(crate) struct ActiveSession<N: NetworkPrimitives> {
139    /// Keeps track of request ids.
140    pub(crate) next_id: u64,
141    /// The underlying connection.
142    pub(crate) conn: EthRlpxConnection<N>,
143    /// Identifier of the node we're connected to.
144    pub(crate) remote_peer_id: PeerId,
145    /// The address we're connected to.
146    pub(crate) remote_addr: SocketAddr,
147    /// All capabilities the peer announced
148    pub(crate) remote_capabilities: Arc<Capabilities>,
149    /// Internal identifier of this session
150    pub(crate) session_id: SessionId,
151    /// Incoming commands from the manager
152    pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
153    /// Unbounded channel for commands that couldn't fit in the bounded channel (broadcast
154    /// overflow) and for disconnect commands that must never be dropped.
155    pub(crate) unbounded_rx: mpsc::UnboundedReceiver<SessionCommand<N>>,
156    /// Counter for broadcast messages received via the unbounded overflow channel.
157    pub(crate) unbounded_broadcast_msgs: Counter,
158    /// Sink to send messages to the [`SessionManager`](super::SessionManager).
159    pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
160    /// A message that needs to be delivered to the session manager
161    pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
162    /// Incoming internal requests which are delegated to the remote peer.
163    pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
164    /// All requests sent to the remote peer we're waiting on a response
165    pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
166    /// All requests that were sent by the remote peer and we're waiting on an internal response
167    pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
168    /// Buffered messages that should be handled and sent to the peer.
169    pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
170    /// The maximum time we wait for a response from a peer.
171    pub(crate) internal_request_timeout: Arc<AtomicU64>,
172    /// Interval when to check for timed out requests.
173    pub(crate) internal_request_timeout_interval: Interval,
174    /// If an [`ActiveSession`] does not receive a response at all within this duration then it is
175    /// considered a protocol violation and the session will initiate a drop.
176    pub(crate) protocol_breach_request_timeout: Duration,
177    /// Used to reserve a slot to guarantee that the termination message is delivered
178    pub(crate) terminate_message:
179        Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
180    /// The eth69 range info for the remote peer.
181    /// This is `None` for peers negotiating versions below `eth/69`.
182    pub(crate) range_info: Option<BlockRangeInfo>,
183    /// The eth69 range info for the local node (this node).
184    /// This represents the range of blocks that this node can serve to other peers.
185    pub(crate) local_range_info: BlockRangeInfo,
186    /// Optional interval for sending periodic range updates to the remote peer (eth69+)
187    /// The interval is set to one epoch duration (~6.4 minutes), but updates are only sent when
188    /// the block height has advanced by at least one epoch (32 blocks) since the last update
189    pub(crate) range_update_interval: Option<Interval>,
190    /// The last latest block number we sent in a range update
191    /// Used to avoid sending unnecessary updates when block height hasn't changed significantly
192    pub(crate) last_sent_latest_block: Option<u64>,
193}
194
195impl<N: NetworkPrimitives> ActiveSession<N> {
196    /// Returns `true` if the session is currently in the process of disconnecting
197    fn is_disconnecting(&self) -> bool {
198        self.conn.inner().is_disconnecting()
199    }
200
201    /// Returns the next request id
202    const fn next_id(&mut self) -> u64 {
203        let id = self.next_id;
204        self.next_id += 1;
205        id
206    }
207
208    /// Shrinks the capacity of the internal buffers.
209    pub fn shrink_to_fit(&mut self) {
210        self.received_requests_from_remote.shrink_to(MIN_RECEIVED_REQUESTS_CAPACITY);
211        self.queued_outgoing.shrink_to(MAX_QUEUED_OUTGOING_RESPONSES);
212    }
213
214    /// Returns how many responses we've currently queued up.
215    fn queued_response_count(&self) -> usize {
216        self.queued_outgoing.messages.iter().filter(|m| m.is_response()).count()
217    }
218
219    /// Handle a message read from the connection.
220    ///
221    /// Returns an error if the message is considered to be in violation of the protocol.
222    fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
223        /// A macro that handles an incoming request
224        /// This creates a new channel and tries to send the sender half to the session while
225        /// storing the receiver half internally so the pending response can be polled.
226        macro_rules! on_request {
227            ($req:ident, $resp_item:ident, $req_item:ident) => {{
228                let RequestPair { request_id, message: request } = $req;
229                let (tx, response) = oneshot::channel();
230                let received = ReceivedRequest {
231                    request_id,
232                    rx: PeerResponse::$resp_item { response },
233                    received: Instant::now(),
234                };
235                self.received_requests_from_remote.push(received);
236                self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
237                    request,
238                    response: tx,
239                }))
240                .into()
241            }};
242        }
243
244        /// Processes a response received from the peer
245        macro_rules! on_response {
246            ($resp:ident, $item:ident) => {{
247                let RequestPair { request_id, message } = $resp;
248                if let Some(req) = self.inflight_requests.remove(&request_id) {
249                    match req.request {
250                        RequestState::Waiting(PeerRequest::$item { response, .. }) => {
251                            trace!(peer_id=?self.remote_peer_id, ?request_id, "received response from peer");
252                            let _ = response.send(Ok(message));
253                            self.update_request_timeout(req.timestamp, Instant::now());
254                        }
255                        RequestState::Waiting(request) => {
256                            request.send_bad_response();
257                        }
258                        RequestState::TimedOut => {
259                            // request was already timed out internally
260                            self.update_request_timeout(req.timestamp, Instant::now());
261                        }
262                    }
263                } else {
264                    trace!(peer_id=?self.remote_peer_id, ?request_id, "received response to unknown request");
265                    // we received a response to a request we never sent
266                    self.on_bad_message();
267                }
268
269                OnIncomingMessageOutcome::Ok
270            }};
271        }
272
273        match msg {
274            message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
275                error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
276                message,
277            },
278            EthMessage::NewBlockHashes(msg) => {
279                self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
280            }
281            EthMessage::NewBlock(msg) => {
282                let block = NewBlockMessage {
283                    hash: msg.block().header().hash_slow(),
284                    block: Arc::new(*msg),
285                };
286                self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
287            }
288            EthMessage::Transactions(msg) => {
289                self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
290            }
291            EthMessage::NewPooledTransactionHashes66(msg) => {
292                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
293            }
294            EthMessage::NewPooledTransactionHashes68(msg) => {
295                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
296            }
297            EthMessage::GetBlockHeaders(req) => {
298                on_request!(req, BlockHeaders, GetBlockHeaders)
299            }
300            EthMessage::BlockHeaders(resp) => {
301                on_response!(resp, GetBlockHeaders)
302            }
303            EthMessage::GetBlockBodies(req) => {
304                on_request!(req, BlockBodies, GetBlockBodies)
305            }
306            EthMessage::BlockBodies(resp) => {
307                on_response!(resp, GetBlockBodies)
308            }
309            EthMessage::GetPooledTransactions(req) => {
310                on_request!(req, PooledTransactions, GetPooledTransactions)
311            }
312            EthMessage::PooledTransactions(resp) => {
313                on_response!(resp, GetPooledTransactions)
314            }
315            EthMessage::GetNodeData(req) => {
316                on_request!(req, NodeData, GetNodeData)
317            }
318            EthMessage::NodeData(resp) => {
319                on_response!(resp, GetNodeData)
320            }
321            EthMessage::GetReceipts(req) => {
322                if self.conn.version() >= EthVersion::Eth69 {
323                    on_request!(req, Receipts69, GetReceipts69)
324                } else {
325                    on_request!(req, Receipts, GetReceipts)
326                }
327            }
328            EthMessage::GetReceipts70(req) => {
329                on_request!(req, Receipts70, GetReceipts70)
330            }
331            EthMessage::Receipts(resp) => {
332                on_response!(resp, GetReceipts)
333            }
334            EthMessage::Receipts69(resp) => {
335                on_response!(resp, GetReceipts69)
336            }
337            EthMessage::Receipts70(resp) => {
338                on_response!(resp, GetReceipts70)
339            }
340            EthMessage::GetBlockAccessLists(req) => {
341                on_request!(req, BlockAccessLists, GetBlockAccessLists)
342            }
343            EthMessage::BlockAccessLists(resp) => {
344                on_response!(resp, GetBlockAccessLists)
345            }
346            EthMessage::Cells(resp) => {
347                on_response!(resp, GetCells)
348            }
349            EthMessage::BlockRangeUpdate(msg) => {
350                // Validate that earliest <= latest according to the spec
351                if msg.earliest > msg.latest {
352                    return OnIncomingMessageOutcome::BadMessage {
353                        error: EthStreamError::InvalidMessage(MessageError::Other(format!(
354                            "invalid block range: earliest ({}) > latest ({})",
355                            msg.earliest, msg.latest
356                        ))),
357                        message: EthMessage::BlockRangeUpdate(msg),
358                    };
359                }
360
361                // Validate that the latest hash is not zero
362                if msg.latest_hash.is_zero() {
363                    return OnIncomingMessageOutcome::BadMessage {
364                        error: EthStreamError::InvalidMessage(MessageError::Other(
365                            "invalid block range: latest_hash cannot be zero".to_string(),
366                        )),
367                        message: EthMessage::BlockRangeUpdate(msg),
368                    };
369                }
370
371                if let Some(range_info) = self.range_info.as_ref() {
372                    range_info.update(msg.earliest, msg.latest, msg.latest_hash);
373                }
374
375                OnIncomingMessageOutcome::Ok
376            }
377            EthMessage::GetCells(resp) => {
378                on_request!(resp, Cells, GetCells)
379            }
380            EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
381        }
382    }
383
384    /// Handle an internal peer request that will be sent to the remote.
385    fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
386        let version = self.conn.version();
387        if !Self::is_request_supported_for_version(&request, version) {
388            debug!(
389                target: "net",
390                ?request,
391                peer_id=?self.remote_peer_id,
392                ?version,
393                "Request not supported for negotiated eth version",
394            );
395            request.send_err_response(RequestError::UnsupportedCapability);
396            return;
397        }
398
399        let request_id = self.next_id();
400        trace!(?request, peer_id=?self.remote_peer_id, ?request_id, "sending request to peer");
401        let msg = request.create_request_message(request_id).map_versioned(version);
402
403        self.queued_outgoing.push_back(msg.into());
404        let req = InflightRequest {
405            request: RequestState::Waiting(request),
406            timestamp: Instant::now(),
407            deadline,
408        };
409        self.inflight_requests.insert(request_id, req);
410    }
411
412    #[inline]
413    fn is_request_supported_for_version(request: &PeerRequest<N>, version: EthVersion) -> bool {
414        request.is_supported_by_eth_version(version)
415    }
416
417    /// Handle a message received from the internal network
418    fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
419        match msg {
420            PeerMessage::NewBlockHashes(msg) => {
421                self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
422            }
423            PeerMessage::NewBlock(msg) => {
424                self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
425            }
426            PeerMessage::PooledTransactions(msg) => {
427                if msg.is_valid_for_version(self.conn.version()) {
428                    self.queued_outgoing.push_pooled_hashes(msg);
429                } else {
430                    self.queued_outgoing.broadcast_items.sub(msg.len());
431                    debug!(target: "net", ?msg,  version=?self.conn.version(), "Message is invalid for connection version, skipping");
432                }
433            }
434            PeerMessage::EthRequest(req) => {
435                let deadline = self.request_deadline();
436                self.on_internal_peer_request(req, deadline);
437            }
438            PeerMessage::SendTransactions(msg) => {
439                self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
440            }
441            PeerMessage::BlockRangeUpdated(_) => {}
442            PeerMessage::ReceivedTransaction(_) => {
443                unreachable!("Not emitted by network")
444            }
445            PeerMessage::Other(other) => {
446                self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
447            }
448        }
449    }
450
451    /// Returns the deadline timestamp at which the request times out
452    fn request_deadline(&self) -> Instant {
453        Instant::now() +
454            Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
455    }
456
457    /// Handle a Response to the peer
458    ///
459    /// This will queue the response to be sent to the peer
460    fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
461        match resp.try_into_message(id) {
462            Ok(msg) => {
463                self.queued_outgoing.push_back(msg.into());
464            }
465            Err(err) => {
466                debug!(target: "net", %err, "Failed to respond to received request");
467            }
468        }
469    }
470
471    /// Send a message back to the [`SessionManager`](super::SessionManager).
472    ///
473    /// Returns the message if the bounded channel is currently unable to handle this message.
474    #[expect(clippy::result_large_err)]
475    fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
476        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
477
478        match sender
479            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
480        {
481            Ok(_) => Ok(()),
482            Err(err) => {
483                trace!(
484                    target: "net",
485                    %err,
486                    "no capacity for incoming broadcast",
487                );
488                match err {
489                    TrySendError::Full(msg) => Err(msg),
490                    TrySendError::Closed(_) => Ok(()),
491                }
492            }
493        }
494    }
495
496    /// Send a message back to the [`SessionManager`](super::SessionManager)
497    /// covering both broadcasts and incoming requests.
498    ///
499    /// Returns the message if the bounded channel is currently unable to handle this message.
500    #[expect(clippy::result_large_err)]
501    fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
502        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
503
504        match sender
505            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
506        {
507            Ok(_) => Ok(()),
508            Err(err) => {
509                trace!(
510                    target: "net",
511                    %err,
512                    "no capacity for incoming request",
513                );
514                match err {
515                    TrySendError::Full(msg) => Err(msg),
516                    TrySendError::Closed(_) => {
517                        // Note: this would mean the `SessionManager` was dropped, which is already
518                        // handled by checking if the command receiver channel has been closed.
519                        Ok(())
520                    }
521                }
522            }
523        }
524    }
525
526    /// Notify the manager that the peer sent a bad message
527    fn on_bad_message(&self) {
528        let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
529        let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
530    }
531
532    /// Report back that this session has been closed.
533    fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
534        trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
535        let msg = ActiveSessionMessage::Disconnected {
536            peer_id: self.remote_peer_id,
537            remote_addr: self.remote_addr,
538        };
539
540        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
541        self.poll_terminate_message(cx).expect("message is set")
542    }
543
544    /// Report back that this session has been closed due to an error
545    fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
546        let msg = ActiveSessionMessage::ClosedOnConnectionError {
547            peer_id: self.remote_peer_id,
548            remote_addr: self.remote_addr,
549            error,
550        };
551        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
552        self.poll_terminate_message(cx).expect("message is set")
553    }
554
555    /// Starts the disconnect process
556    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
557        Ok(self.conn.inner_mut().start_disconnect(reason)?)
558    }
559
560    /// Flushes the disconnect message and emits the corresponding message
561    fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
562        debug_assert!(self.is_disconnecting(), "not disconnecting");
563
564        // try to close the flush out the remaining Disconnect message
565        let _ = ready!(self.conn.poll_close_unpin(cx));
566        self.emit_disconnect(cx)
567    }
568
569    /// Attempts to disconnect by sending the given disconnect reason
570    fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
571        match self.start_disconnect(reason) {
572            Ok(()) => {
573                // we're done
574                self.poll_disconnect(cx)
575            }
576            Err(err) => {
577                debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
578                self.close_on_error(err, cx)
579            }
580        }
581    }
582
583    /// Checks for _internally_ timed out requests.
584    ///
585    /// If a requests misses its deadline, then it is timed out internally.
586    /// If a request misses the `protocol_breach_request_timeout` then this session is considered in
587    /// protocol violation and will close.
588    ///
589    /// Returns `true` if a peer missed the `protocol_breach_request_timeout`, in which case the
590    /// session should be terminated.
591    #[must_use]
592    fn check_timed_out_requests(&mut self, now: Instant) -> bool {
593        for (id, req) in &mut self.inflight_requests {
594            if req.is_timed_out(now) {
595                if req.is_waiting() {
596                    debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
597                    req.timeout();
598                } else if now - req.timestamp > self.protocol_breach_request_timeout {
599                    return true
600                }
601            }
602        }
603
604        false
605    }
606
607    /// Updates the request timeout with a request's timestamps
608    fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
609        let elapsed = received.saturating_duration_since(sent);
610
611        let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
612        let request_timeout = calculate_new_timeout(current, elapsed);
613        self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
614        self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
615    }
616
617    /// If a termination message is queued this will try to send it
618    fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
619        let (mut tx, msg) = self.terminate_message.take()?;
620        match tx.poll_reserve(cx) {
621            Poll::Pending => {
622                self.terminate_message = Some((tx, msg));
623                return Some(Poll::Pending)
624            }
625            Poll::Ready(Ok(())) => {
626                let _ = tx.send_item(msg);
627            }
628            Poll::Ready(Err(_)) => {
629                // channel closed
630            }
631        }
632        // terminate the task
633        Some(Poll::Ready(()))
634    }
635}
636
637impl<N: NetworkPrimitives> Future for ActiveSession<N> {
638    type Output = ();
639
640    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
641        let this = self.get_mut();
642
643        // if the session is terminate we have to send the termination message before we can close
644        if let Some(terminate) = this.poll_terminate_message(cx) {
645            return terminate
646        }
647
648        if this.is_disconnecting() {
649            return this.poll_disconnect(cx)
650        }
651
652        // The receive loop can be CPU intensive since it involves message decoding which could take
653        // up a lot of resources and increase latencies for other sessions if not yielded manually.
654        // If the budget is exhausted we manually yield back control to the (coop) scheduler. This
655        // manual yield point should prevent situations where polling appears to be frozen. See also <https://tokio.rs/blog/2020-04-preemption>
656        // And tokio's docs on cooperative scheduling <https://docs.rs/tokio/latest/tokio/task/#cooperative-scheduling>
657        let mut budget = 4;
658
659        // The main poll loop that drives the session
660        'main: loop {
661            let mut progress = false;
662
663            // we prioritize incoming commands sent from the session manager
664            loop {
665                match this.commands_rx.poll_next_unpin(cx) {
666                    Poll::Pending => break,
667                    Poll::Ready(None) => {
668                        // this is only possible when the manager was dropped, in which case we also
669                        // terminate this session
670                        return Poll::Ready(())
671                    }
672                    Poll::Ready(Some(cmd)) => {
673                        progress = true;
674                        match cmd {
675                            SessionCommand::Disconnect { reason } => {
676                                debug!(
677                                    target: "net::session",
678                                    ?reason,
679                                    remote_peer_id=?this.remote_peer_id,
680                                    "Received disconnect command for session"
681                                );
682                                let reason =
683                                    reason.unwrap_or(DisconnectReason::DisconnectRequested);
684
685                                return this.try_disconnect(reason, cx)
686                            }
687                            SessionCommand::Message(msg) => {
688                                this.on_internal_peer_message(msg);
689                            }
690                        }
691                    }
692                }
693            }
694
695            // Drain the unbounded channel (broadcast overflow + disconnect commands)
696            while let Poll::Ready(Some(cmd)) = this.unbounded_rx.poll_recv(cx) {
697                progress = true;
698                match cmd {
699                    SessionCommand::Message(msg) => {
700                        this.unbounded_broadcast_msgs.increment(1);
701                        this.on_internal_peer_message(msg);
702                    }
703                    SessionCommand::Disconnect { reason } => {
704                        let reason = reason.unwrap_or(DisconnectReason::DisconnectRequested);
705                        return this.try_disconnect(reason, cx)
706                    }
707                }
708            }
709
710            let deadline = this.request_deadline();
711
712            while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
713                progress = true;
714                this.on_internal_peer_request(req, deadline);
715            }
716
717            // Advance all active requests.
718            // We remove each request one by one and add them back.
719            for idx in (0..this.received_requests_from_remote.len()).rev() {
720                let mut req = this.received_requests_from_remote.swap_remove(idx);
721                match req.rx.poll(cx) {
722                    Poll::Pending => {
723                        // not ready yet
724                        this.received_requests_from_remote.push(req);
725                    }
726                    Poll::Ready(resp) => {
727                        this.handle_outgoing_response(req.request_id, resp);
728                    }
729                }
730            }
731
732            // Send messages by advancing the sink and queuing in buffered messages
733            while this.conn.poll_ready_unpin(cx).is_ready() {
734                if let Some(msg) = this.queued_outgoing.pop_front() {
735                    progress = true;
736                    let res = match msg {
737                        OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
738                        OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
739                        OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
740                    };
741                    if let Err(err) = res {
742                        debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
743                        // notify the manager
744                        return this.close_on_error(err, cx)
745                    }
746                } else {
747                    // no more messages to send over the wire
748                    break
749                }
750            }
751
752            // read incoming messages from the wire
753            'receive: loop {
754                // ensure we still have enough budget for another iteration
755                budget -= 1;
756                if budget == 0 {
757                    // make sure we're woken up again
758                    cx.waker().wake_by_ref();
759                    break 'main
760                }
761
762                // try to resend the pending message that we could not send because the channel was
763                // full. [`PollSender`] will ensure that we're woken up again when the channel is
764                // ready to receive the message, and will only error if the channel is closed.
765                if let Some(msg) = this.pending_message_to_session.take() {
766                    match this.to_session_manager.poll_reserve(cx) {
767                        Poll::Ready(Ok(_)) => {
768                            let _ = this.to_session_manager.send_item(msg);
769                        }
770                        Poll::Ready(Err(_)) => return Poll::Ready(()),
771                        Poll::Pending => {
772                            this.pending_message_to_session = Some(msg);
773                            break 'receive
774                        }
775                    };
776                }
777
778                // check whether we should throttle incoming messages
779                if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
780                    // we're currently waiting for the responses to the peer's requests which aren't
781                    // queued as outgoing yet
782                    //
783                    // Note: we don't need to register the waker here because we polled the requests
784                    // above
785                    break 'receive
786                }
787
788                // we also need to check if we have multiple responses queued up
789                if this.queued_outgoing.messages.len() > MAX_QUEUED_OUTGOING_RESPONSES &&
790                    this.queued_response_count() > MAX_QUEUED_OUTGOING_RESPONSES
791                {
792                    // if we've queued up more responses than allowed, we don't poll for new
793                    // messages and break the receive loop early
794                    //
795                    // Note: we don't need to register the waker here because we still have
796                    // queued messages and the sink impl registered the waker because we've
797                    // already advanced it to `Pending` earlier
798                    break 'receive
799                }
800
801                match this.conn.poll_next_unpin(cx) {
802                    Poll::Pending => break,
803                    Poll::Ready(None) => {
804                        if this.is_disconnecting() {
805                            break
806                        }
807                        debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
808                        return this.emit_disconnect(cx)
809                    }
810                    Poll::Ready(Some(res)) => {
811                        match res {
812                            Ok(msg) => {
813                                trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
814                                // decode and handle message
815                                match this.on_incoming_message(msg) {
816                                    OnIncomingMessageOutcome::Ok => {
817                                        // handled successfully
818                                        progress = true;
819                                    }
820                                    OnIncomingMessageOutcome::BadMessage { error, message } => {
821                                        debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
822                                        this.on_bad_message();
823                                        return this
824                                            .try_disconnect(DisconnectReason::ProtocolBreach, cx)
825                                    }
826                                    OnIncomingMessageOutcome::NoCapacity(msg) => {
827                                        // failed to send due to lack of capacity
828                                        this.pending_message_to_session = Some(msg);
829                                    }
830                                }
831                            }
832                            Err(err) => {
833                                debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
834                                if err.is_protocol_breach() {
835                                    this.on_bad_message();
836                                    return this.try_disconnect(DisconnectReason::ProtocolBreach, cx)
837                                }
838                                return this.close_on_error(err, cx)
839                            }
840                        }
841                    }
842                }
843            }
844
845            if !progress {
846                break 'main
847            }
848        }
849
850        if let Some(interval) = &mut this.range_update_interval {
851            // Check if we should send a range update based on block height changes
852            while interval.poll_tick(cx).is_ready() {
853                let current_latest = this.local_range_info.latest();
854                let should_send = if let Some(last_sent) = this.last_sent_latest_block {
855                    // Only send if block height has advanced by at least one epoch (32 blocks)
856                    current_latest.saturating_sub(last_sent) >= EPOCH_SLOTS
857                } else {
858                    true // First update, always send
859                };
860
861                if should_send {
862                    this.queued_outgoing.push_back(
863                        EthMessage::BlockRangeUpdate(this.local_range_info.to_message()).into(),
864                    );
865                    this.last_sent_latest_block = Some(current_latest);
866                }
867            }
868        }
869
870        while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
871            // check for timed out requests
872            if this.check_timed_out_requests(Instant::now()) &&
873                let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx)
874            {
875                let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
876                this.pending_message_to_session = Some(msg);
877            }
878        }
879
880        this.shrink_to_fit();
881
882        Poll::Pending
883    }
884}
885
886/// Tracks a request received from the peer
887pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
888    /// Protocol Identifier
889    request_id: u64,
890    /// Receiver half of the channel that's supposed to receive the proper response.
891    rx: PeerResponse<N>,
892    /// Timestamp when we read this msg from the wire.
893    #[expect(dead_code)]
894    received: Instant,
895}
896
897/// A request that waits for a response from the peer
898pub(crate) struct InflightRequest<R> {
899    /// Request we sent to peer and the internal response channel
900    request: RequestState<R>,
901    /// Instant when the request was sent
902    timestamp: Instant,
903    /// Time limit for the response
904    deadline: Instant,
905}
906
907// === impl InflightRequest ===
908
909impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
910    /// Returns true if the request is timedout
911    #[inline]
912    fn is_timed_out(&self, now: Instant) -> bool {
913        now > self.deadline
914    }
915
916    /// Returns true if we're still waiting for a response
917    #[inline]
918    const fn is_waiting(&self) -> bool {
919        matches!(self.request, RequestState::Waiting(_))
920    }
921
922    /// This will timeout the request by sending an error response to the internal channel
923    fn timeout(&mut self) {
924        let mut req = RequestState::TimedOut;
925        std::mem::swap(&mut self.request, &mut req);
926
927        if let RequestState::Waiting(req) = req {
928            req.send_err_response(RequestError::Timeout);
929        }
930    }
931}
932
933/// All outcome variants when handling an incoming message
934enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
935    /// Message successfully handled.
936    Ok,
937    /// Message is considered to be in violation of the protocol
938    BadMessage { error: EthStreamError, message: EthMessage<N> },
939    /// Currently no capacity to handle the message
940    NoCapacity(ActiveSessionMessage<N>),
941}
942
943impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
944    for OnIncomingMessageOutcome<N>
945{
946    fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
947        match res {
948            Ok(_) => Self::Ok,
949            Err(msg) => Self::NoCapacity(msg),
950        }
951    }
952}
953
954enum RequestState<R> {
955    /// Waiting for the response
956    Waiting(R),
957    /// Request already timed out
958    TimedOut,
959}
960
961/// Outgoing messages that can be sent over the wire.
962#[derive(Debug)]
963pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
964    /// A message that is owned.
965    Eth(EthMessage<N>),
966    /// A message that may be shared by multiple sessions.
967    Broadcast(EthBroadcastMessage<N>),
968    /// A raw capability message
969    Raw(RawCapabilityMessage),
970}
971
972impl<N: NetworkPrimitives> OutgoingMessage<N> {
973    /// Returns true if this is a response.
974    const fn is_response(&self) -> bool {
975        match self {
976            Self::Eth(msg) => msg.is_response(),
977            _ => false,
978        }
979    }
980
981    /// Returns the number of broadcast items in this message.
982    ///
983    /// For transaction hash announcements this is the number of hashes, for full transaction
984    /// broadcasts it is the number of transactions, and for blocks it is 1.
985    /// Request/response messages return 0.
986    fn broadcast_item_count(&self) -> usize {
987        match self {
988            Self::Eth(msg) => match msg {
989                EthMessage::NewBlockHashes(h) => h.len(),
990                EthMessage::NewPooledTransactionHashes66(h) => h.len(),
991                EthMessage::NewPooledTransactionHashes68(h) => h.hashes.len(),
992                _ => 0,
993            },
994            Self::Broadcast(msg) => match msg {
995                EthBroadcastMessage::NewBlock(_) => 1,
996                EthBroadcastMessage::Transactions(txs) => txs.len(),
997            },
998            Self::Raw(_) => 0,
999        }
1000    }
1001
1002    /// Tries to merge pooled transaction hash announcements into this message, consuming the
1003    /// incoming hashes. Returns `Some(incoming)` back if the variants don't match.
1004    fn try_merge_hashes(
1005        &mut self,
1006        incoming: NewPooledTransactionHashes,
1007    ) -> Option<NewPooledTransactionHashes> {
1008        let Self::Eth(eth) = self else { return Some(incoming) };
1009        match (eth, incoming) {
1010            (
1011                EthMessage::NewPooledTransactionHashes66(existing),
1012                NewPooledTransactionHashes::Eth66(inc),
1013            ) => {
1014                existing.extend(inc);
1015                None
1016            }
1017            (
1018                EthMessage::NewPooledTransactionHashes68(existing),
1019                NewPooledTransactionHashes::Eth68(inc),
1020            ) => {
1021                existing.hashes.extend(inc.hashes);
1022                existing.sizes.extend(inc.sizes);
1023                existing.types.extend(inc.types);
1024                None
1025            }
1026            (_, incoming) => Some(incoming),
1027        }
1028    }
1029}
1030
1031impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
1032    fn from(value: EthMessage<N>) -> Self {
1033        Self::Eth(value)
1034    }
1035}
1036
1037impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
1038    fn from(value: EthBroadcastMessage<N>) -> Self {
1039        Self::Broadcast(value)
1040    }
1041}
1042
1043/// Calculates a new timeout using an updated estimation of the RTT
1044#[inline]
1045fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
1046    let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
1047
1048    // this dampens sudden changes by taking a weighted mean of the old and new values
1049    let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
1050
1051    smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
1052}
1053
1054/// A helper struct that wraps the queue of outgoing messages with broadcast-aware tracking.
1055///
1056/// Tracks both the total number of queued messages (via a metric gauge) and the total number of
1057/// broadcast items (tx hashes, transactions, blocks) via a shared atomic counter. The atomic
1058/// counter is shared with [`ActiveSessionHandle`](super::handle::ActiveSessionHandle) so the
1059/// [`SessionManager`](super::SessionManager) can apply size-based backpressure.
1060pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
1061    messages: VecDeque<OutgoingMessage<N>>,
1062    count: Gauge,
1063    /// Shared counter of buffered broadcast items for size-based backpressure.
1064    broadcast_items: BroadcastItemCounter,
1065}
1066
1067impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
1068    pub(crate) const fn new(metric: Gauge, broadcast_items: BroadcastItemCounter) -> Self {
1069        Self { messages: VecDeque::new(), count: metric, broadcast_items }
1070    }
1071
1072    pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
1073        self.messages.push_back(message);
1074        self.count.increment(1);
1075    }
1076
1077    pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
1078        self.messages.pop_front().inspect(|msg| {
1079            self.count.decrement(1);
1080            let items = msg.broadcast_item_count();
1081            if items > 0 {
1082                self.broadcast_items.sub(items);
1083            }
1084        })
1085    }
1086
1087    /// Pushes a pooled transaction hash announcement, merging into the last queued message if
1088    /// it is the same variant (eth66 or eth68).
1089    pub(crate) fn push_pooled_hashes(&mut self, msg: NewPooledTransactionHashes) {
1090        let msg = if let Some(last) = self.messages.back_mut() {
1091            match last.try_merge_hashes(msg) {
1092                None => return,
1093                Some(msg) => msg,
1094            }
1095        } else {
1096            msg
1097        };
1098        self.messages.push_back(EthMessage::from(msg).into());
1099        self.count.increment(1);
1100    }
1101
1102    pub(crate) fn shrink_to(&mut self, min_capacity: usize) {
1103        self.messages.shrink_to(min_capacity);
1104    }
1105}
1106
1107impl<N: NetworkPrimitives> Drop for QueuedOutgoingMessages<N> {
1108    fn drop(&mut self) {
1109        // Ensure gauge is decremented for any remaining items to avoid metric leak on teardown.
1110        let remaining = self.messages.len();
1111        if remaining > 0 {
1112            self.count.decrement(remaining as f64);
1113        }
1114    }
1115}
1116
1117#[cfg(test)]
1118mod tests {
1119    use super::*;
1120    use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
1121    use alloy_eips::eip2124::ForkFilter;
1122    use reth_chainspec::MAINNET;
1123    use reth_ecies::stream::ECIESStream;
1124    use reth_eth_wire::{
1125        handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockAccessLists,
1126        GetBlockBodies, HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream,
1127        UnauthedP2PStream, UnifiedStatus,
1128    };
1129    use reth_eth_wire_types::{message::MAX_MESSAGE_SIZE, EthMessageID, RawCapabilityMessage};
1130    use reth_ethereum_forks::EthereumHardfork;
1131    use reth_network_peers::pk2id;
1132    use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
1133    use secp256k1::{SecretKey, SECP256K1};
1134    use tokio::{
1135        net::{TcpListener, TcpStream},
1136        sync::mpsc,
1137    };
1138
1139    /// Returns a testing `HelloMessage` and new secretkey
1140    fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
1141        HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
1142    }
1143
1144    struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
1145        _remote_capabilities: Arc<Capabilities>,
1146        active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
1147        active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
1148        to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
1149        secret_key: SecretKey,
1150        local_peer_id: PeerId,
1151        hello: HelloMessageWithProtocols,
1152        status: UnifiedStatus,
1153        fork_filter: ForkFilter,
1154        next_id: usize,
1155    }
1156
1157    impl<N: NetworkPrimitives> SessionBuilder<N> {
1158        fn next_id(&mut self) -> SessionId {
1159            let id = self.next_id;
1160            self.next_id += 1;
1161            SessionId(id)
1162        }
1163
1164        /// Connects a new Eth stream and executes the given closure with that established stream
1165        fn with_client_stream<F, O>(
1166            &self,
1167            local_addr: SocketAddr,
1168            f: F,
1169        ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1170        where
1171            F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
1172            O: Future<Output = ()> + Send + Sync,
1173        {
1174            let mut status = self.status;
1175            let fork_filter = self.fork_filter.clone();
1176            let local_peer_id = self.local_peer_id;
1177            let mut hello = self.hello.clone();
1178            let key = SecretKey::new(&mut rand_08::thread_rng());
1179            hello.id = pk2id(&key.public_key(SECP256K1));
1180            Box::pin(async move {
1181                let outgoing = TcpStream::connect(local_addr).await.unwrap();
1182                let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
1183
1184                let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
1185
1186                let eth_version = p2p_stream.shared_capabilities().eth_version().unwrap();
1187                status.set_eth_version(eth_version);
1188
1189                let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
1190                    .handshake(status, fork_filter)
1191                    .await
1192                    .unwrap();
1193                f(client_stream).await
1194            })
1195        }
1196
1197        async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
1198            let remote_addr = stream.local_addr().unwrap();
1199            let session_id = self.next_id();
1200            let (_disconnect_tx, disconnect_rx) = oneshot::channel();
1201            let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
1202
1203            tokio::task::spawn(start_pending_incoming_session(
1204                Arc::new(EthHandshake::default()),
1205                MAX_MESSAGE_SIZE,
1206                disconnect_rx,
1207                session_id,
1208                stream,
1209                pending_sessions_tx,
1210                remote_addr,
1211                self.secret_key,
1212                self.hello.clone(),
1213                self.status,
1214                self.fork_filter.clone(),
1215                Default::default(),
1216            ));
1217
1218            let mut stream = ReceiverStream::new(pending_sessions_rx);
1219
1220            match stream.next().await.unwrap() {
1221                PendingSessionEvent::Established {
1222                    session_id,
1223                    remote_addr,
1224                    peer_id,
1225                    capabilities,
1226                    conn,
1227                    ..
1228                } => {
1229                    let (_to_session_tx, messages_rx) = mpsc::channel(10);
1230                    let (commands_to_session, commands_rx) = mpsc::channel(10);
1231                    let (_unbounded_tx, unbounded_rx) = mpsc::unbounded_channel();
1232                    let poll_sender = PollSender::new(self.active_session_tx.clone());
1233
1234                    self.to_sessions.push(commands_to_session);
1235
1236                    ActiveSession {
1237                        next_id: 0,
1238                        remote_peer_id: peer_id,
1239                        remote_addr,
1240                        remote_capabilities: Arc::clone(&capabilities),
1241                        session_id,
1242                        commands_rx: ReceiverStream::new(commands_rx),
1243                        unbounded_rx,
1244                        unbounded_broadcast_msgs: Counter::noop(),
1245                        to_session_manager: MeteredPollSender::new(
1246                            poll_sender,
1247                            "network_active_session",
1248                        ),
1249                        pending_message_to_session: None,
1250                        internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
1251                        inflight_requests: Default::default(),
1252                        conn,
1253                        queued_outgoing: QueuedOutgoingMessages::new(
1254                            Gauge::noop(),
1255                            BroadcastItemCounter::new(),
1256                        ),
1257                        received_requests_from_remote: Default::default(),
1258                        internal_request_timeout_interval: tokio::time::interval(
1259                            INITIAL_REQUEST_TIMEOUT,
1260                        ),
1261                        internal_request_timeout: Arc::new(AtomicU64::new(
1262                            INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
1263                        )),
1264                        protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
1265                        terminate_message: None,
1266                        range_info: None,
1267                        local_range_info: BlockRangeInfo::new(
1268                            0,
1269                            1000,
1270                            alloy_primitives::B256::ZERO,
1271                        ),
1272                        range_update_interval: None,
1273                        last_sent_latest_block: None,
1274                    }
1275                }
1276                ev => {
1277                    panic!("unexpected message {ev:?}")
1278                }
1279            }
1280        }
1281    }
1282
1283    impl Default for SessionBuilder {
1284        fn default() -> Self {
1285            let (active_session_tx, active_session_rx) = mpsc::channel(100);
1286
1287            let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1288            let local_peer_id = pk2id(&pk);
1289
1290            Self {
1291                next_id: 0,
1292                _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1293                active_session_tx,
1294                active_session_rx: ReceiverStream::new(active_session_rx),
1295                to_sessions: vec![],
1296                hello: eth_hello(&secret_key),
1297                secret_key,
1298                local_peer_id,
1299                status: StatusBuilder::default().build(),
1300                fork_filter: MAINNET
1301                    .hardfork_fork_filter(EthereumHardfork::Frontier)
1302                    .expect("The Frontier fork filter should exist on mainnet"),
1303            }
1304        }
1305    }
1306
1307    #[tokio::test(flavor = "multi_thread")]
1308    async fn test_disconnect() {
1309        let mut builder = SessionBuilder::default();
1310
1311        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1312        let local_addr = listener.local_addr().unwrap();
1313
1314        let expected_disconnect = DisconnectReason::UselessPeer;
1315
1316        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1317            let msg = client_stream.next().await.unwrap().unwrap_err();
1318            assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1319        });
1320
1321        tokio::task::spawn(async move {
1322            let (incoming, _) = listener.accept().await.unwrap();
1323            let mut session = builder.connect_incoming(incoming).await;
1324
1325            session.start_disconnect(expected_disconnect).unwrap();
1326            session.await
1327        });
1328
1329        fut.await;
1330    }
1331
1332    #[tokio::test(flavor = "multi_thread")]
1333    async fn test_invalid_message_disconnects_with_protocol_breach() {
1334        let mut builder = SessionBuilder::default();
1335
1336        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1337        let local_addr = listener.local_addr().unwrap();
1338
1339        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1340            client_stream
1341                .start_send_raw(RawCapabilityMessage::eth(
1342                    EthMessageID::PooledTransactions,
1343                    vec![0xc0].into(),
1344                ))
1345                .unwrap();
1346            client_stream.flush().await.unwrap();
1347
1348            let msg = client_stream.next().await.unwrap().unwrap_err();
1349            assert_eq!(msg.as_disconnected(), Some(DisconnectReason::ProtocolBreach));
1350        });
1351
1352        let (tx, rx) = oneshot::channel();
1353
1354        tokio::task::spawn(async move {
1355            let (incoming, _) = listener.accept().await.unwrap();
1356            let session = builder.connect_incoming(incoming).await;
1357            session.await;
1358
1359            tx.send(()).unwrap();
1360        });
1361
1362        fut.await;
1363        rx.await.unwrap();
1364    }
1365
1366    #[tokio::test(flavor = "multi_thread")]
1367    async fn handle_dropped_stream() {
1368        let mut builder = SessionBuilder::default();
1369
1370        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1371        let local_addr = listener.local_addr().unwrap();
1372
1373        let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1374            drop(client_stream);
1375            tokio::time::sleep(Duration::from_secs(1)).await
1376        });
1377
1378        let (tx, rx) = oneshot::channel();
1379
1380        tokio::task::spawn(async move {
1381            let (incoming, _) = listener.accept().await.unwrap();
1382            let session = builder.connect_incoming(incoming).await;
1383            session.await;
1384
1385            tx.send(()).unwrap();
1386        });
1387
1388        tokio::task::spawn(fut);
1389
1390        rx.await.unwrap();
1391    }
1392
1393    #[tokio::test(flavor = "multi_thread")]
1394    async fn test_send_many_messages() {
1395        reth_tracing::init_test_tracing();
1396        let mut builder = SessionBuilder::default();
1397
1398        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1399        let local_addr = listener.local_addr().unwrap();
1400
1401        let num_messages = 100;
1402
1403        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1404            for _ in 0..num_messages {
1405                client_stream
1406                    .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1407                    .await
1408                    .unwrap();
1409            }
1410        });
1411
1412        let (tx, rx) = oneshot::channel();
1413
1414        tokio::task::spawn(async move {
1415            let (incoming, _) = listener.accept().await.unwrap();
1416            let session = builder.connect_incoming(incoming).await;
1417            session.await;
1418
1419            tx.send(()).unwrap();
1420        });
1421
1422        tokio::task::spawn(fut);
1423
1424        rx.await.unwrap();
1425    }
1426
1427    #[tokio::test(flavor = "multi_thread")]
1428    async fn test_request_timeout() {
1429        reth_tracing::init_test_tracing();
1430
1431        let mut builder = SessionBuilder::default();
1432
1433        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1434        let local_addr = listener.local_addr().unwrap();
1435
1436        let request_timeout = Duration::from_millis(100);
1437        let drop_timeout = Duration::from_millis(1500);
1438
1439        let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1440            let _client_stream = client_stream;
1441            tokio::time::sleep(drop_timeout * 60).await;
1442        });
1443        tokio::task::spawn(fut);
1444
1445        let (incoming, _) = listener.accept().await.unwrap();
1446        let mut session = builder.connect_incoming(incoming).await;
1447        session
1448            .internal_request_timeout
1449            .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1450        session.protocol_breach_request_timeout = drop_timeout;
1451        session.internal_request_timeout_interval =
1452            tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1453        let (tx, rx) = oneshot::channel();
1454        let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1455        session.on_internal_peer_request(req, Instant::now());
1456        tokio::spawn(session);
1457
1458        let err = rx.await.unwrap().unwrap_err();
1459        assert_eq!(err, RequestError::Timeout);
1460
1461        // wait for protocol breach error
1462        let msg = builder.active_session_rx.next().await.unwrap();
1463        match msg {
1464            ActiveSessionMessage::ProtocolBreach { .. } => {}
1465            ev => unreachable!("{ev:?}"),
1466        }
1467    }
1468
1469    #[test]
1470    fn test_reject_bal_request_for_eth70() {
1471        let (tx, _rx) = oneshot::channel();
1472        let request: PeerRequest<EthNetworkPrimitives> =
1473            PeerRequest::GetBlockAccessLists { request: GetBlockAccessLists(vec![]), response: tx };
1474
1475        assert!(!ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1476            &request,
1477            EthVersion::Eth70
1478        ));
1479        assert!(ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1480            &request,
1481            EthVersion::Eth71
1482        ));
1483    }
1484
1485    #[tokio::test(flavor = "multi_thread")]
1486    async fn test_keep_alive() {
1487        let mut builder = SessionBuilder::default();
1488
1489        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1490        let local_addr = listener.local_addr().unwrap();
1491
1492        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1493            let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1494            client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1495        });
1496
1497        let (tx, rx) = oneshot::channel();
1498
1499        tokio::task::spawn(async move {
1500            let (incoming, _) = listener.accept().await.unwrap();
1501            let session = builder.connect_incoming(incoming).await;
1502            session.await;
1503
1504            tx.send(()).unwrap();
1505        });
1506
1507        tokio::task::spawn(fut);
1508
1509        rx.await.unwrap();
1510    }
1511
1512    // This tests that incoming messages are delivered when there's capacity.
1513    #[tokio::test(flavor = "multi_thread")]
1514    async fn test_send_at_capacity() {
1515        let mut builder = SessionBuilder::default();
1516
1517        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1518        let local_addr = listener.local_addr().unwrap();
1519
1520        let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1521            client_stream
1522                .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1523                .await
1524                .unwrap();
1525            let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1526        });
1527        tokio::task::spawn(fut);
1528
1529        let (incoming, _) = listener.accept().await.unwrap();
1530        let session = builder.connect_incoming(incoming).await;
1531
1532        // fill the entire message buffer with an unrelated message
1533        let mut num_fill_messages = 0;
1534        loop {
1535            if builder
1536                .active_session_tx
1537                .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1538                .is_err()
1539            {
1540                break
1541            }
1542            num_fill_messages += 1;
1543        }
1544
1545        tokio::task::spawn(async move {
1546            session.await;
1547        });
1548
1549        tokio::time::sleep(Duration::from_millis(100)).await;
1550
1551        for _ in 0..num_fill_messages {
1552            let message = builder.active_session_rx.next().await.unwrap();
1553            match message {
1554                ActiveSessionMessage::ProtocolBreach { .. } => {}
1555                ev => unreachable!("{ev:?}"),
1556            }
1557        }
1558
1559        let message = builder.active_session_rx.next().await.unwrap();
1560        match message {
1561            ActiveSessionMessage::ValidMessage {
1562                message: PeerMessage::PooledTransactions(_),
1563                ..
1564            } => {}
1565            _ => unreachable!(),
1566        }
1567    }
1568
1569    #[test]
1570    fn timeout_calculation_sanity_tests() {
1571        let rtt = Duration::from_secs(5);
1572        // timeout for an RTT of `rtt`
1573        let timeout = rtt * TIMEOUT_SCALING;
1574
1575        // if rtt hasn't changed, timeout shouldn't change
1576        assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1577
1578        // if rtt changed, the new timeout should change less than it
1579        assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1580        assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1581        assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1582        assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1583    }
1584}