Skip to main content

reth_downloaders/headers/
task.rs

1use alloy_primitives::Sealable;
2use futures::Stream;
3use futures_util::StreamExt;
4use pin_project::pin_project;
5use reth_network_p2p::headers::{
6    downloader::{HeaderDownloader, SyncTarget},
7    error::HeadersDownloaderResult,
8};
9use reth_primitives_traits::SealedHeader;
10use reth_tasks::Runtime;
11use std::{
12    fmt::Debug,
13    future::Future,
14    pin::Pin,
15    task::{ready, Context, Poll},
16};
17use tokio::sync::{mpsc, mpsc::UnboundedSender};
18use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
19use tokio_util::sync::PollSender;
20
21/// The maximum number of header results to hold in the buffer.
22pub const HEADERS_TASK_BUFFER_SIZE: usize = 8;
23
24/// A [HeaderDownloader] that drives a spawned [HeaderDownloader] on a spawned task.
25#[derive(Debug)]
26#[pin_project]
27pub struct TaskDownloader<H: Sealable> {
28    #[pin]
29    from_downloader: ReceiverStream<HeadersDownloaderResult<Vec<SealedHeader<H>>, H>>,
30    to_downloader: UnboundedSender<DownloaderUpdates<H>>,
31}
32
33// === impl TaskDownloader ===
34
35impl<H: Sealable + Send + Sync + Unpin + 'static> TaskDownloader<H> {
36    /// Spawns the given `downloader` via the given [`Runtime`] and returns a [`TaskDownloader`]
37    /// that's connected to that task.
38    pub fn spawn_with<T>(downloader: T, runtime: &Runtime) -> Self
39    where
40        T: HeaderDownloader<Header = H> + 'static,
41    {
42        let (headers_tx, headers_rx) = mpsc::channel(HEADERS_TASK_BUFFER_SIZE);
43        let (to_downloader, updates_rx) = mpsc::unbounded_channel();
44
45        let downloader = SpawnedDownloader {
46            headers_tx: PollSender::new(headers_tx),
47            updates: UnboundedReceiverStream::new(updates_rx),
48            downloader,
49        };
50        runtime.spawn_task(downloader);
51
52        Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader }
53    }
54}
55
56impl<H: Sealable + Debug + Send + Sync + Unpin + 'static> HeaderDownloader for TaskDownloader<H> {
57    type Header = H;
58
59    fn update_sync_gap(&mut self, head: SealedHeader<H>, target: SyncTarget) {
60        let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncGap(head, target));
61    }
62
63    fn update_local_head(&mut self, head: SealedHeader<H>) {
64        let _ = self.to_downloader.send(DownloaderUpdates::UpdateLocalHead(head));
65    }
66
67    fn update_sync_target(&mut self, target: SyncTarget) {
68        let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncTarget(target));
69    }
70
71    fn set_batch_size(&mut self, limit: usize) {
72        let _ = self.to_downloader.send(DownloaderUpdates::SetBatchSize(limit));
73    }
74}
75
76impl<H: Sealable> Stream for TaskDownloader<H> {
77    type Item = HeadersDownloaderResult<Vec<SealedHeader<H>>, H>;
78
79    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80        self.project().from_downloader.poll_next(cx)
81    }
82}
83
84/// A [`HeaderDownloader`] that runs on its own task
85#[expect(clippy::complexity)]
86struct SpawnedDownloader<T: HeaderDownloader> {
87    updates: UnboundedReceiverStream<DownloaderUpdates<T::Header>>,
88    headers_tx: PollSender<HeadersDownloaderResult<Vec<SealedHeader<T::Header>>, T::Header>>,
89    downloader: T,
90}
91
92impl<T: HeaderDownloader> Future for SpawnedDownloader<T> {
93    type Output = ();
94
95    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
96        let this = self.get_mut();
97
98        loop {
99            loop {
100                match this.updates.poll_next_unpin(cx) {
101                    Poll::Pending => break,
102                    Poll::Ready(None) => {
103                        // channel closed, this means [TaskDownloader] was dropped, so we can also
104                        // exit
105                        return Poll::Ready(())
106                    }
107                    Poll::Ready(Some(update)) => match update {
108                        DownloaderUpdates::UpdateSyncGap(head, target) => {
109                            this.downloader.update_sync_gap(head, target);
110                        }
111                        DownloaderUpdates::UpdateLocalHead(head) => {
112                            this.downloader.update_local_head(head);
113                        }
114                        DownloaderUpdates::UpdateSyncTarget(target) => {
115                            this.downloader.update_sync_target(target);
116                        }
117                        DownloaderUpdates::SetBatchSize(limit) => {
118                            this.downloader.set_batch_size(limit);
119                        }
120                    },
121                }
122            }
123
124            match ready!(this.headers_tx.poll_reserve(cx)) {
125                Ok(()) => {
126                    match ready!(this.downloader.poll_next_unpin(cx)) {
127                        Some(headers) => {
128                            if this.headers_tx.send_item(headers).is_err() {
129                                // channel closed, this means [TaskDownloader] was dropped, so we
130                                // can also exit
131                                return Poll::Ready(())
132                            }
133                        }
134                        None => return Poll::Pending,
135                    }
136                }
137                Err(_) => {
138                    // channel closed, this means [TaskDownloader] was dropped, so
139                    // we can also exit
140                    return Poll::Ready(())
141                }
142            }
143        }
144    }
145}
146
147/// Commands delegated to the spawned [`HeaderDownloader`]
148#[derive(Debug)]
149enum DownloaderUpdates<H> {
150    UpdateSyncGap(SealedHeader<H>, SyncTarget),
151    UpdateLocalHead(SealedHeader<H>),
152    UpdateSyncTarget(SyncTarget),
153    SetBatchSize(usize),
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::headers::{
160        reverse_headers::ReverseHeadersDownloaderBuilder, test_utils::child_header,
161    };
162    use reth_consensus::test_utils::TestConsensus;
163    use reth_network_p2p::test_utils::TestHeadersClient;
164    use std::sync::Arc;
165
166    #[tokio::test(flavor = "multi_thread")]
167    async fn download_one_by_one_on_task() {
168        reth_tracing::init_test_tracing();
169
170        let p3 = SealedHeader::default();
171        let p2 = child_header(&p3);
172        let p1 = child_header(&p2);
173        let p0 = child_header(&p1);
174
175        let client = Arc::new(TestHeadersClient::default());
176        let downloader = ReverseHeadersDownloaderBuilder::default()
177            .stream_batch_size(1)
178            .request_limit(1)
179            .build(Arc::clone(&client), Arc::new(TestConsensus::default()));
180
181        let runtime = Runtime::test();
182        let mut downloader = TaskDownloader::spawn_with(downloader, &runtime);
183        downloader.update_local_head(p3.clone());
184        downloader.update_sync_target(SyncTarget::Tip(p0.hash()));
185
186        client
187            .extend(vec![
188                p0.as_ref().clone(),
189                p1.as_ref().clone(),
190                p2.as_ref().clone(),
191                p3.as_ref().clone(),
192            ])
193            .await;
194
195        let headers = downloader.next().await.unwrap();
196        assert_eq!(headers.unwrap(), vec![p0]);
197
198        let headers = downloader.next().await.unwrap();
199        assert_eq!(headers.unwrap(), vec![p1]);
200        let headers = downloader.next().await.unwrap();
201        assert_eq!(headers.unwrap(), vec![p2]);
202    }
203}