Skip to main content

reth_network/
state.rs

1//! Keeps track of the state of the network.
2
3use crate::{
4    cache::LruCache,
5    discovery::Discovery,
6    fetch::{BlockResponseOutcome, FetchAction, StateFetcher},
7    message::{BlockRequest, NewBlockMessage, PeerResponse, PeerResponseResult},
8    peers::{PeerAction, PeersManager},
9    session::BlockRangeInfo,
10    FetchClient,
11};
12use alloy_consensus::BlockHeader;
13use alloy_primitives::B256;
14use rand::seq::SliceRandom;
15use reth_eth_wire::{
16    BlockHashNumber, Capabilities, DisconnectReason, EthNetworkPrimitives, GetReceipts70,
17    NetworkPrimitives, NewBlockHashes, NewBlockPayload, UnifiedStatus,
18};
19use reth_ethereum_forks::ForkId;
20use reth_network_api::{DiscoveredEvent, DiscoveryEvent, PeerRequest, PeerRequestSender};
21use reth_network_p2p::receipts::client::ReceiptsResponse;
22use reth_network_peers::PeerId;
23use reth_network_types::{PeerAddr, PeerKind};
24use reth_primitives_traits::Block;
25use std::{
26    collections::{HashMap, VecDeque},
27    fmt,
28    net::{IpAddr, SocketAddr},
29    ops::Deref,
30    sync::{
31        atomic::{AtomicU64, AtomicUsize},
32        Arc,
33    },
34    task::{Context, Poll},
35};
36use tokio::sync::oneshot;
37use tracing::{debug, trace};
38
39/// Cache limit of blocks to keep track of for a single peer.
40const PEER_BLOCK_CACHE_LIMIT: u32 = 512;
41
42/// Wrapper type for the [`BlockNumReader`] trait.
43pub(crate) struct BlockNumReader(Box<dyn reth_storage_api::BlockNumReader>);
44
45impl BlockNumReader {
46    /// Create a new instance with the given reader.
47    pub fn new(reader: impl reth_storage_api::BlockNumReader + 'static) -> Self {
48        Self(Box::new(reader))
49    }
50}
51
52impl fmt::Debug for BlockNumReader {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        f.debug_struct("BlockNumReader").field("inner", &"<dyn BlockNumReader>").finish()
55    }
56}
57
58impl Deref for BlockNumReader {
59    type Target = Box<dyn reth_storage_api::BlockNumReader>;
60
61    fn deref(&self) -> &Self::Target {
62        &self.0
63    }
64}
65
66/// The [`NetworkState`] keeps track of the state of all peers in the network.
67///
68/// This includes:
69///   - [`Discovery`]: manages the discovery protocol, essentially a stream of discovery updates
70///   - [`PeersManager`]: keeps track of connected peers and issues new outgoing connections
71///     depending on the configured capacity.
72///   - [`StateFetcher`]: streams download request (received from outside via channel) which are
73///     then send to the session of the peer.
74///
75/// This type is also responsible for responding for received request.
76#[derive(Debug)]
77pub struct NetworkState<N: NetworkPrimitives = EthNetworkPrimitives> {
78    /// All active peers and their state.
79    active_peers: HashMap<PeerId, ActivePeer<N>>,
80    /// Manages connections to peers.
81    peers_manager: PeersManager,
82    /// Buffered messages until polled.
83    queued_messages: VecDeque<StateAction<N>>,
84    /// The client type that can interact with the chain.
85    ///
86    /// This type is used to fetch the block number after we established a session and received the
87    /// [`UnifiedStatus`] block hash.
88    client: BlockNumReader,
89    /// Network discovery.
90    discovery: Discovery,
91    /// The type that handles requests.
92    ///
93    /// The fetcher streams `RLPx` related requests on a per-peer basis to this type. This type
94    /// will then queue in the request and notify the fetcher once the result has been
95    /// received.
96    state_fetcher: StateFetcher<N>,
97}
98
99impl<N: NetworkPrimitives> NetworkState<N> {
100    /// Create a new state instance with the given params
101    pub(crate) fn new(
102        client: BlockNumReader,
103        discovery: Discovery,
104        peers_manager: PeersManager,
105        num_active_peers: Arc<AtomicUsize>,
106    ) -> Self {
107        let state_fetcher = StateFetcher::new(peers_manager.handle(), num_active_peers);
108        Self {
109            active_peers: Default::default(),
110            peers_manager,
111            queued_messages: Default::default(),
112            client,
113            discovery,
114            state_fetcher,
115        }
116    }
117
118    /// Returns mutable access to the [`PeersManager`]
119    pub(crate) const fn peers_mut(&mut self) -> &mut PeersManager {
120        &mut self.peers_manager
121    }
122
123    /// Returns mutable access to the [`Discovery`]
124    pub(crate) const fn discovery_mut(&mut self) -> &mut Discovery {
125        &mut self.discovery
126    }
127
128    /// Returns access to the [`PeersManager`]
129    pub(crate) const fn peers(&self) -> &PeersManager {
130        &self.peers_manager
131    }
132
133    /// Returns a new [`FetchClient`]
134    pub(crate) fn fetch_client(&self) -> FetchClient<N> {
135        self.state_fetcher.client()
136    }
137
138    /// How many peers we're currently connected to.
139    pub fn num_active_peers(&self) -> usize {
140        self.active_peers.len()
141    }
142
143    /// Event hook for an activated session for the peer.
144    ///
145    /// Returns `Ok` if the session is valid, returns an `Err` if the session is not accepted and
146    /// should be rejected.
147    pub(crate) fn on_session_activated(
148        &mut self,
149        peer: PeerId,
150        capabilities: Arc<Capabilities>,
151        status: Arc<UnifiedStatus>,
152        request_tx: PeerRequestSender<PeerRequest<N>>,
153        timeout: Arc<AtomicU64>,
154        range_info: Option<BlockRangeInfo>,
155    ) {
156        debug_assert!(!self.active_peers.contains_key(&peer), "Already connected; not possible");
157
158        // Use the block number from the peer's status (eth/69+) if available,
159        // otherwise fall back to a local lookup by hash.
160        let block_number = status.latest_block.unwrap_or_else(|| {
161            self.client.block_number(status.blockhash).ok().flatten().unwrap_or_default()
162        });
163        self.state_fetcher.new_active_peer(
164            peer,
165            status.blockhash,
166            block_number,
167            Arc::clone(&capabilities),
168            timeout,
169            range_info,
170        );
171
172        self.active_peers.insert(
173            peer,
174            ActivePeer {
175                best_hash: status.blockhash,
176                capabilities,
177                request_tx,
178                pending_response: None,
179                blocks: LruCache::new(PEER_BLOCK_CACHE_LIMIT),
180            },
181        );
182    }
183
184    /// Event hook for a disconnected session for the given peer.
185    ///
186    /// This will remove the peer from the available set of peers and close all inflight requests.
187    pub(crate) fn on_session_closed(&mut self, peer: PeerId) {
188        self.active_peers.remove(&peer);
189        self.state_fetcher.on_session_closed(&peer);
190    }
191
192    /// Starts propagating the new block to peers that haven't reported the block yet.
193    ///
194    /// This is supposed to be invoked after the block was validated.
195    ///
196    /// > It then sends the block to a small fraction of connected peers (usually the square root of
197    /// > the total number of peers) using the `NewBlock` message.
198    ///
199    /// See also <https://github.com/ethereum/devp2p/blob/master/caps/eth.md>
200    pub(crate) fn announce_new_block(&mut self, msg: NewBlockMessage<N::NewBlockPayload>) {
201        // send a `NewBlock` message to a fraction of the connected peers (square root of the total
202        // number of peers)
203        let num_propagate = (self.active_peers.len() as f64).sqrt() as u64 + 1;
204
205        let number = msg.block.block().header().number();
206        let mut count = 0;
207
208        // Shuffle to propagate to a random sample of peers on every block announcement
209        let mut peers: Vec<_> = self.active_peers.iter_mut().collect();
210        peers.shuffle(&mut rand::rng());
211
212        for (peer_id, peer) in peers {
213            if peer.blocks.contains(&msg.hash) {
214                // skip peers which already reported the block
215                continue
216            }
217
218            // Queue a `NewBlock` message for the peer
219            if count < num_propagate {
220                self.queued_messages
221                    .push_back(StateAction::NewBlock { peer_id: *peer_id, block: msg.clone() });
222
223                // update peer block info
224                if self.state_fetcher.update_peer_block(peer_id, msg.hash, number) {
225                    peer.best_hash = msg.hash;
226                }
227
228                // mark the block as seen by the peer
229                peer.blocks.insert(msg.hash);
230
231                count += 1;
232            }
233
234            if count >= num_propagate {
235                break
236            }
237        }
238    }
239
240    /// Completes the block propagation process started in [`NetworkState::announce_new_block()`]
241    /// but sending `NewBlockHash` broadcast to all peers that haven't seen it yet.
242    pub(crate) fn announce_new_block_hash(&mut self, msg: NewBlockMessage<N::NewBlockPayload>) {
243        let number = msg.block.block().header().number();
244        let hashes = NewBlockHashes(vec![BlockHashNumber { hash: msg.hash, number }]);
245        for (peer_id, peer) in &mut self.active_peers {
246            if peer.blocks.contains(&msg.hash) {
247                // skip peers which already reported the block
248                continue
249            }
250
251            if self.state_fetcher.update_peer_block(peer_id, msg.hash, number) {
252                peer.best_hash = msg.hash;
253            }
254
255            self.queued_messages.push_back(StateAction::NewBlockHashes {
256                peer_id: *peer_id,
257                hashes: hashes.clone(),
258            });
259        }
260    }
261
262    /// Updates the block information for the peer.
263    pub(crate) fn update_peer_block(&mut self, peer_id: &PeerId, hash: B256, number: u64) {
264        if let Some(peer) = self.active_peers.get_mut(peer_id) {
265            peer.best_hash = hash;
266        }
267        self.state_fetcher.update_peer_block(peer_id, hash, number);
268    }
269
270    /// Invoked when a new [`ForkId`] is activated.
271    pub(crate) fn update_fork_id(&self, fork_id: ForkId) {
272        self.discovery.update_fork_id(fork_id)
273    }
274
275    /// Invoked after a `NewBlock` message was received by the peer.
276    ///
277    /// This will keep track of blocks we know a peer has
278    pub(crate) fn on_new_block(&mut self, peer_id: PeerId, hash: B256) {
279        // Mark the blocks as seen
280        if let Some(peer) = self.active_peers.get_mut(&peer_id) {
281            peer.blocks.insert(hash);
282        }
283    }
284
285    /// Invoked for a `NewBlockHashes` broadcast message.
286    pub(crate) fn on_new_block_hashes(&mut self, peer_id: PeerId, hashes: Vec<BlockHashNumber>) {
287        // Mark the blocks as seen
288        if let Some(peer) = self.active_peers.get_mut(&peer_id) {
289            peer.blocks.extend(hashes.into_iter().map(|b| b.hash));
290        }
291    }
292
293    /// Bans the [`IpAddr`] in the discovery service.
294    pub(crate) fn ban_ip_discovery(&self, ip: IpAddr) {
295        trace!(target: "net", ?ip, "Banning discovery");
296        self.discovery.ban_ip(ip)
297    }
298
299    /// Bans the [`PeerId`] and [`IpAddr`] in the discovery service.
300    pub(crate) fn ban_discovery(&self, peer_id: PeerId, ip: IpAddr) {
301        trace!(target: "net", ?peer_id, ?ip, "Banning discovery");
302        self.discovery.ban(peer_id, ip)
303    }
304
305    /// Marks the given peer as trusted.
306    pub(crate) fn add_trusted_peer_id(&mut self, peer_id: PeerId) {
307        self.peers_manager.add_trusted_peer_id(peer_id)
308    }
309
310    /// Adds a trusted peer that may use a hostname, with periodic DNS re-resolution.
311    pub(crate) fn add_trusted_peer_node(&mut self, trusted: reth_network_peers::TrustedPeer) {
312        self.peers_manager.add_trusted_peer_node(trusted)
313    }
314
315    /// Adds a peer and its address with the given kind to the peerset.
316    pub(crate) fn add_peer_kind(
317        &mut self,
318        peer_id: PeerId,
319        kind: Option<PeerKind>,
320        addr: PeerAddr,
321    ) {
322        self.peers_manager.add_peer_kind(peer_id, kind, addr, None)
323    }
324
325    /// Connects a peer and its address with the given kind
326    pub(crate) fn add_and_connect(&mut self, peer_id: PeerId, kind: PeerKind, addr: PeerAddr) {
327        self.peers_manager.add_and_connect_kind(peer_id, kind, addr, None)
328    }
329
330    /// Removes a peer and its address with the given kind from the peerset.
331    pub(crate) fn remove_peer_kind(&mut self, peer_id: PeerId, kind: PeerKind) {
332        match kind {
333            PeerKind::Basic | PeerKind::Static => self.peers_manager.remove_peer(peer_id),
334            PeerKind::Trusted => self.peers_manager.remove_peer_from_trusted_set(peer_id),
335        }
336    }
337
338    /// Event hook for events received from the discovery service.
339    fn on_discovery_event(&mut self, event: DiscoveryEvent) {
340        match event {
341            DiscoveryEvent::NewNode(DiscoveredEvent::EventQueued { peer_id, addr, fork_id }) => {
342                self.queued_messages.push_back(StateAction::DiscoveredNode {
343                    peer_id,
344                    addr,
345                    fork_id,
346                });
347            }
348            DiscoveryEvent::EnrForkId(record, fork_id) => {
349                let peer_id = record.id;
350                let tcp_addr = record.tcp_addr();
351                if tcp_addr.port() == 0 {
352                    return
353                }
354                let udp_addr = record.udp_addr();
355                let addr = PeerAddr::new(tcp_addr, Some(udp_addr));
356                self.queued_messages.push_back(StateAction::DiscoveredEnrForkId {
357                    peer_id,
358                    addr,
359                    fork_id,
360                });
361            }
362        }
363    }
364
365    /// Event hook for new actions derived from the peer management set.
366    fn on_peer_action(&mut self, action: PeerAction) {
367        match action {
368            PeerAction::Connect { peer_id, remote_addr } => {
369                self.queued_messages.push_back(StateAction::Connect { peer_id, remote_addr });
370            }
371            PeerAction::Disconnect { peer_id, reason } => {
372                self.state_fetcher.on_pending_disconnect(&peer_id);
373                self.queued_messages.push_back(StateAction::Disconnect { peer_id, reason });
374            }
375            PeerAction::DisconnectBannedIncoming { peer_id } |
376            PeerAction::DisconnectUntrustedIncoming { peer_id } => {
377                self.state_fetcher.on_pending_disconnect(&peer_id);
378                self.queued_messages.push_back(StateAction::Disconnect { peer_id, reason: None });
379            }
380            PeerAction::DiscoveryBanPeerId { peer_id, ip_addr } => {
381                self.ban_discovery(peer_id, ip_addr)
382            }
383            PeerAction::DiscoveryBanIp { ip_addr } => self.ban_ip_discovery(ip_addr),
384            PeerAction::PeerAdded(peer_id) => {
385                self.queued_messages.push_back(StateAction::PeerAdded(peer_id))
386            }
387            PeerAction::PeerRemoved(peer_id) => {
388                self.queued_messages.push_back(StateAction::PeerRemoved(peer_id))
389            }
390            PeerAction::BanPeer { .. } | PeerAction::UnBanPeer { .. } => {}
391        }
392    }
393
394    /// Sends The message to the peer's session and queues in a response.
395    ///
396    /// Caution: this will replace an already pending response. It's the responsibility of the
397    /// caller to select the peer.
398    fn handle_block_request(&mut self, peer_id: PeerId, request: BlockRequest) {
399        if let Some(ref mut peer) = self.active_peers.get_mut(&peer_id) {
400            let (request, response) = match request {
401                BlockRequest::GetBlockHeaders(request) => {
402                    let (response, rx) = oneshot::channel();
403                    let request = PeerRequest::GetBlockHeaders { request, response };
404                    let response = PeerResponse::BlockHeaders { response: rx };
405                    (request, response)
406                }
407                BlockRequest::GetBlockBodies(request) => {
408                    let (response, rx) = oneshot::channel();
409                    let request = PeerRequest::GetBlockBodies { request, response };
410                    let response = PeerResponse::BlockBodies { response: rx };
411                    (request, response)
412                }
413                BlockRequest::GetBlockAccessLists(request) => {
414                    let (response, rx) = oneshot::channel();
415                    let request = PeerRequest::GetBlockAccessLists { request, response };
416                    let response = PeerResponse::BlockAccessLists { response: rx };
417                    (request, response)
418                }
419                BlockRequest::GetReceipts(request) => {
420                    if peer.capabilities.supports_eth_v70() {
421                        let (response, rx) = oneshot::channel();
422                        let request = PeerRequest::GetReceipts70 {
423                            request: GetReceipts70 {
424                                first_block_receipt_index: 0,
425                                block_hashes: request.0,
426                            },
427                            response,
428                        };
429                        let response = PeerResponse::Receipts70 { response: rx };
430                        (request, response)
431                    } else if peer.capabilities.supports_eth_v69() {
432                        let (response, rx) = oneshot::channel();
433                        let request = PeerRequest::GetReceipts69 { request, response };
434                        let response = PeerResponse::Receipts69 { response: rx };
435                        (request, response)
436                    } else {
437                        let (response, rx) = oneshot::channel();
438                        let request = PeerRequest::GetReceipts { request, response };
439                        let response = PeerResponse::Receipts { response: rx };
440                        (request, response)
441                    }
442                }
443            };
444            let _ = peer.request_tx.to_session_tx.try_send(request);
445            peer.pending_response = Some(response);
446        }
447    }
448
449    /// Handle the outcome of processed response, for example directly queue another request.
450    fn on_block_response_outcome(&mut self, outcome: BlockResponseOutcome) {
451        match outcome {
452            BlockResponseOutcome::Request(peer, request) => {
453                self.handle_block_request(peer, request);
454            }
455            BlockResponseOutcome::BadResponse(peer, reputation_change) => {
456                self.peers_manager.apply_reputation_change(&peer, reputation_change);
457            }
458        }
459    }
460
461    /// Invoked when received a response from a connected peer.
462    ///
463    /// Delegates the response result to the fetcher which may return an outcome specific
464    /// instruction that needs to be handled in [`Self::on_block_response_outcome`]. This could be
465    /// a follow-up request or an instruction to slash the peer's reputation.
466    fn on_eth_response(&mut self, peer: PeerId, resp: PeerResponseResult<N>) {
467        let outcome = match resp {
468            PeerResponseResult::BlockHeaders(res) => {
469                self.state_fetcher.on_block_headers_response(peer, res)
470            }
471            PeerResponseResult::BlockBodies(res) => {
472                self.state_fetcher.on_block_bodies_response(peer, res)
473            }
474            PeerResponseResult::Receipts(res) => {
475                // Legacy eth/66-68: strip bloom filters and wrap in ReceiptsResponse
476                let normalized = res.map(|blocks| {
477                    let receipts = blocks
478                        .into_iter()
479                        .map(|block_receipts| {
480                            block_receipts.into_iter().map(|rwb| rwb.receipt).collect()
481                        })
482                        .collect();
483                    ReceiptsResponse::new(receipts)
484                });
485                self.state_fetcher.on_receipts_response(peer, normalized)
486            }
487            PeerResponseResult::Receipts69(res) => {
488                let normalized = res.map(ReceiptsResponse::new);
489                self.state_fetcher.on_receipts_response(peer, normalized)
490            }
491            PeerResponseResult::Receipts70(res) => {
492                let normalized = res.map(ReceiptsResponse::from);
493                self.state_fetcher.on_receipts_response(peer, normalized)
494            }
495            PeerResponseResult::BlockAccessLists(res) => {
496                self.state_fetcher.on_block_access_lists_response(peer, res)
497            }
498            _ => None,
499        };
500
501        if let Some(outcome) = outcome {
502            self.on_block_response_outcome(outcome);
503        }
504    }
505
506    /// Advances the state
507    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<StateAction<N>> {
508        loop {
509            // drain buffered messages
510            if let Some(message) = self.queued_messages.pop_front() {
511                return Poll::Ready(message)
512            }
513
514            while let Poll::Ready(discovery) = self.discovery.poll(cx) {
515                self.on_discovery_event(discovery);
516            }
517
518            while let Poll::Ready(action) = self.state_fetcher.poll(cx) {
519                match action {
520                    FetchAction::BlockRequest { peer_id, request } => {
521                        self.handle_block_request(peer_id, request)
522                    }
523                }
524            }
525
526            loop {
527                // need to buffer results here to make borrow checker happy
528                let mut closed_sessions = Vec::new();
529                let mut received_responses = Vec::new();
530
531                // poll all connected peers for responses
532                for (id, peer) in &mut self.active_peers {
533                    let Some(mut response) = peer.pending_response.take() else { continue };
534                    match response.poll(cx) {
535                        Poll::Ready(res) => {
536                            // check if the error is due to a closed channel to the session
537                            if res.err().is_some_and(|err| err.is_channel_closed()) {
538                                debug!(
539                                    target: "net",
540                                    ?id,
541                                    "Request canceled, response channel from session closed."
542                                );
543                                // if the channel is closed, this means the peer session is also
544                                // closed, in which case we can invoke the
545                                // [Self::on_closed_session]
546                                // immediately, preventing followup requests and propagate the
547                                // connection dropped error
548                                closed_sessions.push(*id);
549                            } else {
550                                received_responses.push((*id, res));
551                            }
552                        }
553                        Poll::Pending => {
554                            // not ready yet, store again.
555                            peer.pending_response = Some(response);
556                        }
557                    };
558                }
559
560                for peer in closed_sessions {
561                    self.on_session_closed(peer)
562                }
563
564                if received_responses.is_empty() {
565                    break;
566                }
567
568                for (peer_id, resp) in received_responses {
569                    self.on_eth_response(peer_id, resp);
570                }
571            }
572
573            // poll peer manager
574            while let Poll::Ready(action) = self.peers_manager.poll(cx) {
575                self.on_peer_action(action);
576            }
577
578            // We need to poll again in case we have received any responses because they may have
579            // triggered follow-up requests.
580            if self.queued_messages.is_empty() {
581                return Poll::Pending
582            }
583        }
584    }
585}
586
587/// Tracks the state of a Peer with an active Session.
588///
589/// For example known blocks,so we can decide what to announce.
590#[derive(Debug)]
591pub(crate) struct ActivePeer<N: NetworkPrimitives> {
592    /// Best block of the peer.
593    pub(crate) best_hash: B256,
594    /// The capabilities of the remote peer.
595    pub(crate) capabilities: Arc<Capabilities>,
596    /// A communication channel directly to the session task.
597    pub(crate) request_tx: PeerRequestSender<PeerRequest<N>>,
598    /// The response receiver for a currently active request to that peer.
599    pub(crate) pending_response: Option<PeerResponse<N>>,
600    /// Blocks we know the peer has.
601    pub(crate) blocks: LruCache<B256>,
602}
603
604/// Message variants triggered by the [`NetworkState`]
605#[derive(Debug)]
606pub(crate) enum StateAction<N: NetworkPrimitives> {
607    /// Dispatch a `NewBlock` message to the peer
608    NewBlock {
609        /// Target of the message
610        peer_id: PeerId,
611        /// The `NewBlock` message
612        block: NewBlockMessage<N::NewBlockPayload>,
613    },
614    NewBlockHashes {
615        /// Target of the message
616        peer_id: PeerId,
617        /// `NewBlockHashes` message to send to the peer.
618        hashes: NewBlockHashes,
619    },
620    /// Create a new connection to the given node.
621    Connect { remote_addr: SocketAddr, peer_id: PeerId },
622    /// Disconnect an existing connection
623    Disconnect {
624        peer_id: PeerId,
625        /// Why the disconnect was initiated
626        reason: Option<DisconnectReason>,
627    },
628    /// Retrieved a [`ForkId`] from the peer via ENR request, See <https://eips.ethereum.org/EIPS/eip-868>
629    DiscoveredEnrForkId {
630        peer_id: PeerId,
631        /// The address of the peer.
632        addr: PeerAddr,
633        /// The reported [`ForkId`] by this peer.
634        fork_id: ForkId,
635    },
636    /// A new node was found through the discovery, possibly with a `ForkId`
637    DiscoveredNode { peer_id: PeerId, addr: PeerAddr, fork_id: Option<ForkId> },
638    /// A peer was added
639    PeerAdded(PeerId),
640    /// A peer was dropped
641    PeerRemoved(PeerId),
642}
643
644#[cfg(test)]
645mod tests {
646    use crate::{
647        discovery::Discovery,
648        fetch::StateFetcher,
649        peers::PeersManager,
650        state::{BlockNumReader, NetworkState},
651        PeerRequest,
652    };
653    use alloy_consensus::Header;
654    use alloy_primitives::B256;
655    use reth_eth_wire::{BlockBodies, Capabilities, Capability, EthNetworkPrimitives, EthVersion};
656    use reth_ethereum_primitives::BlockBody;
657    use reth_network_api::PeerRequestSender;
658    use reth_network_p2p::{bodies::client::BodiesClient, error::RequestError};
659    use reth_network_peers::PeerId;
660    use reth_storage_api::noop::NoopProvider;
661    use std::{
662        future::poll_fn,
663        sync::{atomic::AtomicU64, Arc},
664    };
665    use tokio::sync::mpsc;
666    use tokio_stream::{wrappers::ReceiverStream, StreamExt};
667
668    /// Returns a testing instance of the [`NetworkState`].
669    fn state() -> NetworkState<EthNetworkPrimitives> {
670        let peers = PeersManager::default();
671        let handle = peers.handle();
672        NetworkState {
673            active_peers: Default::default(),
674            peers_manager: Default::default(),
675            queued_messages: Default::default(),
676            client: BlockNumReader(Box::new(NoopProvider::default())),
677            discovery: Discovery::noop(),
678            state_fetcher: StateFetcher::new(handle, Default::default()),
679        }
680    }
681
682    fn capabilities() -> Arc<Capabilities> {
683        Arc::new(vec![Capability::from(EthVersion::Eth67)].into())
684    }
685
686    // tests that ongoing requests are answered with connection dropped if the session that received
687    // that request is drops the request object.
688    #[tokio::test(flavor = "multi_thread")]
689    async fn test_dropped_active_session() {
690        let mut state = state();
691        let client = state.fetch_client();
692
693        let peer_id = PeerId::random();
694        let (tx, session_rx) = mpsc::channel(1);
695        let peer_tx = PeerRequestSender::new(peer_id, tx);
696
697        state.on_session_activated(
698            peer_id,
699            capabilities(),
700            Arc::default(),
701            peer_tx,
702            Arc::new(AtomicU64::new(1)),
703            None,
704        );
705
706        assert!(state.active_peers.contains_key(&peer_id));
707
708        let body = BlockBody { ommers: vec![Header::default()], ..Default::default() };
709
710        let body_response = body.clone();
711
712        // this mimics an active session that receives the requests from the state
713        tokio::task::spawn(async move {
714            let mut stream = ReceiverStream::new(session_rx);
715            let resp = stream.next().await.unwrap();
716            match resp {
717                PeerRequest::GetBlockBodies { response, .. } => {
718                    response.send(Ok(BlockBodies(vec![body_response]))).unwrap();
719                }
720                _ => unreachable!(),
721            }
722
723            // wait for the next request, then drop
724            let _resp = stream.next().await.unwrap();
725        });
726
727        // spawn the state as future
728        tokio::task::spawn(async move {
729            loop {
730                poll_fn(|cx| state.poll(cx)).await;
731            }
732        });
733
734        // send requests to the state via the client
735        let (peer, bodies) = client.get_block_bodies(vec![B256::random()]).await.unwrap().split();
736        assert_eq!(peer, peer_id);
737        assert_eq!(bodies, vec![body]);
738
739        let resp = client.get_block_bodies(vec![B256::random()]).await;
740        assert!(resp.is_err());
741        assert_eq!(resp.unwrap_err(), RequestError::ConnectionDropped);
742    }
743}