reth_network_p2p/test_utils/
headers.rs
1use crate::{
4 download::DownloadClient,
5 error::{DownloadError, DownloadResult, PeerRequestResult, RequestError},
6 headers::{
7 client::{HeadersClient, HeadersRequest},
8 downloader::{HeaderDownloader, SyncTarget},
9 error::HeadersDownloaderResult,
10 },
11 priority::Priority,
12};
13use alloy_consensus::Header;
14use futures::{Future, FutureExt, Stream, StreamExt};
15use reth_consensus::{test_utils::TestConsensus, HeaderValidator};
16use reth_eth_wire_types::HeadersDirection;
17use reth_network_peers::{PeerId, WithPeerId};
18use reth_primitives_traits::SealedHeader;
19use std::{
20 fmt,
21 pin::Pin,
22 sync::{
23 atomic::{AtomicU64, Ordering},
24 Arc,
25 },
26 task::{ready, Context, Poll},
27};
28use tokio::sync::Mutex;
29
30#[derive(Debug)]
32pub struct TestHeaderDownloader {
33 client: TestHeadersClient,
34 consensus: Arc<TestConsensus>,
35 limit: u64,
36 download: Option<TestDownload>,
37 queued_headers: Vec<SealedHeader>,
38 batch_size: usize,
39}
40
41impl TestHeaderDownloader {
42 pub const fn new(
44 client: TestHeadersClient,
45 consensus: Arc<TestConsensus>,
46 limit: u64,
47 batch_size: usize,
48 ) -> Self {
49 Self { client, consensus, limit, download: None, batch_size, queued_headers: Vec::new() }
50 }
51
52 fn create_download(&self) -> TestDownload {
53 TestDownload {
54 client: self.client.clone(),
55 consensus: Arc::clone(&self.consensus),
56 limit: self.limit,
57 fut: None,
58 buffer: vec![],
59 done: false,
60 }
61 }
62}
63
64impl HeaderDownloader for TestHeaderDownloader {
65 type Header = Header;
66
67 fn update_local_head(&mut self, _head: SealedHeader) {}
68
69 fn update_sync_target(&mut self, _target: SyncTarget) {}
70
71 fn set_batch_size(&mut self, limit: usize) {
72 self.batch_size = limit;
73 }
74}
75
76impl Stream for TestHeaderDownloader {
77 type Item = HeadersDownloaderResult<Vec<SealedHeader>, Header>;
78
79 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80 let this = self.get_mut();
81 loop {
82 if this.queued_headers.len() == this.batch_size {
83 return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers))))
84 }
85 if this.download.is_none() {
86 this.download = Some(this.create_download());
87 }
88
89 match ready!(this.download.as_mut().unwrap().poll_next_unpin(cx)) {
90 None => return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers)))),
91 Some(header) => this.queued_headers.push(header.unwrap()),
92 }
93 }
94 }
95}
96
97type TestHeadersFut = Pin<Box<dyn Future<Output = PeerRequestResult<Vec<Header>>> + Sync + Send>>;
98
99struct TestDownload {
100 client: TestHeadersClient,
101 consensus: Arc<TestConsensus>,
102 limit: u64,
103 fut: Option<TestHeadersFut>,
104 buffer: Vec<SealedHeader>,
105 done: bool,
106}
107
108impl TestDownload {
109 fn get_or_init_fut(&mut self) -> &mut TestHeadersFut {
110 if self.fut.is_none() {
111 let request = HeadersRequest {
112 limit: self.limit,
113 direction: HeadersDirection::Rising,
114 start: 0u64.into(), };
116 let client = self.client.clone();
117 self.fut = Some(Box::pin(client.get_headers(request)));
118 }
119 self.fut.as_mut().unwrap()
120 }
121}
122
123impl fmt::Debug for TestDownload {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 f.debug_struct("TestDownload")
126 .field("client", &self.client)
127 .field("consensus", &self.consensus)
128 .field("limit", &self.limit)
129 .field("buffer", &self.buffer)
130 .field("done", &self.done)
131 .finish_non_exhaustive()
132 }
133}
134
135impl Stream for TestDownload {
136 type Item = DownloadResult<SealedHeader>;
137
138 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139 let this = self.get_mut();
140
141 loop {
142 if let Some(header) = this.buffer.pop() {
143 return Poll::Ready(Some(Ok(header)))
144 } else if this.done {
145 return Poll::Ready(None)
146 }
147
148 let empty: SealedHeader = SealedHeader::default();
149 if let Err(error) = this.consensus.validate_header_against_parent(&empty, &empty) {
150 this.done = true;
151 return Poll::Ready(Some(Err(DownloadError::HeaderValidation {
152 hash: empty.hash(),
153 number: empty.number,
154 error: Box::new(error),
155 })))
156 }
157
158 match ready!(this.get_or_init_fut().poll_unpin(cx)) {
159 Ok(resp) => {
160 let mut headers =
162 resp.1.into_iter().skip(1).map(SealedHeader::seal_slow).collect::<Vec<_>>();
163 headers.sort_unstable_by_key(|h| h.number);
164 headers.into_iter().for_each(|h| this.buffer.push(h));
165 this.done = true;
166 }
167 Err(err) => {
168 this.done = true;
169 return Poll::Ready(Some(Err(match err {
170 RequestError::Timeout => DownloadError::Timeout,
171 _ => DownloadError::RequestError(err),
172 })))
173 }
174 }
175 }
176 }
177}
178
179#[derive(Debug, Default, Clone)]
181pub struct TestHeadersClient {
182 responses: Arc<Mutex<Vec<Header>>>,
183 error: Arc<Mutex<Option<RequestError>>>,
184 request_attempts: Arc<AtomicU64>,
185}
186
187impl TestHeadersClient {
188 pub fn request_attempts(&self) -> u64 {
190 self.request_attempts.load(Ordering::SeqCst)
191 }
192
193 pub async fn extend(&self, headers: impl IntoIterator<Item = Header>) {
195 let mut lock = self.responses.lock().await;
196 lock.extend(headers);
197 }
198
199 pub async fn clear(&self) {
201 let mut lock = self.responses.lock().await;
202 lock.clear();
203 }
204
205 pub async fn set_error(&self, err: RequestError) {
207 let mut lock = self.error.lock().await;
208 lock.replace(err);
209 }
210}
211
212impl DownloadClient for TestHeadersClient {
213 fn report_bad_message(&self, _peer_id: PeerId) {
214 }
216
217 fn num_connected_peers(&self) -> usize {
218 0
219 }
220}
221
222impl HeadersClient for TestHeadersClient {
223 type Header = Header;
224 type Output = TestHeadersFut;
225
226 fn get_headers_with_priority(
227 &self,
228 request: HeadersRequest,
229 _priority: Priority,
230 ) -> Self::Output {
231 let responses = self.responses.clone();
232 let error = self.error.clone();
233
234 self.request_attempts.fetch_add(1, Ordering::SeqCst);
235
236 Box::pin(async move {
237 if let Some(err) = &mut *error.lock().await {
238 return Err(err.clone())
239 }
240
241 let mut lock = responses.lock().await;
242 let len = lock.len().min(request.limit as usize);
243 let resp = lock.drain(..len).collect();
244 let with_peer_id = WithPeerId::from((PeerId::default(), resp));
245 Ok(with_peer_id)
246 })
247 }
248}