reth_network/test_utils/
testnet.rs

1//! A network implementation for testing purposes.
2
3use crate::{
4    builder::ETH_REQUEST_CHANNEL_CAPACITY,
5    error::NetworkError,
6    eth_requests::EthRequestHandler,
7    protocol::IntoRlpxSubProtocol,
8    transactions::{
9        config::{StrictEthAnnouncementFilter, TransactionPropagationKind},
10        policy::NetworkPolicies,
11        TransactionsHandle, TransactionsManager, TransactionsManagerConfig,
12    },
13    NetworkConfig, NetworkConfigBuilder, NetworkHandle, NetworkManager,
14};
15use futures::{FutureExt, StreamExt};
16use pin_project::pin_project;
17use reth_chainspec::{ChainSpecProvider, EthereumHardforks, Hardforks};
18use reth_eth_wire::{
19    protocol::Protocol, DisconnectReason, EthNetworkPrimitives, HelloMessageWithProtocols,
20};
21use reth_ethereum_primitives::{PooledTransactionVariant, TransactionSigned};
22use reth_network_api::{
23    events::{PeerEvent, SessionInfo},
24    test_utils::{PeersHandle, PeersHandleProvider},
25    NetworkEvent, NetworkEventListenerProvider, NetworkInfo, Peers,
26};
27use reth_network_peers::PeerId;
28use reth_storage_api::{
29    noop::NoopProvider, BlockReader, BlockReaderIdExt, HeaderProvider, StateProviderFactory,
30};
31use reth_tasks::TokioTaskExecutor;
32use reth_tokio_util::EventStream;
33use reth_transaction_pool::{
34    blobstore::InMemoryBlobStore,
35    test_utils::{TestPool, TestPoolBuilder},
36    EthTransactionPool, PoolTransaction, TransactionPool, TransactionValidationTaskExecutor,
37};
38use secp256k1::SecretKey;
39use std::{
40    fmt,
41    future::Future,
42    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
43    pin::Pin,
44    task::{Context, Poll},
45};
46use tokio::{
47    sync::{
48        mpsc::{channel, unbounded_channel},
49        oneshot,
50    },
51    task::JoinHandle,
52};
53
54/// A test network consisting of multiple peers.
55pub struct Testnet<C, Pool> {
56    /// All running peers in the network.
57    peers: Vec<Peer<C, Pool>>,
58}
59
60// === impl Testnet ===
61
62impl<C> Testnet<C, TestPool>
63where
64    C: BlockReader + HeaderProvider + Clone + 'static + ChainSpecProvider<ChainSpec: Hardforks>,
65{
66    /// Same as [`Self::try_create_with`] but panics on error
67    pub async fn create_with(num_peers: usize, provider: C) -> Self {
68        Self::try_create_with(num_peers, provider).await.unwrap()
69    }
70
71    /// Creates a new [`Testnet`] with the given number of peers and the provider.
72    pub async fn try_create_with(num_peers: usize, provider: C) -> Result<Self, NetworkError> {
73        let mut this = Self { peers: Vec::with_capacity(num_peers) };
74        for _ in 0..num_peers {
75            let config = PeerConfig::new(provider.clone());
76            this.add_peer_with_config(config).await?;
77        }
78        Ok(this)
79    }
80
81    /// Extend the list of peers with new peers that are configured with each of the given
82    /// [`PeerConfig`]s.
83    pub async fn extend_peer_with_config(
84        &mut self,
85        configs: impl IntoIterator<Item = PeerConfig<C>>,
86    ) -> Result<(), NetworkError> {
87        let peers = configs.into_iter().map(|c| c.launch()).collect::<Vec<_>>();
88        let peers = futures::future::join_all(peers).await;
89        for peer in peers {
90            self.peers.push(peer?);
91        }
92        Ok(())
93    }
94}
95
96impl<C, Pool> Testnet<C, Pool>
97where
98    C: BlockReader + HeaderProvider + Clone + 'static,
99    Pool: TransactionPool,
100{
101    /// Return a mutable slice of all peers.
102    pub fn peers_mut(&mut self) -> &mut [Peer<C, Pool>] {
103        &mut self.peers
104    }
105
106    /// Return a slice of all peers.
107    pub fn peers(&self) -> &[Peer<C, Pool>] {
108        &self.peers
109    }
110
111    /// Remove a peer from the [`Testnet`] and return it.
112    ///
113    /// # Panics
114    /// If the index is out of bounds.
115    pub fn remove_peer(&mut self, index: usize) -> Peer<C, Pool> {
116        self.peers.remove(index)
117    }
118
119    /// Return a mutable iterator over all peers.
120    pub fn peers_iter_mut(&mut self) -> impl Iterator<Item = &mut Peer<C, Pool>> + '_ {
121        self.peers.iter_mut()
122    }
123
124    /// Return an iterator over all peers.
125    pub fn peers_iter(&self) -> impl Iterator<Item = &Peer<C, Pool>> + '_ {
126        self.peers.iter()
127    }
128
129    /// Add a peer to the [`Testnet`] with the given [`PeerConfig`].
130    pub async fn add_peer_with_config(
131        &mut self,
132        config: PeerConfig<C>,
133    ) -> Result<(), NetworkError> {
134        let PeerConfig { config, client, secret_key } = config;
135
136        let network = NetworkManager::new(config).await?;
137        let peer = Peer {
138            network,
139            client,
140            secret_key,
141            request_handler: None,
142            transactions_manager: None,
143            pool: None,
144        };
145        self.peers.push(peer);
146        Ok(())
147    }
148
149    /// Returns all handles to the networks
150    pub fn handles(&self) -> impl Iterator<Item = NetworkHandle<EthNetworkPrimitives>> + '_ {
151        self.peers.iter().map(|p| p.handle())
152    }
153
154    /// Maps the pool of each peer with the given closure
155    pub fn map_pool<F, P>(self, f: F) -> Testnet<C, P>
156    where
157        F: Fn(Peer<C, Pool>) -> Peer<C, P>,
158        P: TransactionPool,
159    {
160        Testnet { peers: self.peers.into_iter().map(f).collect() }
161    }
162
163    /// Apply a closure on each peer
164    pub fn for_each<F>(&self, f: F)
165    where
166        F: Fn(&Peer<C, Pool>),
167    {
168        self.peers.iter().for_each(f)
169    }
170
171    /// Apply a closure on each peer
172    pub fn for_each_mut<F>(&mut self, f: F)
173    where
174        F: FnMut(&mut Peer<C, Pool>),
175    {
176        self.peers.iter_mut().for_each(f)
177    }
178}
179
180impl<C, Pool> Testnet<C, Pool>
181where
182    C: ChainSpecProvider<ChainSpec: EthereumHardforks>
183        + StateProviderFactory
184        + BlockReaderIdExt
185        + HeaderProvider
186        + Clone
187        + 'static,
188    Pool: TransactionPool,
189{
190    /// Installs an eth pool on each peer
191    pub fn with_eth_pool(self) -> Testnet<C, EthTransactionPool<C, InMemoryBlobStore>> {
192        self.map_pool(|peer| {
193            let blob_store = InMemoryBlobStore::default();
194            let pool = TransactionValidationTaskExecutor::eth(
195                peer.client.clone(),
196                blob_store.clone(),
197                TokioTaskExecutor::default(),
198            );
199            peer.map_transactions_manager(EthTransactionPool::eth_pool(
200                pool,
201                blob_store,
202                Default::default(),
203            ))
204        })
205    }
206
207    /// Installs an eth pool on each peer with custom transaction manager config
208    pub fn with_eth_pool_config(
209        self,
210        tx_manager_config: TransactionsManagerConfig,
211    ) -> Testnet<C, EthTransactionPool<C, InMemoryBlobStore>> {
212        self.with_eth_pool_config_and_policy(tx_manager_config, Default::default())
213    }
214
215    /// Installs an eth pool on each peer with custom transaction manager config and policy.
216    pub fn with_eth_pool_config_and_policy(
217        self,
218        tx_manager_config: TransactionsManagerConfig,
219        policy: TransactionPropagationKind,
220    ) -> Testnet<C, EthTransactionPool<C, InMemoryBlobStore>> {
221        self.map_pool(|peer| {
222            let blob_store = InMemoryBlobStore::default();
223            let pool = TransactionValidationTaskExecutor::eth(
224                peer.client.clone(),
225                blob_store.clone(),
226                TokioTaskExecutor::default(),
227            );
228
229            peer.map_transactions_manager_with(
230                EthTransactionPool::eth_pool(pool, blob_store, Default::default()),
231                tx_manager_config.clone(),
232                policy,
233            )
234        })
235    }
236}
237
238impl<C, Pool> Testnet<C, Pool>
239where
240    C: BlockReader<
241            Block = reth_ethereum_primitives::Block,
242            Receipt = reth_ethereum_primitives::Receipt,
243            Header = alloy_consensus::Header,
244        > + HeaderProvider
245        + Clone
246        + Unpin
247        + 'static,
248    Pool: TransactionPool<
249            Transaction: PoolTransaction<
250                Consensus = TransactionSigned,
251                Pooled = PooledTransactionVariant,
252            >,
253        > + Unpin
254        + 'static,
255{
256    /// Spawns the testnet to a separate task
257    pub fn spawn(self) -> TestnetHandle<C, Pool> {
258        let (tx, rx) = oneshot::channel::<oneshot::Sender<Self>>();
259        let peers = self.peers.iter().map(|peer| peer.peer_handle()).collect::<Vec<_>>();
260        let mut net = self;
261        let handle = tokio::task::spawn(async move {
262            let mut tx = None;
263            tokio::select! {
264                _ = &mut net => {}
265                inc = rx => {
266                    tx = inc.ok();
267                }
268            }
269            if let Some(tx) = tx {
270                let _ = tx.send(net);
271            }
272        });
273
274        TestnetHandle { _handle: handle, peers, terminate: tx }
275    }
276}
277
278impl Testnet<NoopProvider, TestPool> {
279    /// Same as [`Self::try_create`] but panics on error
280    pub async fn create(num_peers: usize) -> Self {
281        Self::try_create(num_peers).await.unwrap()
282    }
283
284    /// Creates a new [`Testnet`] with the given number of peers
285    pub async fn try_create(num_peers: usize) -> Result<Self, NetworkError> {
286        let mut this = Self::default();
287
288        this.extend_peer_with_config((0..num_peers).map(|_| Default::default())).await?;
289        Ok(this)
290    }
291
292    /// Add a peer to the [`Testnet`]
293    pub async fn add_peer(&mut self) -> Result<(), NetworkError> {
294        self.add_peer_with_config(Default::default()).await
295    }
296}
297
298impl<C, Pool> Default for Testnet<C, Pool> {
299    fn default() -> Self {
300        Self { peers: Vec::new() }
301    }
302}
303
304impl<C, Pool> fmt::Debug for Testnet<C, Pool> {
305    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306        f.debug_struct("Testnet {{}}").finish_non_exhaustive()
307    }
308}
309
310impl<C, Pool> Future for Testnet<C, Pool>
311where
312    C: BlockReader<
313            Block = reth_ethereum_primitives::Block,
314            Receipt = reth_ethereum_primitives::Receipt,
315            Header = alloy_consensus::Header,
316        > + HeaderProvider
317        + Unpin
318        + 'static,
319    Pool: TransactionPool<
320            Transaction: PoolTransaction<
321                Consensus = TransactionSigned,
322                Pooled = PooledTransactionVariant,
323            >,
324        > + Unpin
325        + 'static,
326{
327    type Output = ();
328
329    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
330        let this = self.get_mut();
331        for peer in &mut this.peers {
332            let _ = peer.poll_unpin(cx);
333        }
334        Poll::Pending
335    }
336}
337
338/// A handle to a [`Testnet`] that can be shared.
339#[derive(Debug)]
340pub struct TestnetHandle<C, Pool> {
341    _handle: JoinHandle<()>,
342    peers: Vec<PeerHandle<Pool>>,
343    terminate: oneshot::Sender<oneshot::Sender<Testnet<C, Pool>>>,
344}
345
346// === impl TestnetHandle ===
347
348impl<C, Pool> TestnetHandle<C, Pool> {
349    /// Terminates the task and returns the [`Testnet`] back.
350    pub async fn terminate(self) -> Testnet<C, Pool> {
351        let (tx, rx) = oneshot::channel();
352        self.terminate.send(tx).unwrap();
353        rx.await.unwrap()
354    }
355
356    /// Returns the [`PeerHandle`]s of this [`Testnet`].
357    pub fn peers(&self) -> &[PeerHandle<Pool>] {
358        &self.peers
359    }
360
361    /// Connects all peers with each other.
362    ///
363    /// This establishes sessions concurrently between all peers.
364    ///
365    /// Returns once all sessions are established.
366    pub async fn connect_peers(&self) {
367        if self.peers.len() < 2 {
368            return
369        }
370
371        // add an event stream for _each_ peer
372        let streams =
373            self.peers.iter().map(|handle| NetworkEventStream::new(handle.event_listener()));
374
375        // add all peers to each other
376        for (idx, handle) in self.peers.iter().enumerate().take(self.peers.len() - 1) {
377            for idx in (idx + 1)..self.peers.len() {
378                let neighbour = &self.peers[idx];
379                handle.network.add_peer(*neighbour.peer_id(), neighbour.local_addr());
380            }
381        }
382
383        // await all sessions to be established
384        let num_sessions_per_peer = self.peers.len() - 1;
385        let fut = streams.into_iter().map(|mut stream| async move {
386            stream.take_session_established(num_sessions_per_peer).await
387        });
388
389        futures::future::join_all(fut).await;
390    }
391}
392
393/// A peer in the [`Testnet`].
394#[pin_project]
395#[derive(Debug)]
396pub struct Peer<C, Pool = TestPool> {
397    #[pin]
398    network: NetworkManager<EthNetworkPrimitives>,
399    #[pin]
400    request_handler: Option<EthRequestHandler<C, EthNetworkPrimitives>>,
401    #[pin]
402    transactions_manager: Option<
403        TransactionsManager<
404            Pool,
405            EthNetworkPrimitives,
406            NetworkPolicies<TransactionPropagationKind, StrictEthAnnouncementFilter>,
407        >,
408    >,
409    pool: Option<Pool>,
410    client: C,
411    secret_key: SecretKey,
412}
413
414// === impl Peer ===
415
416impl<C, Pool> Peer<C, Pool>
417where
418    C: BlockReader + HeaderProvider + Clone + 'static,
419    Pool: TransactionPool,
420{
421    /// Returns the number of connected peers.
422    pub fn num_peers(&self) -> usize {
423        self.network.num_connected_peers()
424    }
425
426    /// Adds an additional protocol handler to the peer.
427    pub fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol) {
428        self.network.add_rlpx_sub_protocol(protocol);
429    }
430
431    /// Returns a handle to the peer's network.
432    pub fn peer_handle(&self) -> PeerHandle<Pool> {
433        PeerHandle {
434            network: self.network.handle().clone(),
435            pool: self.pool.clone(),
436            transactions: self.transactions_manager.as_ref().map(|mgr| mgr.handle()),
437        }
438    }
439
440    /// The address that listens for incoming connections.
441    pub const fn local_addr(&self) -> SocketAddr {
442        self.network.local_addr()
443    }
444
445    /// The [`PeerId`] of this peer.
446    pub fn peer_id(&self) -> PeerId {
447        *self.network.peer_id()
448    }
449
450    /// Returns mutable access to the network.
451    pub const fn network_mut(&mut self) -> &mut NetworkManager<EthNetworkPrimitives> {
452        &mut self.network
453    }
454
455    /// Returns the [`NetworkHandle`] of this peer.
456    pub fn handle(&self) -> NetworkHandle<EthNetworkPrimitives> {
457        self.network.handle().clone()
458    }
459
460    /// Returns the [`TestPool`] of this peer.
461    pub const fn pool(&self) -> Option<&Pool> {
462        self.pool.as_ref()
463    }
464
465    /// Set a new request handler that's connected to the peer's network
466    pub fn install_request_handler(&mut self) {
467        let (tx, rx) = channel(ETH_REQUEST_CHANNEL_CAPACITY);
468        self.network.set_eth_request_handler(tx);
469        let peers = self.network.peers_handle();
470        let request_handler = EthRequestHandler::new(self.client.clone(), peers, rx);
471        self.request_handler = Some(request_handler);
472    }
473
474    /// Set a new transactions manager that's connected to the peer's network
475    pub fn install_transactions_manager(&mut self, pool: Pool) {
476        let (tx, rx) = unbounded_channel();
477        self.network.set_transactions(tx);
478        let transactions_manager = TransactionsManager::new(
479            self.handle(),
480            pool.clone(),
481            rx,
482            TransactionsManagerConfig::default(),
483        );
484        self.transactions_manager = Some(transactions_manager);
485        self.pool = Some(pool);
486    }
487
488    /// Set a new transactions manager that's connected to the peer's network
489    pub fn map_transactions_manager<P>(self, pool: P) -> Peer<C, P>
490    where
491        P: TransactionPool,
492    {
493        let Self { mut network, request_handler, client, secret_key, .. } = self;
494        let (tx, rx) = unbounded_channel();
495        network.set_transactions(tx);
496        let transactions_manager = TransactionsManager::new(
497            network.handle().clone(),
498            pool.clone(),
499            rx,
500            TransactionsManagerConfig::default(),
501        );
502        Peer {
503            network,
504            request_handler,
505            transactions_manager: Some(transactions_manager),
506            pool: Some(pool),
507            client,
508            secret_key,
509        }
510    }
511
512    /// Map transactions manager with custom config
513    pub fn map_transactions_manager_with_config<P>(
514        self,
515        pool: P,
516        config: TransactionsManagerConfig,
517    ) -> Peer<C, P>
518    where
519        P: TransactionPool,
520    {
521        self.map_transactions_manager_with(pool, config, Default::default())
522    }
523
524    /// Map transactions manager with custom config and the given policy.
525    pub fn map_transactions_manager_with<P>(
526        self,
527        pool: P,
528        config: TransactionsManagerConfig,
529        policy: TransactionPropagationKind,
530    ) -> Peer<C, P>
531    where
532        P: TransactionPool,
533    {
534        let Self { mut network, request_handler, client, secret_key, .. } = self;
535        let (tx, rx) = unbounded_channel();
536        network.set_transactions(tx);
537
538        let announcement_policy = StrictEthAnnouncementFilter::default();
539        let policies = NetworkPolicies::new(policy, announcement_policy);
540
541        let transactions_manager = TransactionsManager::with_policy(
542            network.handle().clone(),
543            pool.clone(),
544            rx,
545            config,
546            policies,
547        );
548
549        Peer {
550            network,
551            request_handler,
552            transactions_manager: Some(transactions_manager),
553            pool: Some(pool),
554            client,
555            secret_key,
556        }
557    }
558}
559
560impl<C> Peer<C>
561where
562    C: BlockReader + HeaderProvider + Clone + 'static,
563{
564    /// Installs a new [`TestPool`]
565    pub fn install_test_pool(&mut self) {
566        self.install_transactions_manager(TestPoolBuilder::default().into())
567    }
568}
569
570impl<C, Pool> Future for Peer<C, Pool>
571where
572    C: BlockReader<
573            Block = reth_ethereum_primitives::Block,
574            Receipt = reth_ethereum_primitives::Receipt,
575            Header = alloy_consensus::Header,
576        > + HeaderProvider
577        + Unpin
578        + 'static,
579    Pool: TransactionPool<
580            Transaction: PoolTransaction<
581                Consensus = TransactionSigned,
582                Pooled = PooledTransactionVariant,
583            >,
584        > + Unpin
585        + 'static,
586{
587    type Output = ();
588
589    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
590        let this = self.project();
591
592        if let Some(request) = this.request_handler.as_pin_mut() {
593            let _ = request.poll(cx);
594        }
595
596        if let Some(tx_manager) = this.transactions_manager.as_pin_mut() {
597            let _ = tx_manager.poll(cx);
598        }
599
600        this.network.poll(cx)
601    }
602}
603
604/// A helper config for setting up the reth networking stack.
605#[derive(Debug)]
606pub struct PeerConfig<C = NoopProvider> {
607    config: NetworkConfig<C>,
608    client: C,
609    secret_key: SecretKey,
610}
611
612/// A handle to a peer in the [`Testnet`].
613#[derive(Debug)]
614pub struct PeerHandle<Pool> {
615    network: NetworkHandle<EthNetworkPrimitives>,
616    transactions: Option<TransactionsHandle<EthNetworkPrimitives>>,
617    pool: Option<Pool>,
618}
619
620// === impl PeerHandle ===
621
622impl<Pool> PeerHandle<Pool> {
623    /// Returns the [`PeerId`] used in the network.
624    pub fn peer_id(&self) -> &PeerId {
625        self.network.peer_id()
626    }
627
628    /// Returns the [`PeersHandle`] from the network.
629    pub fn peer_handle(&self) -> &PeersHandle {
630        self.network.peers_handle()
631    }
632
633    /// Returns the local socket as configured for the network.
634    pub fn local_addr(&self) -> SocketAddr {
635        self.network.local_addr()
636    }
637
638    /// Creates a new [`NetworkEvent`] listener channel.
639    pub fn event_listener(&self) -> EventStream<NetworkEvent> {
640        self.network.event_listener()
641    }
642
643    /// Returns the [`TransactionsHandle`] of this peer.
644    pub const fn transactions(&self) -> Option<&TransactionsHandle> {
645        self.transactions.as_ref()
646    }
647
648    /// Returns the [`TestPool`] of this peer.
649    pub const fn pool(&self) -> Option<&Pool> {
650        self.pool.as_ref()
651    }
652
653    /// Returns the [`NetworkHandle`] of this peer.
654    pub const fn network(&self) -> &NetworkHandle<EthNetworkPrimitives> {
655        &self.network
656    }
657}
658
659// === impl PeerConfig ===
660
661impl<C> PeerConfig<C>
662where
663    C: BlockReader + HeaderProvider + Clone + 'static,
664{
665    /// Launches the network and returns the [Peer] that manages it
666    pub async fn launch(self) -> Result<Peer<C>, NetworkError> {
667        let Self { config, client, secret_key } = self;
668        let network = NetworkManager::new(config).await?;
669        let peer = Peer {
670            network,
671            client,
672            secret_key,
673            request_handler: None,
674            transactions_manager: None,
675            pool: None,
676        };
677        Ok(peer)
678    }
679
680    /// Initialize the network with a random secret key, allowing the devp2p and discovery to bind
681    /// to any available IP and port.
682    pub fn new(client: C) -> Self
683    where
684        C: ChainSpecProvider<ChainSpec: Hardforks>,
685    {
686        let secret_key = SecretKey::new(&mut rand_08::thread_rng());
687        let config = Self::network_config_builder(secret_key).build(client.clone());
688        Self { config, client, secret_key }
689    }
690
691    /// Initialize the network with a given secret key, allowing devp2p and discovery to bind any
692    /// available IP and port.
693    pub fn with_secret_key(client: C, secret_key: SecretKey) -> Self
694    where
695        C: ChainSpecProvider<ChainSpec: Hardforks>,
696    {
697        let config = Self::network_config_builder(secret_key).build(client.clone());
698        Self { config, client, secret_key }
699    }
700
701    /// Initialize the network with a given capabilities.
702    pub fn with_protocols(client: C, protocols: impl IntoIterator<Item = Protocol>) -> Self
703    where
704        C: ChainSpecProvider<ChainSpec: Hardforks>,
705    {
706        let secret_key = SecretKey::new(&mut rand_08::thread_rng());
707
708        let builder = Self::network_config_builder(secret_key);
709        let hello_message =
710            HelloMessageWithProtocols::builder(builder.get_peer_id()).protocols(protocols).build();
711        let config = builder.hello_message(hello_message).build(client.clone());
712
713        Self { config, client, secret_key }
714    }
715
716    fn network_config_builder(secret_key: SecretKey) -> NetworkConfigBuilder {
717        NetworkConfigBuilder::new(secret_key)
718            .listener_addr(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
719            .discovery_addr(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
720            .disable_dns_discovery()
721            .disable_discv4_discovery()
722    }
723}
724
725impl Default for PeerConfig {
726    fn default() -> Self {
727        Self::new(NoopProvider::default())
728    }
729}
730
731/// A helper type to await network events
732///
733/// This makes it easier to await established connections
734#[derive(Debug)]
735pub struct NetworkEventStream {
736    inner: EventStream<NetworkEvent>,
737}
738
739// === impl NetworkEventStream ===
740
741impl NetworkEventStream {
742    /// Create a new [`NetworkEventStream`] from the given network event receiver stream.
743    pub const fn new(inner: EventStream<NetworkEvent>) -> Self {
744        Self { inner }
745    }
746
747    /// Awaits the next event for a session to be closed
748    pub async fn next_session_closed(&mut self) -> Option<(PeerId, Option<DisconnectReason>)> {
749        while let Some(ev) = self.inner.next().await {
750            if let NetworkEvent::Peer(PeerEvent::SessionClosed { peer_id, reason }) = ev {
751                return Some((peer_id, reason))
752            }
753        }
754        None
755    }
756
757    /// Awaits the next event for an established session
758    pub async fn next_session_established(&mut self) -> Option<PeerId> {
759        while let Some(ev) = self.inner.next().await {
760            match ev {
761                NetworkEvent::ActivePeerSession { info, .. } |
762                NetworkEvent::Peer(PeerEvent::SessionEstablished(info)) => {
763                    return Some(info.peer_id)
764                }
765                _ => {}
766            }
767        }
768        None
769    }
770
771    /// Awaits the next `num` events for an established session
772    pub async fn take_session_established(&mut self, mut num: usize) -> Vec<PeerId> {
773        if num == 0 {
774            return Vec::new();
775        }
776        let mut peers = Vec::with_capacity(num);
777        while let Some(ev) = self.inner.next().await {
778            if let NetworkEvent::ActivePeerSession { info: SessionInfo { peer_id, .. }, .. } = ev {
779                peers.push(peer_id);
780                num -= 1;
781                if num == 0 {
782                    return peers;
783                }
784            }
785        }
786        peers
787    }
788
789    /// Ensures that the first two events are a [`NetworkEvent::Peer`] and
790    /// [`PeerEvent::PeerAdded`][`NetworkEvent::ActivePeerSession`], returning the [`PeerId`] of the
791    /// established session.
792    pub async fn peer_added_and_established(&mut self) -> Option<PeerId> {
793        let peer_id = match self.inner.next().await {
794            Some(NetworkEvent::Peer(PeerEvent::PeerAdded(peer_id))) => peer_id,
795            _ => return None,
796        };
797
798        match self.inner.next().await {
799            Some(NetworkEvent::ActivePeerSession {
800                info: SessionInfo { peer_id: peer_id2, .. },
801                ..
802            }) => {
803                debug_assert_eq!(
804                    peer_id, peer_id2,
805                    "PeerAdded peer_id {peer_id} does not match SessionEstablished peer_id {peer_id2}"
806                );
807                Some(peer_id)
808            }
809            _ => None,
810        }
811    }
812
813    /// Awaits the next event for a peer added.
814    pub async fn peer_added(&mut self) -> Option<PeerId> {
815        let peer_id = match self.inner.next().await {
816            Some(NetworkEvent::Peer(PeerEvent::PeerAdded(peer_id))) => peer_id,
817            _ => return None,
818        };
819
820        Some(peer_id)
821    }
822
823    /// Awaits the next event for a peer removed.
824    pub async fn peer_removed(&mut self) -> Option<PeerId> {
825        let peer_id = match self.inner.next().await {
826            Some(NetworkEvent::Peer(PeerEvent::PeerRemoved(peer_id))) => peer_id,
827            _ => return None,
828        };
829
830        Some(peer_id)
831    }
832}