reth_network_p2p/test_utils/
headers.rs

1//! Testing support for headers related interfaces.
2
3use crate::{
4    download::DownloadClient,
5    error::{DownloadError, DownloadResult, PeerRequestResult, RequestError},
6    headers::{
7        client::{HeadersClient, HeadersRequest},
8        downloader::{HeaderDownloader, SyncTarget},
9        error::HeadersDownloaderResult,
10    },
11    priority::Priority,
12};
13use alloy_consensus::Header;
14use futures::{Future, FutureExt, Stream, StreamExt};
15use reth_consensus::{test_utils::TestConsensus, HeaderValidator};
16use reth_eth_wire_types::HeadersDirection;
17use reth_network_peers::{PeerId, WithPeerId};
18use reth_primitives_traits::SealedHeader;
19use std::{
20    fmt,
21    pin::Pin,
22    sync::{
23        atomic::{AtomicU64, Ordering},
24        Arc,
25    },
26    task::{ready, Context, Poll},
27};
28use tokio::sync::Mutex;
29
30/// A test downloader which just returns the values that have been pushed to it.
31#[derive(Debug)]
32pub struct TestHeaderDownloader {
33    client: TestHeadersClient,
34    consensus: Arc<TestConsensus>,
35    limit: u64,
36    download: Option<TestDownload>,
37    queued_headers: Vec<SealedHeader>,
38    batch_size: usize,
39}
40
41impl TestHeaderDownloader {
42    /// Instantiates the downloader with the mock responses
43    pub const fn new(
44        client: TestHeadersClient,
45        consensus: Arc<TestConsensus>,
46        limit: u64,
47        batch_size: usize,
48    ) -> Self {
49        Self { client, consensus, limit, download: None, batch_size, queued_headers: Vec::new() }
50    }
51
52    fn create_download(&self) -> TestDownload {
53        TestDownload {
54            client: self.client.clone(),
55            consensus: Arc::clone(&self.consensus),
56            limit: self.limit,
57            fut: None,
58            buffer: vec![],
59            done: false,
60        }
61    }
62}
63
64impl HeaderDownloader for TestHeaderDownloader {
65    type Header = Header;
66
67    fn update_local_head(&mut self, _head: SealedHeader) {}
68
69    fn update_sync_target(&mut self, _target: SyncTarget) {}
70
71    fn set_batch_size(&mut self, limit: usize) {
72        self.batch_size = limit;
73    }
74}
75
76impl Stream for TestHeaderDownloader {
77    type Item = HeadersDownloaderResult<Vec<SealedHeader>, Header>;
78
79    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80        let this = self.get_mut();
81        loop {
82            if this.queued_headers.len() == this.batch_size {
83                return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers))))
84            }
85            if this.download.is_none() {
86                this.download = Some(this.create_download());
87            }
88
89            match ready!(this.download.as_mut().unwrap().poll_next_unpin(cx)) {
90                None => return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers)))),
91                Some(header) => this.queued_headers.push(header.unwrap()),
92            }
93        }
94    }
95}
96
97type TestHeadersFut = Pin<Box<dyn Future<Output = PeerRequestResult<Vec<Header>>> + Sync + Send>>;
98
99struct TestDownload {
100    client: TestHeadersClient,
101    consensus: Arc<TestConsensus>,
102    limit: u64,
103    fut: Option<TestHeadersFut>,
104    buffer: Vec<SealedHeader>,
105    done: bool,
106}
107
108impl TestDownload {
109    fn get_or_init_fut(&mut self) -> &mut TestHeadersFut {
110        if self.fut.is_none() {
111            let request = HeadersRequest {
112                limit: self.limit,
113                direction: HeadersDirection::Rising,
114                start: 0u64.into(), // ignored
115            };
116            let client = self.client.clone();
117            self.fut = Some(Box::pin(client.get_headers(request)));
118        }
119        self.fut.as_mut().unwrap()
120    }
121}
122
123impl fmt::Debug for TestDownload {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        f.debug_struct("TestDownload")
126            .field("client", &self.client)
127            .field("consensus", &self.consensus)
128            .field("limit", &self.limit)
129            .field("buffer", &self.buffer)
130            .field("done", &self.done)
131            .finish_non_exhaustive()
132    }
133}
134
135impl Stream for TestDownload {
136    type Item = DownloadResult<SealedHeader>;
137
138    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139        let this = self.get_mut();
140
141        loop {
142            if let Some(header) = this.buffer.pop() {
143                return Poll::Ready(Some(Ok(header)))
144            } else if this.done {
145                return Poll::Ready(None)
146            }
147
148            let empty: SealedHeader = SealedHeader::default();
149            if let Err(error) = this.consensus.validate_header_against_parent(&empty, &empty) {
150                this.done = true;
151                return Poll::Ready(Some(Err(DownloadError::HeaderValidation {
152                    hash: empty.hash(),
153                    number: empty.number,
154                    error: Box::new(error),
155                })))
156            }
157
158            match ready!(this.get_or_init_fut().poll_unpin(cx)) {
159                Ok(resp) => {
160                    // Skip head and seal headers
161                    let mut headers =
162                        resp.1.into_iter().skip(1).map(SealedHeader::seal_slow).collect::<Vec<_>>();
163                    headers.sort_unstable_by_key(|h| h.number);
164                    headers.into_iter().for_each(|h| this.buffer.push(h));
165                    this.done = true;
166                }
167                Err(err) => {
168                    this.done = true;
169                    return Poll::Ready(Some(Err(match err {
170                        RequestError::Timeout => DownloadError::Timeout,
171                        _ => DownloadError::RequestError(err),
172                    })))
173                }
174            }
175        }
176    }
177}
178
179/// A test client for fetching headers
180#[derive(Debug, Default, Clone)]
181pub struct TestHeadersClient {
182    responses: Arc<Mutex<Vec<Header>>>,
183    error: Arc<Mutex<Option<RequestError>>>,
184    request_attempts: Arc<AtomicU64>,
185}
186
187impl TestHeadersClient {
188    /// Return the number of times client was polled
189    pub fn request_attempts(&self) -> u64 {
190        self.request_attempts.load(Ordering::SeqCst)
191    }
192
193    /// Adds headers to the set.
194    pub async fn extend(&self, headers: impl IntoIterator<Item = Header>) {
195        let mut lock = self.responses.lock().await;
196        lock.extend(headers);
197    }
198
199    /// Clears the set.
200    pub async fn clear(&self) {
201        let mut lock = self.responses.lock().await;
202        lock.clear();
203    }
204
205    /// Set response error
206    pub async fn set_error(&self, err: RequestError) {
207        let mut lock = self.error.lock().await;
208        lock.replace(err);
209    }
210}
211
212impl DownloadClient for TestHeadersClient {
213    fn report_bad_message(&self, _peer_id: PeerId) {
214        // noop
215    }
216
217    fn num_connected_peers(&self) -> usize {
218        0
219    }
220}
221
222impl HeadersClient for TestHeadersClient {
223    type Header = Header;
224    type Output = TestHeadersFut;
225
226    fn get_headers_with_priority(
227        &self,
228        request: HeadersRequest,
229        _priority: Priority,
230    ) -> Self::Output {
231        let responses = self.responses.clone();
232        let error = self.error.clone();
233
234        self.request_attempts.fetch_add(1, Ordering::SeqCst);
235
236        Box::pin(async move {
237            if let Some(err) = &mut *error.lock().await {
238                return Err(err.clone())
239            }
240
241            let mut lock = responses.lock().await;
242            let len = lock.len().min(request.limit as usize);
243            let resp = lock.drain(..len).collect();
244            let with_peer_id = WithPeerId::from((PeerId::default(), resp));
245            Ok(with_peer_id)
246        })
247    }
248}