reth_downloaders/headers/
task.rs1use alloy_primitives::Sealable;
2use futures::Stream;
3use futures_util::StreamExt;
4use pin_project::pin_project;
5use reth_network_p2p::headers::{
6 downloader::{HeaderDownloader, SyncTarget},
7 error::HeadersDownloaderResult,
8};
9use reth_primitives_traits::SealedHeader;
10use reth_tasks::Runtime;
11use std::{
12 fmt::Debug,
13 future::Future,
14 pin::Pin,
15 task::{ready, Context, Poll},
16};
17use tokio::sync::{mpsc, mpsc::UnboundedSender};
18use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
19use tokio_util::sync::PollSender;
20
21pub const HEADERS_TASK_BUFFER_SIZE: usize = 8;
23
24#[derive(Debug)]
26#[pin_project]
27pub struct TaskDownloader<H: Sealable> {
28 #[pin]
29 from_downloader: ReceiverStream<HeadersDownloaderResult<Vec<SealedHeader<H>>, H>>,
30 to_downloader: UnboundedSender<DownloaderUpdates<H>>,
31}
32
33impl<H: Sealable + Send + Sync + Unpin + 'static> TaskDownloader<H> {
36 pub fn spawn_with<T>(downloader: T, runtime: &Runtime) -> Self
39 where
40 T: HeaderDownloader<Header = H> + 'static,
41 {
42 let (headers_tx, headers_rx) = mpsc::channel(HEADERS_TASK_BUFFER_SIZE);
43 let (to_downloader, updates_rx) = mpsc::unbounded_channel();
44
45 let downloader = SpawnedDownloader {
46 headers_tx: PollSender::new(headers_tx),
47 updates: UnboundedReceiverStream::new(updates_rx),
48 downloader,
49 };
50 runtime.spawn_task(downloader);
51
52 Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader }
53 }
54}
55
56impl<H: Sealable + Debug + Send + Sync + Unpin + 'static> HeaderDownloader for TaskDownloader<H> {
57 type Header = H;
58
59 fn update_sync_gap(&mut self, head: SealedHeader<H>, target: SyncTarget) {
60 let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncGap(head, target));
61 }
62
63 fn update_local_head(&mut self, head: SealedHeader<H>) {
64 let _ = self.to_downloader.send(DownloaderUpdates::UpdateLocalHead(head));
65 }
66
67 fn update_sync_target(&mut self, target: SyncTarget) {
68 let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncTarget(target));
69 }
70
71 fn set_batch_size(&mut self, limit: usize) {
72 let _ = self.to_downloader.send(DownloaderUpdates::SetBatchSize(limit));
73 }
74}
75
76impl<H: Sealable> Stream for TaskDownloader<H> {
77 type Item = HeadersDownloaderResult<Vec<SealedHeader<H>>, H>;
78
79 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80 self.project().from_downloader.poll_next(cx)
81 }
82}
83
84#[expect(clippy::complexity)]
86struct SpawnedDownloader<T: HeaderDownloader> {
87 updates: UnboundedReceiverStream<DownloaderUpdates<T::Header>>,
88 headers_tx: PollSender<HeadersDownloaderResult<Vec<SealedHeader<T::Header>>, T::Header>>,
89 downloader: T,
90}
91
92impl<T: HeaderDownloader> Future for SpawnedDownloader<T> {
93 type Output = ();
94
95 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
96 let this = self.get_mut();
97
98 loop {
99 loop {
100 match this.updates.poll_next_unpin(cx) {
101 Poll::Pending => break,
102 Poll::Ready(None) => {
103 return Poll::Ready(())
106 }
107 Poll::Ready(Some(update)) => match update {
108 DownloaderUpdates::UpdateSyncGap(head, target) => {
109 this.downloader.update_sync_gap(head, target);
110 }
111 DownloaderUpdates::UpdateLocalHead(head) => {
112 this.downloader.update_local_head(head);
113 }
114 DownloaderUpdates::UpdateSyncTarget(target) => {
115 this.downloader.update_sync_target(target);
116 }
117 DownloaderUpdates::SetBatchSize(limit) => {
118 this.downloader.set_batch_size(limit);
119 }
120 },
121 }
122 }
123
124 match ready!(this.headers_tx.poll_reserve(cx)) {
125 Ok(()) => {
126 match ready!(this.downloader.poll_next_unpin(cx)) {
127 Some(headers) => {
128 if this.headers_tx.send_item(headers).is_err() {
129 return Poll::Ready(())
132 }
133 }
134 None => return Poll::Pending,
135 }
136 }
137 Err(_) => {
138 return Poll::Ready(())
141 }
142 }
143 }
144 }
145}
146
147#[derive(Debug)]
149enum DownloaderUpdates<H> {
150 UpdateSyncGap(SealedHeader<H>, SyncTarget),
151 UpdateLocalHead(SealedHeader<H>),
152 UpdateSyncTarget(SyncTarget),
153 SetBatchSize(usize),
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::headers::{
160 reverse_headers::ReverseHeadersDownloaderBuilder, test_utils::child_header,
161 };
162 use reth_consensus::test_utils::TestConsensus;
163 use reth_network_p2p::test_utils::TestHeadersClient;
164 use std::sync::Arc;
165
166 #[tokio::test(flavor = "multi_thread")]
167 async fn download_one_by_one_on_task() {
168 reth_tracing::init_test_tracing();
169
170 let p3 = SealedHeader::default();
171 let p2 = child_header(&p3);
172 let p1 = child_header(&p2);
173 let p0 = child_header(&p1);
174
175 let client = Arc::new(TestHeadersClient::default());
176 let downloader = ReverseHeadersDownloaderBuilder::default()
177 .stream_batch_size(1)
178 .request_limit(1)
179 .build(Arc::clone(&client), Arc::new(TestConsensus::default()));
180
181 let runtime = Runtime::test();
182 let mut downloader = TaskDownloader::spawn_with(downloader, &runtime);
183 downloader.update_local_head(p3.clone());
184 downloader.update_sync_target(SyncTarget::Tip(p0.hash()));
185
186 client
187 .extend(vec![
188 p0.as_ref().clone(),
189 p1.as_ref().clone(),
190 p2.as_ref().clone(),
191 p3.as_ref().clone(),
192 ])
193 .await;
194
195 let headers = downloader.next().await.unwrap();
196 assert_eq!(headers.unwrap(), vec![p0]);
197
198 let headers = downloader.next().await.unwrap();
199 assert_eq!(headers.unwrap(), vec![p1]);
200 let headers = downloader.next().await.unwrap();
201 assert_eq!(headers.unwrap(), vec![p2]);
202 }
203}