reth_downloaders/headers/
task.rs
1use alloy_primitives::Sealable;
2use futures::{FutureExt, Stream};
3use futures_util::StreamExt;
4use pin_project::pin_project;
5use reth_network_p2p::headers::{
6 downloader::{HeaderDownloader, SyncTarget},
7 error::HeadersDownloaderResult,
8};
9use reth_primitives::SealedHeader;
10use reth_tasks::{TaskSpawner, TokioTaskExecutor};
11use std::{
12 fmt::Debug,
13 future::Future,
14 pin::Pin,
15 task::{ready, Context, Poll},
16};
17use tokio::sync::{mpsc, mpsc::UnboundedSender};
18use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
19use tokio_util::sync::PollSender;
20
21pub const HEADERS_TASK_BUFFER_SIZE: usize = 8;
23
24#[derive(Debug)]
26#[pin_project]
27pub struct TaskDownloader<H: Sealable> {
28 #[pin]
29 from_downloader: ReceiverStream<HeadersDownloaderResult<Vec<SealedHeader<H>>, H>>,
30 to_downloader: UnboundedSender<DownloaderUpdates<H>>,
31}
32
33impl<H: Sealable + Send + Sync + Unpin + 'static> TaskDownloader<H> {
36 pub fn spawn<T>(downloader: T) -> Self
60 where
61 T: HeaderDownloader<Header = H> + 'static,
62 {
63 Self::spawn_with(downloader, &TokioTaskExecutor::default())
64 }
65
66 pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
69 where
70 T: HeaderDownloader<Header = H> + 'static,
71 S: TaskSpawner,
72 {
73 let (headers_tx, headers_rx) = mpsc::channel(HEADERS_TASK_BUFFER_SIZE);
74 let (to_downloader, updates_rx) = mpsc::unbounded_channel();
75
76 let downloader = SpawnedDownloader {
77 headers_tx: PollSender::new(headers_tx),
78 updates: UnboundedReceiverStream::new(updates_rx),
79 downloader,
80 };
81 spawner.spawn(downloader.boxed());
82
83 Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader }
84 }
85}
86
87impl<H: Sealable + Debug + Send + Sync + Unpin + 'static> HeaderDownloader for TaskDownloader<H> {
88 type Header = H;
89
90 fn update_sync_gap(&mut self, head: SealedHeader<H>, target: SyncTarget) {
91 let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncGap(head, target));
92 }
93
94 fn update_local_head(&mut self, head: SealedHeader<H>) {
95 let _ = self.to_downloader.send(DownloaderUpdates::UpdateLocalHead(head));
96 }
97
98 fn update_sync_target(&mut self, target: SyncTarget) {
99 let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncTarget(target));
100 }
101
102 fn set_batch_size(&mut self, limit: usize) {
103 let _ = self.to_downloader.send(DownloaderUpdates::SetBatchSize(limit));
104 }
105}
106
107impl<H: Sealable> Stream for TaskDownloader<H> {
108 type Item = HeadersDownloaderResult<Vec<SealedHeader<H>>, H>;
109
110 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 self.project().from_downloader.poll_next(cx)
112 }
113}
114
115#[expect(clippy::complexity)]
117struct SpawnedDownloader<T: HeaderDownloader> {
118 updates: UnboundedReceiverStream<DownloaderUpdates<T::Header>>,
119 headers_tx: PollSender<HeadersDownloaderResult<Vec<SealedHeader<T::Header>>, T::Header>>,
120 downloader: T,
121}
122
123impl<T: HeaderDownloader> Future for SpawnedDownloader<T> {
124 type Output = ();
125
126 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 let this = self.get_mut();
128
129 loop {
130 loop {
131 match this.updates.poll_next_unpin(cx) {
132 Poll::Pending => break,
133 Poll::Ready(None) => {
134 return Poll::Ready(())
137 }
138 Poll::Ready(Some(update)) => match update {
139 DownloaderUpdates::UpdateSyncGap(head, target) => {
140 this.downloader.update_sync_gap(head, target);
141 }
142 DownloaderUpdates::UpdateLocalHead(head) => {
143 this.downloader.update_local_head(head);
144 }
145 DownloaderUpdates::UpdateSyncTarget(target) => {
146 this.downloader.update_sync_target(target);
147 }
148 DownloaderUpdates::SetBatchSize(limit) => {
149 this.downloader.set_batch_size(limit);
150 }
151 },
152 }
153 }
154
155 match ready!(this.headers_tx.poll_reserve(cx)) {
156 Ok(()) => {
157 match ready!(this.downloader.poll_next_unpin(cx)) {
158 Some(headers) => {
159 if this.headers_tx.send_item(headers).is_err() {
160 return Poll::Ready(())
163 }
164 }
165 None => return Poll::Pending,
166 }
167 }
168 Err(_) => {
169 return Poll::Ready(())
172 }
173 }
174 }
175 }
176}
177
178#[derive(Debug)]
180enum DownloaderUpdates<H> {
181 UpdateSyncGap(SealedHeader<H>, SyncTarget),
182 UpdateLocalHead(SealedHeader<H>),
183 UpdateSyncTarget(SyncTarget),
184 SetBatchSize(usize),
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::headers::{
191 reverse_headers::ReverseHeadersDownloaderBuilder, test_utils::child_header,
192 };
193 use reth_consensus::test_utils::TestConsensus;
194 use reth_network_p2p::test_utils::TestHeadersClient;
195 use std::sync::Arc;
196
197 #[tokio::test(flavor = "multi_thread")]
198 async fn download_one_by_one_on_task() {
199 reth_tracing::init_test_tracing();
200
201 let p3 = SealedHeader::default();
202 let p2 = child_header(&p3);
203 let p1 = child_header(&p2);
204 let p0 = child_header(&p1);
205
206 let client = Arc::new(TestHeadersClient::default());
207 let downloader = ReverseHeadersDownloaderBuilder::default()
208 .stream_batch_size(1)
209 .request_limit(1)
210 .build(Arc::clone(&client), Arc::new(TestConsensus::default()));
211
212 let mut downloader = TaskDownloader::spawn(downloader);
213 downloader.update_local_head(p3.clone());
214 downloader.update_sync_target(SyncTarget::Tip(p0.hash()));
215
216 client
217 .extend(vec![
218 p0.as_ref().clone(),
219 p1.as_ref().clone(),
220 p2.as_ref().clone(),
221 p3.as_ref().clone(),
222 ])
223 .await;
224
225 let headers = downloader.next().await.unwrap();
226 assert_eq!(headers, Ok(vec![p0]));
227
228 let headers = downloader.next().await.unwrap();
229 assert_eq!(headers, Ok(vec![p1]));
230 let headers = downloader.next().await.unwrap();
231 assert_eq!(headers, Ok(vec![p2]));
232 }
233}