reth_network_p2p/test_utils/
headers.rs1use crate::{
4 download::DownloadClient,
5 error::{DownloadError, DownloadResult, PeerRequestResult, RequestError},
6 headers::{
7 client::{HeadersClient, HeadersRequest},
8 downloader::{HeaderDownloader, SyncTarget},
9 error::HeadersDownloaderResult,
10 },
11 priority::Priority,
12};
13use alloy_consensus::Header;
14use futures::{Future, FutureExt, Stream, StreamExt};
15use reth_eth_wire_types::HeadersDirection;
16use reth_network_peers::{PeerId, WithPeerId};
17use reth_primitives_traits::SealedHeader;
18use std::{
19 fmt,
20 pin::Pin,
21 sync::{
22 atomic::{AtomicU64, Ordering},
23 Arc,
24 },
25 task::{ready, Context, Poll},
26};
27use tokio::sync::Mutex;
28
29#[derive(Debug)]
31pub struct TestHeaderDownloader {
32 client: TestHeadersClient,
33 limit: u64,
34 download: Option<TestDownload>,
35 queued_headers: Vec<SealedHeader>,
36 batch_size: usize,
37}
38
39impl TestHeaderDownloader {
40 pub const fn new(client: TestHeadersClient, limit: u64, batch_size: usize) -> Self {
42 Self { client, limit, download: None, batch_size, queued_headers: Vec::new() }
43 }
44
45 fn create_download(&self) -> TestDownload {
46 TestDownload {
47 client: self.client.clone(),
48 limit: self.limit,
49 fut: None,
50 buffer: vec![],
51 done: false,
52 }
53 }
54}
55
56impl HeaderDownloader for TestHeaderDownloader {
57 type Header = Header;
58
59 fn update_local_head(&mut self, _head: SealedHeader) {}
60
61 fn update_sync_target(&mut self, _target: SyncTarget) {}
62
63 fn set_batch_size(&mut self, limit: usize) {
64 self.batch_size = limit;
65 }
66}
67
68impl Stream for TestHeaderDownloader {
69 type Item = HeadersDownloaderResult<Vec<SealedHeader>, Header>;
70
71 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
72 let this = self.get_mut();
73 loop {
74 if this.queued_headers.len() == this.batch_size {
75 return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers))))
76 }
77 if this.download.is_none() {
78 this.download = Some(this.create_download());
79 }
80
81 match ready!(this.download.as_mut().unwrap().poll_next_unpin(cx)) {
82 None => return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers)))),
83 Some(header) => this.queued_headers.push(header.unwrap()),
84 }
85 }
86 }
87}
88
89type TestHeadersFut = Pin<Box<dyn Future<Output = PeerRequestResult<Vec<Header>>> + Sync + Send>>;
90
91struct TestDownload {
92 client: TestHeadersClient,
93 limit: u64,
94 fut: Option<TestHeadersFut>,
95 buffer: Vec<SealedHeader>,
96 done: bool,
97}
98
99impl TestDownload {
100 fn get_or_init_fut(&mut self) -> &mut TestHeadersFut {
101 if self.fut.is_none() {
102 let request = HeadersRequest {
103 limit: self.limit,
104 direction: HeadersDirection::Rising,
105 start: 0u64.into(), };
107 let client = self.client.clone();
108 self.fut = Some(Box::pin(client.get_headers(request)));
109 }
110 self.fut.as_mut().unwrap()
111 }
112}
113
114impl fmt::Debug for TestDownload {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("TestDownload")
117 .field("client", &self.client)
118 .field("limit", &self.limit)
119 .field("buffer", &self.buffer)
120 .field("done", &self.done)
121 .finish_non_exhaustive()
122 }
123}
124
125impl Stream for TestDownload {
126 type Item = DownloadResult<SealedHeader>;
127
128 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 let this = self.get_mut();
130
131 loop {
132 if let Some(header) = this.buffer.pop() {
133 return Poll::Ready(Some(Ok(header)))
134 } else if this.done {
135 return Poll::Ready(None)
136 }
137
138 match ready!(this.get_or_init_fut().poll_unpin(cx)) {
139 Ok(resp) => {
140 let mut headers =
142 resp.1.into_iter().skip(1).map(SealedHeader::seal_slow).collect::<Vec<_>>();
143 headers.sort_unstable_by_key(|h| h.number);
144 for h in headers {
145 this.buffer.push(h);
146 }
147 this.done = true;
148 }
149 Err(err) => {
150 this.done = true;
151 return Poll::Ready(Some(Err(match err {
152 RequestError::Timeout => DownloadError::Timeout,
153 _ => DownloadError::RequestError(err),
154 })))
155 }
156 }
157 }
158 }
159}
160
161#[derive(Debug, Default, Clone)]
163pub struct TestHeadersClient {
164 responses: Arc<Mutex<Vec<Header>>>,
165 error: Arc<Mutex<Option<RequestError>>>,
166 request_attempts: Arc<AtomicU64>,
167}
168
169impl TestHeadersClient {
170 pub fn request_attempts(&self) -> u64 {
172 self.request_attempts.load(Ordering::SeqCst)
173 }
174
175 pub async fn extend(&self, headers: impl IntoIterator<Item = Header>) {
177 let mut lock = self.responses.lock().await;
178 lock.extend(headers);
179 }
180
181 pub async fn clear(&self) {
183 let mut lock = self.responses.lock().await;
184 lock.clear();
185 }
186
187 pub async fn set_error(&self, err: RequestError) {
189 let mut lock = self.error.lock().await;
190 lock.replace(err);
191 }
192}
193
194impl DownloadClient for TestHeadersClient {
195 fn report_bad_message(&self, _peer_id: PeerId) {
196 }
198
199 fn num_connected_peers(&self) -> usize {
200 0
201 }
202}
203
204impl HeadersClient for TestHeadersClient {
205 type Header = Header;
206 type Output = TestHeadersFut;
207
208 fn get_headers_with_priority(
209 &self,
210 request: HeadersRequest,
211 _priority: Priority,
212 ) -> Self::Output {
213 let responses = self.responses.clone();
214 let error = self.error.clone();
215
216 self.request_attempts.fetch_add(1, Ordering::SeqCst);
217
218 Box::pin(async move {
219 if let Some(err) = &mut *error.lock().await {
220 return Err(err.clone())
221 }
222
223 let mut lock = responses.lock().await;
224 let len = lock.len().min(request.limit as usize);
225 let resp = lock.drain(..len).collect();
226 let with_peer_id = WithPeerId::from((PeerId::default(), resp));
227 Ok(with_peer_id)
228 })
229 }
230}