Skip to main content

reth_network/
state.rs

1//! Keeps track of the state of the network.
2
3use crate::{
4    cache::LruCache,
5    discovery::Discovery,
6    fetch::{BlockResponseOutcome, FetchAction, StateFetcher},
7    message::{BlockRequest, NewBlockMessage, PeerResponse, PeerResponseResult},
8    peers::{PeerAction, PeersManager},
9    session::BlockRangeInfo,
10    FetchClient,
11};
12use alloy_consensus::BlockHeader;
13use alloy_primitives::B256;
14use rand::seq::SliceRandom;
15use reth_eth_wire::{
16    BlockHashNumber, Capabilities, DisconnectReason, EthNetworkPrimitives, GetReceipts70,
17    NetworkPrimitives, NewBlockHashes, NewBlockPayload, UnifiedStatus,
18};
19use reth_ethereum_forks::ForkId;
20use reth_network_api::{DiscoveredEvent, DiscoveryEvent, PeerRequest, PeerRequestSender};
21use reth_network_p2p::receipts::client::ReceiptsResponse;
22use reth_network_peers::PeerId;
23use reth_network_types::{PeerAddr, PeerKind};
24use reth_primitives_traits::Block;
25use std::{
26    collections::{HashMap, VecDeque},
27    fmt,
28    net::{IpAddr, SocketAddr},
29    ops::Deref,
30    sync::{
31        atomic::{AtomicU64, AtomicUsize},
32        Arc,
33    },
34    task::{Context, Poll},
35};
36use tokio::sync::oneshot;
37use tracing::{debug, trace};
38
39/// Cache limit of blocks to keep track of for a single peer.
40const PEER_BLOCK_CACHE_LIMIT: u32 = 512;
41
42/// Wrapper type for the [`BlockNumReader`] trait.
43pub(crate) struct BlockNumReader(Box<dyn reth_storage_api::BlockNumReader>);
44
45impl BlockNumReader {
46    /// Create a new instance with the given reader.
47    pub fn new(reader: impl reth_storage_api::BlockNumReader + 'static) -> Self {
48        Self(Box::new(reader))
49    }
50}
51
52impl fmt::Debug for BlockNumReader {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        f.debug_struct("BlockNumReader").field("inner", &"<dyn BlockNumReader>").finish()
55    }
56}
57
58impl Deref for BlockNumReader {
59    type Target = Box<dyn reth_storage_api::BlockNumReader>;
60
61    fn deref(&self) -> &Self::Target {
62        &self.0
63    }
64}
65
66/// The [`NetworkState`] keeps track of the state of all peers in the network.
67///
68/// This includes:
69///   - [`Discovery`]: manages the discovery protocol, essentially a stream of discovery updates
70///   - [`PeersManager`]: keeps track of connected peers and issues new outgoing connections
71///     depending on the configured capacity.
72///   - [`StateFetcher`]: streams download request (received from outside via channel) which are
73///     then send to the session of the peer.
74///
75/// This type is also responsible for responding for received request.
76#[derive(Debug)]
77pub struct NetworkState<N: NetworkPrimitives = EthNetworkPrimitives> {
78    /// All active peers and their state.
79    active_peers: HashMap<PeerId, ActivePeer<N>>,
80    /// Manages connections to peers.
81    peers_manager: PeersManager,
82    /// Buffered messages until polled.
83    queued_messages: VecDeque<StateAction<N>>,
84    /// The client type that can interact with the chain.
85    ///
86    /// This type is used to fetch the block number after we established a session and received the
87    /// [`UnifiedStatus`] block hash.
88    client: BlockNumReader,
89    /// Network discovery.
90    discovery: Discovery,
91    /// The type that handles requests.
92    ///
93    /// The fetcher streams `RLPx` related requests on a per-peer basis to this type. This type
94    /// will then queue in the request and notify the fetcher once the result has been
95    /// received.
96    state_fetcher: StateFetcher<N>,
97}
98
99impl<N: NetworkPrimitives> NetworkState<N> {
100    /// Create a new state instance with the given params
101    pub(crate) fn new(
102        client: BlockNumReader,
103        discovery: Discovery,
104        peers_manager: PeersManager,
105        num_active_peers: Arc<AtomicUsize>,
106    ) -> Self {
107        let state_fetcher = StateFetcher::new(peers_manager.handle(), num_active_peers);
108        Self {
109            active_peers: Default::default(),
110            peers_manager,
111            queued_messages: Default::default(),
112            client,
113            discovery,
114            state_fetcher,
115        }
116    }
117
118    /// Returns mutable access to the [`PeersManager`]
119    pub(crate) const fn peers_mut(&mut self) -> &mut PeersManager {
120        &mut self.peers_manager
121    }
122
123    /// Returns mutable access to the [`Discovery`]
124    pub(crate) const fn discovery_mut(&mut self) -> &mut Discovery {
125        &mut self.discovery
126    }
127
128    /// Returns access to the [`PeersManager`]
129    pub(crate) const fn peers(&self) -> &PeersManager {
130        &self.peers_manager
131    }
132
133    /// Returns a new [`FetchClient`]
134    pub(crate) fn fetch_client(&self) -> FetchClient<N> {
135        self.state_fetcher.client()
136    }
137
138    /// How many peers we're currently connected to.
139    pub fn num_active_peers(&self) -> usize {
140        self.active_peers.len()
141    }
142
143    /// Event hook for an activated session for the peer.
144    ///
145    /// Returns `Ok` if the session is valid, returns an `Err` if the session is not accepted and
146    /// should be rejected.
147    pub(crate) fn on_session_activated(
148        &mut self,
149        peer: PeerId,
150        capabilities: Arc<Capabilities>,
151        status: Arc<UnifiedStatus>,
152        request_tx: PeerRequestSender<PeerRequest<N>>,
153        timeout: Arc<AtomicU64>,
154        range_info: Option<BlockRangeInfo>,
155    ) {
156        debug_assert!(!self.active_peers.contains_key(&peer), "Already connected; not possible");
157
158        // Use the block number from the peer's status (eth/69+) if available,
159        // otherwise fall back to a local lookup by hash.
160        let block_number = status.latest_block.unwrap_or_else(|| {
161            self.client.block_number(status.blockhash).ok().flatten().unwrap_or_default()
162        });
163        self.state_fetcher.new_active_peer(
164            peer,
165            status.blockhash,
166            block_number,
167            Arc::clone(&capabilities),
168            timeout,
169            range_info,
170        );
171
172        self.active_peers.insert(
173            peer,
174            ActivePeer {
175                best_hash: status.blockhash,
176                capabilities,
177                request_tx,
178                pending_response: None,
179                blocks: LruCache::new(PEER_BLOCK_CACHE_LIMIT),
180            },
181        );
182    }
183
184    /// Event hook for a disconnected session for the given peer.
185    ///
186    /// This will remove the peer from the available set of peers and close all inflight requests.
187    pub(crate) fn on_session_closed(&mut self, peer: PeerId) {
188        self.active_peers.remove(&peer);
189        self.state_fetcher.on_session_closed(&peer);
190    }
191
192    /// Starts propagating the new block to peers that haven't reported the block yet.
193    ///
194    /// This is supposed to be invoked after the block was validated.
195    ///
196    /// > It then sends the block to a small fraction of connected peers (usually the square root of
197    /// > the total number of peers) using the `NewBlock` message.
198    ///
199    /// See also <https://github.com/ethereum/devp2p/blob/master/caps/eth.md>
200    pub(crate) fn announce_new_block(&mut self, msg: NewBlockMessage<N::NewBlockPayload>) {
201        // send a `NewBlock` message to a fraction of the connected peers (square root of the total
202        // number of peers)
203        let num_propagate = (self.active_peers.len() as f64).sqrt() as u64 + 1;
204
205        let number = msg.block.block().header().number();
206        let mut count = 0;
207
208        // Shuffle to propagate to a random sample of peers on every block announcement
209        let mut peers: Vec<_> = self.active_peers.iter_mut().collect();
210        peers.shuffle(&mut rand::rng());
211
212        for (peer_id, peer) in peers {
213            if peer.blocks.contains(&msg.hash) {
214                // skip peers which already reported the block
215                continue
216            }
217
218            // Queue a `NewBlock` message for the peer
219            if count < num_propagate {
220                self.queued_messages
221                    .push_back(StateAction::NewBlock { peer_id: *peer_id, block: msg.clone() });
222
223                // update peer block info
224                if self.state_fetcher.update_peer_block(peer_id, msg.hash, number) {
225                    peer.best_hash = msg.hash;
226                }
227
228                // mark the block as seen by the peer
229                peer.blocks.insert(msg.hash);
230
231                count += 1;
232            }
233
234            if count >= num_propagate {
235                break
236            }
237        }
238    }
239
240    /// Completes the block propagation process started in [`NetworkState::announce_new_block()`]
241    /// but sending `NewBlockHash` broadcast to all peers that haven't seen it yet.
242    pub(crate) fn announce_new_block_hash(&mut self, msg: NewBlockMessage<N::NewBlockPayload>) {
243        let number = msg.block.block().header().number();
244        let hashes = NewBlockHashes(vec![BlockHashNumber { hash: msg.hash, number }]);
245        for (peer_id, peer) in &mut self.active_peers {
246            if peer.blocks.contains(&msg.hash) {
247                // skip peers which already reported the block
248                continue
249            }
250
251            if self.state_fetcher.update_peer_block(peer_id, msg.hash, number) {
252                peer.best_hash = msg.hash;
253            }
254
255            self.queued_messages.push_back(StateAction::NewBlockHashes {
256                peer_id: *peer_id,
257                hashes: hashes.clone(),
258            });
259        }
260    }
261
262    /// Updates the block information for the peer.
263    pub(crate) fn update_peer_block(&mut self, peer_id: &PeerId, hash: B256, number: u64) {
264        if let Some(peer) = self.active_peers.get_mut(peer_id) {
265            peer.best_hash = hash;
266        }
267        self.state_fetcher.update_peer_block(peer_id, hash, number);
268    }
269
270    /// Invoked when a new [`ForkId`] is activated.
271    pub(crate) fn update_fork_id(&self, fork_id: ForkId) {
272        self.discovery.update_fork_id(fork_id)
273    }
274
275    /// Invoked after a `NewBlock` message was received by the peer.
276    ///
277    /// This will keep track of blocks we know a peer has
278    pub(crate) fn on_new_block(&mut self, peer_id: PeerId, hash: B256) {
279        // Mark the blocks as seen
280        if let Some(peer) = self.active_peers.get_mut(&peer_id) {
281            peer.blocks.insert(hash);
282        }
283    }
284
285    /// Invoked for a `NewBlockHashes` broadcast message.
286    pub(crate) fn on_new_block_hashes(&mut self, peer_id: PeerId, hashes: Vec<BlockHashNumber>) {
287        // Mark the blocks as seen
288        if let Some(peer) = self.active_peers.get_mut(&peer_id) {
289            peer.blocks.extend(hashes.into_iter().map(|b| b.hash));
290        }
291    }
292
293    /// Bans the [`IpAddr`] in the discovery service.
294    pub(crate) fn ban_ip_discovery(&self, ip: IpAddr) {
295        trace!(target: "net", ?ip, "Banning discovery");
296        self.discovery.ban_ip(ip)
297    }
298
299    /// Bans the [`PeerId`] and [`IpAddr`] in the discovery service.
300    pub(crate) fn ban_discovery(&self, peer_id: PeerId, ip: IpAddr) {
301        trace!(target: "net", ?peer_id, ?ip, "Banning discovery");
302        self.discovery.ban(peer_id, ip)
303    }
304
305    /// Marks the given peer as trusted.
306    pub(crate) fn add_trusted_peer_id(&mut self, peer_id: PeerId) {
307        self.peers_manager.add_trusted_peer_id(peer_id)
308    }
309
310    /// Adds a peer and its address with the given kind to the peerset.
311    pub(crate) fn add_peer_kind(
312        &mut self,
313        peer_id: PeerId,
314        kind: Option<PeerKind>,
315        addr: PeerAddr,
316    ) {
317        self.peers_manager.add_peer_kind(peer_id, kind, addr, None)
318    }
319
320    /// Connects a peer and its address with the given kind
321    pub(crate) fn add_and_connect(&mut self, peer_id: PeerId, kind: PeerKind, addr: PeerAddr) {
322        self.peers_manager.add_and_connect_kind(peer_id, kind, addr, None)
323    }
324
325    /// Removes a peer and its address with the given kind from the peerset.
326    pub(crate) fn remove_peer_kind(&mut self, peer_id: PeerId, kind: PeerKind) {
327        match kind {
328            PeerKind::Basic | PeerKind::Static => self.peers_manager.remove_peer(peer_id),
329            PeerKind::Trusted => self.peers_manager.remove_peer_from_trusted_set(peer_id),
330        }
331    }
332
333    /// Event hook for events received from the discovery service.
334    fn on_discovery_event(&mut self, event: DiscoveryEvent) {
335        match event {
336            DiscoveryEvent::NewNode(DiscoveredEvent::EventQueued { peer_id, addr, fork_id }) => {
337                self.queued_messages.push_back(StateAction::DiscoveredNode {
338                    peer_id,
339                    addr,
340                    fork_id,
341                });
342            }
343            DiscoveryEvent::EnrForkId(record, fork_id) => {
344                let peer_id = record.id;
345                let tcp_addr = record.tcp_addr();
346                if tcp_addr.port() == 0 {
347                    return
348                }
349                let udp_addr = record.udp_addr();
350                let addr = PeerAddr::new(tcp_addr, Some(udp_addr));
351                self.queued_messages.push_back(StateAction::DiscoveredEnrForkId {
352                    peer_id,
353                    addr,
354                    fork_id,
355                });
356            }
357        }
358    }
359
360    /// Event hook for new actions derived from the peer management set.
361    fn on_peer_action(&mut self, action: PeerAction) {
362        match action {
363            PeerAction::Connect { peer_id, remote_addr } => {
364                self.queued_messages.push_back(StateAction::Connect { peer_id, remote_addr });
365            }
366            PeerAction::Disconnect { peer_id, reason } => {
367                self.state_fetcher.on_pending_disconnect(&peer_id);
368                self.queued_messages.push_back(StateAction::Disconnect { peer_id, reason });
369            }
370            PeerAction::DisconnectBannedIncoming { peer_id } |
371            PeerAction::DisconnectUntrustedIncoming { peer_id } => {
372                self.state_fetcher.on_pending_disconnect(&peer_id);
373                self.queued_messages.push_back(StateAction::Disconnect { peer_id, reason: None });
374            }
375            PeerAction::DiscoveryBanPeerId { peer_id, ip_addr } => {
376                self.ban_discovery(peer_id, ip_addr)
377            }
378            PeerAction::DiscoveryBanIp { ip_addr } => self.ban_ip_discovery(ip_addr),
379            PeerAction::PeerAdded(peer_id) => {
380                self.queued_messages.push_back(StateAction::PeerAdded(peer_id))
381            }
382            PeerAction::PeerRemoved(peer_id) => {
383                self.queued_messages.push_back(StateAction::PeerRemoved(peer_id))
384            }
385            PeerAction::BanPeer { .. } | PeerAction::UnBanPeer { .. } => {}
386        }
387    }
388
389    /// Sends The message to the peer's session and queues in a response.
390    ///
391    /// Caution: this will replace an already pending response. It's the responsibility of the
392    /// caller to select the peer.
393    fn handle_block_request(&mut self, peer_id: PeerId, request: BlockRequest) {
394        if let Some(ref mut peer) = self.active_peers.get_mut(&peer_id) {
395            let (request, response) = match request {
396                BlockRequest::GetBlockHeaders(request) => {
397                    let (response, rx) = oneshot::channel();
398                    let request = PeerRequest::GetBlockHeaders { request, response };
399                    let response = PeerResponse::BlockHeaders { response: rx };
400                    (request, response)
401                }
402                BlockRequest::GetBlockBodies(request) => {
403                    let (response, rx) = oneshot::channel();
404                    let request = PeerRequest::GetBlockBodies { request, response };
405                    let response = PeerResponse::BlockBodies { response: rx };
406                    (request, response)
407                }
408                BlockRequest::GetBlockAccessLists(request) => {
409                    let (response, rx) = oneshot::channel();
410                    let request = PeerRequest::GetBlockAccessLists { request, response };
411                    let response = PeerResponse::BlockAccessLists { response: rx };
412                    (request, response)
413                }
414                BlockRequest::GetReceipts(request) => {
415                    if peer.capabilities.supports_eth_v70() {
416                        let (response, rx) = oneshot::channel();
417                        let request = PeerRequest::GetReceipts70 {
418                            request: GetReceipts70 {
419                                first_block_receipt_index: 0,
420                                block_hashes: request.0,
421                            },
422                            response,
423                        };
424                        let response = PeerResponse::Receipts70 { response: rx };
425                        (request, response)
426                    } else if peer.capabilities.supports_eth_v69() {
427                        let (response, rx) = oneshot::channel();
428                        let request = PeerRequest::GetReceipts69 { request, response };
429                        let response = PeerResponse::Receipts69 { response: rx };
430                        (request, response)
431                    } else {
432                        let (response, rx) = oneshot::channel();
433                        let request = PeerRequest::GetReceipts { request, response };
434                        let response = PeerResponse::Receipts { response: rx };
435                        (request, response)
436                    }
437                }
438            };
439            let _ = peer.request_tx.to_session_tx.try_send(request);
440            peer.pending_response = Some(response);
441        }
442    }
443
444    /// Handle the outcome of processed response, for example directly queue another request.
445    fn on_block_response_outcome(&mut self, outcome: BlockResponseOutcome) {
446        match outcome {
447            BlockResponseOutcome::Request(peer, request) => {
448                self.handle_block_request(peer, request);
449            }
450            BlockResponseOutcome::BadResponse(peer, reputation_change) => {
451                self.peers_manager.apply_reputation_change(&peer, reputation_change);
452            }
453        }
454    }
455
456    /// Invoked when received a response from a connected peer.
457    ///
458    /// Delegates the response result to the fetcher which may return an outcome specific
459    /// instruction that needs to be handled in [`Self::on_block_response_outcome`]. This could be
460    /// a follow-up request or an instruction to slash the peer's reputation.
461    fn on_eth_response(&mut self, peer: PeerId, resp: PeerResponseResult<N>) {
462        let outcome = match resp {
463            PeerResponseResult::BlockHeaders(res) => {
464                self.state_fetcher.on_block_headers_response(peer, res)
465            }
466            PeerResponseResult::BlockBodies(res) => {
467                self.state_fetcher.on_block_bodies_response(peer, res)
468            }
469            PeerResponseResult::Receipts(res) => {
470                // Legacy eth/66-68: strip bloom filters and wrap in ReceiptsResponse
471                let normalized = res.map(|blocks| {
472                    let receipts = blocks
473                        .into_iter()
474                        .map(|block_receipts| {
475                            block_receipts.into_iter().map(|rwb| rwb.receipt).collect()
476                        })
477                        .collect();
478                    ReceiptsResponse::new(receipts)
479                });
480                self.state_fetcher.on_receipts_response(peer, normalized)
481            }
482            PeerResponseResult::Receipts69(res) => {
483                let normalized = res.map(ReceiptsResponse::new);
484                self.state_fetcher.on_receipts_response(peer, normalized)
485            }
486            PeerResponseResult::Receipts70(res) => {
487                let normalized = res.map(ReceiptsResponse::from);
488                self.state_fetcher.on_receipts_response(peer, normalized)
489            }
490            PeerResponseResult::BlockAccessLists(res) => {
491                self.state_fetcher.on_block_access_lists_response(peer, res)
492            }
493            _ => None,
494        };
495
496        if let Some(outcome) = outcome {
497            self.on_block_response_outcome(outcome);
498        }
499    }
500
501    /// Advances the state
502    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<StateAction<N>> {
503        loop {
504            // drain buffered messages
505            if let Some(message) = self.queued_messages.pop_front() {
506                return Poll::Ready(message)
507            }
508
509            while let Poll::Ready(discovery) = self.discovery.poll(cx) {
510                self.on_discovery_event(discovery);
511            }
512
513            while let Poll::Ready(action) = self.state_fetcher.poll(cx) {
514                match action {
515                    FetchAction::BlockRequest { peer_id, request } => {
516                        self.handle_block_request(peer_id, request)
517                    }
518                }
519            }
520
521            loop {
522                // need to buffer results here to make borrow checker happy
523                let mut closed_sessions = Vec::new();
524                let mut received_responses = Vec::new();
525
526                // poll all connected peers for responses
527                for (id, peer) in &mut self.active_peers {
528                    let Some(mut response) = peer.pending_response.take() else { continue };
529                    match response.poll(cx) {
530                        Poll::Ready(res) => {
531                            // check if the error is due to a closed channel to the session
532                            if res.err().is_some_and(|err| err.is_channel_closed()) {
533                                debug!(
534                                    target: "net",
535                                    ?id,
536                                    "Request canceled, response channel from session closed."
537                                );
538                                // if the channel is closed, this means the peer session is also
539                                // closed, in which case we can invoke the
540                                // [Self::on_closed_session]
541                                // immediately, preventing followup requests and propagate the
542                                // connection dropped error
543                                closed_sessions.push(*id);
544                            } else {
545                                received_responses.push((*id, res));
546                            }
547                        }
548                        Poll::Pending => {
549                            // not ready yet, store again.
550                            peer.pending_response = Some(response);
551                        }
552                    };
553                }
554
555                for peer in closed_sessions {
556                    self.on_session_closed(peer)
557                }
558
559                if received_responses.is_empty() {
560                    break;
561                }
562
563                for (peer_id, resp) in received_responses {
564                    self.on_eth_response(peer_id, resp);
565                }
566            }
567
568            // poll peer manager
569            while let Poll::Ready(action) = self.peers_manager.poll(cx) {
570                self.on_peer_action(action);
571            }
572
573            // We need to poll again in case we have received any responses because they may have
574            // triggered follow-up requests.
575            if self.queued_messages.is_empty() {
576                return Poll::Pending
577            }
578        }
579    }
580}
581
582/// Tracks the state of a Peer with an active Session.
583///
584/// For example known blocks,so we can decide what to announce.
585#[derive(Debug)]
586pub(crate) struct ActivePeer<N: NetworkPrimitives> {
587    /// Best block of the peer.
588    pub(crate) best_hash: B256,
589    /// The capabilities of the remote peer.
590    pub(crate) capabilities: Arc<Capabilities>,
591    /// A communication channel directly to the session task.
592    pub(crate) request_tx: PeerRequestSender<PeerRequest<N>>,
593    /// The response receiver for a currently active request to that peer.
594    pub(crate) pending_response: Option<PeerResponse<N>>,
595    /// Blocks we know the peer has.
596    pub(crate) blocks: LruCache<B256>,
597}
598
599/// Message variants triggered by the [`NetworkState`]
600#[derive(Debug)]
601pub(crate) enum StateAction<N: NetworkPrimitives> {
602    /// Dispatch a `NewBlock` message to the peer
603    NewBlock {
604        /// Target of the message
605        peer_id: PeerId,
606        /// The `NewBlock` message
607        block: NewBlockMessage<N::NewBlockPayload>,
608    },
609    NewBlockHashes {
610        /// Target of the message
611        peer_id: PeerId,
612        /// `NewBlockHashes` message to send to the peer.
613        hashes: NewBlockHashes,
614    },
615    /// Create a new connection to the given node.
616    Connect { remote_addr: SocketAddr, peer_id: PeerId },
617    /// Disconnect an existing connection
618    Disconnect {
619        peer_id: PeerId,
620        /// Why the disconnect was initiated
621        reason: Option<DisconnectReason>,
622    },
623    /// Retrieved a [`ForkId`] from the peer via ENR request, See <https://eips.ethereum.org/EIPS/eip-868>
624    DiscoveredEnrForkId {
625        peer_id: PeerId,
626        /// The address of the peer.
627        addr: PeerAddr,
628        /// The reported [`ForkId`] by this peer.
629        fork_id: ForkId,
630    },
631    /// A new node was found through the discovery, possibly with a `ForkId`
632    DiscoveredNode { peer_id: PeerId, addr: PeerAddr, fork_id: Option<ForkId> },
633    /// A peer was added
634    PeerAdded(PeerId),
635    /// A peer was dropped
636    PeerRemoved(PeerId),
637}
638
639#[cfg(test)]
640mod tests {
641    use crate::{
642        discovery::Discovery,
643        fetch::StateFetcher,
644        peers::PeersManager,
645        state::{BlockNumReader, NetworkState},
646        PeerRequest,
647    };
648    use alloy_consensus::Header;
649    use alloy_primitives::B256;
650    use reth_eth_wire::{BlockBodies, Capabilities, Capability, EthNetworkPrimitives, EthVersion};
651    use reth_ethereum_primitives::BlockBody;
652    use reth_network_api::PeerRequestSender;
653    use reth_network_p2p::{bodies::client::BodiesClient, error::RequestError};
654    use reth_network_peers::PeerId;
655    use reth_storage_api::noop::NoopProvider;
656    use std::{
657        future::poll_fn,
658        sync::{atomic::AtomicU64, Arc},
659    };
660    use tokio::sync::mpsc;
661    use tokio_stream::{wrappers::ReceiverStream, StreamExt};
662
663    /// Returns a testing instance of the [`NetworkState`].
664    fn state() -> NetworkState<EthNetworkPrimitives> {
665        let peers = PeersManager::default();
666        let handle = peers.handle();
667        NetworkState {
668            active_peers: Default::default(),
669            peers_manager: Default::default(),
670            queued_messages: Default::default(),
671            client: BlockNumReader(Box::new(NoopProvider::default())),
672            discovery: Discovery::noop(),
673            state_fetcher: StateFetcher::new(handle, Default::default()),
674        }
675    }
676
677    fn capabilities() -> Arc<Capabilities> {
678        Arc::new(vec![Capability::from(EthVersion::Eth67)].into())
679    }
680
681    // tests that ongoing requests are answered with connection dropped if the session that received
682    // that request is drops the request object.
683    #[tokio::test(flavor = "multi_thread")]
684    async fn test_dropped_active_session() {
685        let mut state = state();
686        let client = state.fetch_client();
687
688        let peer_id = PeerId::random();
689        let (tx, session_rx) = mpsc::channel(1);
690        let peer_tx = PeerRequestSender::new(peer_id, tx);
691
692        state.on_session_activated(
693            peer_id,
694            capabilities(),
695            Arc::default(),
696            peer_tx,
697            Arc::new(AtomicU64::new(1)),
698            None,
699        );
700
701        assert!(state.active_peers.contains_key(&peer_id));
702
703        let body = BlockBody { ommers: vec![Header::default()], ..Default::default() };
704
705        let body_response = body.clone();
706
707        // this mimics an active session that receives the requests from the state
708        tokio::task::spawn(async move {
709            let mut stream = ReceiverStream::new(session_rx);
710            let resp = stream.next().await.unwrap();
711            match resp {
712                PeerRequest::GetBlockBodies { response, .. } => {
713                    response.send(Ok(BlockBodies(vec![body_response]))).unwrap();
714                }
715                _ => unreachable!(),
716            }
717
718            // wait for the next request, then drop
719            let _resp = stream.next().await.unwrap();
720        });
721
722        // spawn the state as future
723        tokio::task::spawn(async move {
724            loop {
725                poll_fn(|cx| state.poll(cx)).await;
726            }
727        });
728
729        // send requests to the state via the client
730        let (peer, bodies) = client.get_block_bodies(vec![B256::random()]).await.unwrap().split();
731        assert_eq!(peer, peer_id);
732        assert_eq!(bodies, vec![body]);
733
734        let resp = client.get_block_bodies(vec![B256::random()]).await;
735        assert!(resp.is_err());
736        assert_eq!(resp.unwrap_err(), RequestError::ConnectionDropped);
737    }
738}