reth_network_p2p/test_utils/
headers.rs

1//! Testing support for headers related interfaces.
2
3use crate::{
4    download::DownloadClient,
5    error::{DownloadError, DownloadResult, PeerRequestResult, RequestError},
6    headers::{
7        client::{HeadersClient, HeadersRequest},
8        downloader::{HeaderDownloader, SyncTarget},
9        error::HeadersDownloaderResult,
10    },
11    priority::Priority,
12};
13use alloy_consensus::Header;
14use futures::{Future, FutureExt, Stream, StreamExt};
15use reth_eth_wire_types::HeadersDirection;
16use reth_network_peers::{PeerId, WithPeerId};
17use reth_primitives_traits::SealedHeader;
18use std::{
19    fmt,
20    pin::Pin,
21    sync::{
22        atomic::{AtomicU64, Ordering},
23        Arc,
24    },
25    task::{ready, Context, Poll},
26};
27use tokio::sync::Mutex;
28
29/// A test downloader which just returns the values that have been pushed to it.
30#[derive(Debug)]
31pub struct TestHeaderDownloader {
32    client: TestHeadersClient,
33    limit: u64,
34    download: Option<TestDownload>,
35    queued_headers: Vec<SealedHeader>,
36    batch_size: usize,
37}
38
39impl TestHeaderDownloader {
40    /// Instantiates the downloader with the mock responses
41    pub const fn new(client: TestHeadersClient, limit: u64, batch_size: usize) -> Self {
42        Self { client, limit, download: None, batch_size, queued_headers: Vec::new() }
43    }
44
45    fn create_download(&self) -> TestDownload {
46        TestDownload {
47            client: self.client.clone(),
48            limit: self.limit,
49            fut: None,
50            buffer: vec![],
51            done: false,
52        }
53    }
54}
55
56impl HeaderDownloader for TestHeaderDownloader {
57    type Header = Header;
58
59    fn update_local_head(&mut self, _head: SealedHeader) {}
60
61    fn update_sync_target(&mut self, _target: SyncTarget) {}
62
63    fn set_batch_size(&mut self, limit: usize) {
64        self.batch_size = limit;
65    }
66}
67
68impl Stream for TestHeaderDownloader {
69    type Item = HeadersDownloaderResult<Vec<SealedHeader>, Header>;
70
71    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
72        let this = self.get_mut();
73        loop {
74            if this.queued_headers.len() == this.batch_size {
75                return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers))))
76            }
77            if this.download.is_none() {
78                this.download = Some(this.create_download());
79            }
80
81            match ready!(this.download.as_mut().unwrap().poll_next_unpin(cx)) {
82                None => return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers)))),
83                Some(header) => this.queued_headers.push(header.unwrap()),
84            }
85        }
86    }
87}
88
89type TestHeadersFut = Pin<Box<dyn Future<Output = PeerRequestResult<Vec<Header>>> + Sync + Send>>;
90
91struct TestDownload {
92    client: TestHeadersClient,
93    limit: u64,
94    fut: Option<TestHeadersFut>,
95    buffer: Vec<SealedHeader>,
96    done: bool,
97}
98
99impl TestDownload {
100    fn get_or_init_fut(&mut self) -> &mut TestHeadersFut {
101        if self.fut.is_none() {
102            let request = HeadersRequest {
103                limit: self.limit,
104                direction: HeadersDirection::Rising,
105                start: 0u64.into(), // ignored
106            };
107            let client = self.client.clone();
108            self.fut = Some(Box::pin(client.get_headers(request)));
109        }
110        self.fut.as_mut().unwrap()
111    }
112}
113
114impl fmt::Debug for TestDownload {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("TestDownload")
117            .field("client", &self.client)
118            .field("limit", &self.limit)
119            .field("buffer", &self.buffer)
120            .field("done", &self.done)
121            .finish_non_exhaustive()
122    }
123}
124
125impl Stream for TestDownload {
126    type Item = DownloadResult<SealedHeader>;
127
128    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        let this = self.get_mut();
130
131        loop {
132            if let Some(header) = this.buffer.pop() {
133                return Poll::Ready(Some(Ok(header)))
134            } else if this.done {
135                return Poll::Ready(None)
136            }
137
138            match ready!(this.get_or_init_fut().poll_unpin(cx)) {
139                Ok(resp) => {
140                    // Skip head and seal headers
141                    let mut headers =
142                        resp.1.into_iter().skip(1).map(SealedHeader::seal_slow).collect::<Vec<_>>();
143                    headers.sort_unstable_by_key(|h| h.number);
144                    for h in headers {
145                        this.buffer.push(h);
146                    }
147                    this.done = true;
148                }
149                Err(err) => {
150                    this.done = true;
151                    return Poll::Ready(Some(Err(match err {
152                        RequestError::Timeout => DownloadError::Timeout,
153                        _ => DownloadError::RequestError(err),
154                    })))
155                }
156            }
157        }
158    }
159}
160
161/// A test client for fetching headers
162#[derive(Debug, Default, Clone)]
163pub struct TestHeadersClient {
164    responses: Arc<Mutex<Vec<Header>>>,
165    error: Arc<Mutex<Option<RequestError>>>,
166    request_attempts: Arc<AtomicU64>,
167}
168
169impl TestHeadersClient {
170    /// Return the number of times client was polled
171    pub fn request_attempts(&self) -> u64 {
172        self.request_attempts.load(Ordering::SeqCst)
173    }
174
175    /// Adds headers to the set.
176    pub async fn extend(&self, headers: impl IntoIterator<Item = Header>) {
177        let mut lock = self.responses.lock().await;
178        lock.extend(headers);
179    }
180
181    /// Clears the set.
182    pub async fn clear(&self) {
183        let mut lock = self.responses.lock().await;
184        lock.clear();
185    }
186
187    /// Set response error
188    pub async fn set_error(&self, err: RequestError) {
189        let mut lock = self.error.lock().await;
190        lock.replace(err);
191    }
192}
193
194impl DownloadClient for TestHeadersClient {
195    fn report_bad_message(&self, _peer_id: PeerId) {
196        // noop
197    }
198
199    fn num_connected_peers(&self) -> usize {
200        0
201    }
202}
203
204impl HeadersClient for TestHeadersClient {
205    type Header = Header;
206    type Output = TestHeadersFut;
207
208    fn get_headers_with_priority(
209        &self,
210        request: HeadersRequest,
211        _priority: Priority,
212    ) -> Self::Output {
213        let responses = self.responses.clone();
214        let error = self.error.clone();
215
216        self.request_attempts.fetch_add(1, Ordering::SeqCst);
217
218        Box::pin(async move {
219            if let Some(err) = &mut *error.lock().await {
220                return Err(err.clone())
221            }
222
223            let mut lock = responses.lock().await;
224            let len = lock.len().min(request.limit as usize);
225            let resp = lock.drain(..len).collect();
226            let with_peer_id = WithPeerId::from((PeerId::default(), resp));
227            Ok(with_peer_id)
228        })
229    }
230}