reth_downloaders/headers/
task.rs

1use alloy_primitives::Sealable;
2use futures::{FutureExt, Stream};
3use futures_util::StreamExt;
4use pin_project::pin_project;
5use reth_network_p2p::headers::{
6    downloader::{HeaderDownloader, SyncTarget},
7    error::HeadersDownloaderResult,
8};
9use reth_primitives::SealedHeader;
10use reth_tasks::{TaskSpawner, TokioTaskExecutor};
11use std::{
12    fmt::Debug,
13    future::Future,
14    pin::Pin,
15    task::{ready, Context, Poll},
16};
17use tokio::sync::{mpsc, mpsc::UnboundedSender};
18use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
19use tokio_util::sync::PollSender;
20
21/// The maximum number of header results to hold in the buffer.
22pub const HEADERS_TASK_BUFFER_SIZE: usize = 8;
23
24/// A [HeaderDownloader] that drives a spawned [HeaderDownloader] on a spawned task.
25#[derive(Debug)]
26#[pin_project]
27pub struct TaskDownloader<H: Sealable> {
28    #[pin]
29    from_downloader: ReceiverStream<HeadersDownloaderResult<Vec<SealedHeader<H>>, H>>,
30    to_downloader: UnboundedSender<DownloaderUpdates<H>>,
31}
32
33// === impl TaskDownloader ===
34
35impl<H: Sealable + Send + Sync + Unpin + 'static> TaskDownloader<H> {
36    /// Spawns the given `downloader` via [`tokio::task::spawn`] and returns a [`TaskDownloader`]
37    /// that's connected to that task.
38    ///
39    /// # Panics
40    ///
41    /// This method panics if called outside of a Tokio runtime
42    ///
43    /// # Example
44    ///
45    /// ```
46    /// # use std::sync::Arc;
47    /// # use reth_downloaders::headers::reverse_headers::ReverseHeadersDownloader;
48    /// # use reth_downloaders::headers::task::TaskDownloader;
49    /// # use reth_consensus::HeaderValidator;
50    /// # use reth_network_p2p::headers::client::HeadersClient;
51    /// # use reth_primitives_traits::BlockHeader;
52    /// # fn t<H: HeadersClient<Header: BlockHeader> + 'static>(consensus:Arc<dyn HeaderValidator<H::Header>>, client: Arc<H>) {
53    ///    let downloader = ReverseHeadersDownloader::<H>::builder().build(
54    ///        client,
55    ///        consensus
56    ///     );
57    ///   let downloader = TaskDownloader::spawn(downloader);
58    /// # }
59    pub fn spawn<T>(downloader: T) -> Self
60    where
61        T: HeaderDownloader<Header = H> + 'static,
62    {
63        Self::spawn_with(downloader, &TokioTaskExecutor::default())
64    }
65
66    /// Spawns the given `downloader` via the given [`TaskSpawner`] returns a [`TaskDownloader`]
67    /// that's connected to that task.
68    pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
69    where
70        T: HeaderDownloader<Header = H> + 'static,
71        S: TaskSpawner,
72    {
73        let (headers_tx, headers_rx) = mpsc::channel(HEADERS_TASK_BUFFER_SIZE);
74        let (to_downloader, updates_rx) = mpsc::unbounded_channel();
75
76        let downloader = SpawnedDownloader {
77            headers_tx: PollSender::new(headers_tx),
78            updates: UnboundedReceiverStream::new(updates_rx),
79            downloader,
80        };
81        spawner.spawn(downloader.boxed());
82
83        Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader }
84    }
85}
86
87impl<H: Sealable + Debug + Send + Sync + Unpin + 'static> HeaderDownloader for TaskDownloader<H> {
88    type Header = H;
89
90    fn update_sync_gap(&mut self, head: SealedHeader<H>, target: SyncTarget) {
91        let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncGap(head, target));
92    }
93
94    fn update_local_head(&mut self, head: SealedHeader<H>) {
95        let _ = self.to_downloader.send(DownloaderUpdates::UpdateLocalHead(head));
96    }
97
98    fn update_sync_target(&mut self, target: SyncTarget) {
99        let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncTarget(target));
100    }
101
102    fn set_batch_size(&mut self, limit: usize) {
103        let _ = self.to_downloader.send(DownloaderUpdates::SetBatchSize(limit));
104    }
105}
106
107impl<H: Sealable> Stream for TaskDownloader<H> {
108    type Item = HeadersDownloaderResult<Vec<SealedHeader<H>>, H>;
109
110    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111        self.project().from_downloader.poll_next(cx)
112    }
113}
114
115/// A [`HeaderDownloader`] that runs on its own task
116#[expect(clippy::complexity)]
117struct SpawnedDownloader<T: HeaderDownloader> {
118    updates: UnboundedReceiverStream<DownloaderUpdates<T::Header>>,
119    headers_tx: PollSender<HeadersDownloaderResult<Vec<SealedHeader<T::Header>>, T::Header>>,
120    downloader: T,
121}
122
123impl<T: HeaderDownloader> Future for SpawnedDownloader<T> {
124    type Output = ();
125
126    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        let this = self.get_mut();
128
129        loop {
130            loop {
131                match this.updates.poll_next_unpin(cx) {
132                    Poll::Pending => break,
133                    Poll::Ready(None) => {
134                        // channel closed, this means [TaskDownloader] was dropped, so we can also
135                        // exit
136                        return Poll::Ready(())
137                    }
138                    Poll::Ready(Some(update)) => match update {
139                        DownloaderUpdates::UpdateSyncGap(head, target) => {
140                            this.downloader.update_sync_gap(head, target);
141                        }
142                        DownloaderUpdates::UpdateLocalHead(head) => {
143                            this.downloader.update_local_head(head);
144                        }
145                        DownloaderUpdates::UpdateSyncTarget(target) => {
146                            this.downloader.update_sync_target(target);
147                        }
148                        DownloaderUpdates::SetBatchSize(limit) => {
149                            this.downloader.set_batch_size(limit);
150                        }
151                    },
152                }
153            }
154
155            match ready!(this.headers_tx.poll_reserve(cx)) {
156                Ok(()) => {
157                    match ready!(this.downloader.poll_next_unpin(cx)) {
158                        Some(headers) => {
159                            if this.headers_tx.send_item(headers).is_err() {
160                                // channel closed, this means [TaskDownloader] was dropped, so we
161                                // can also exit
162                                return Poll::Ready(())
163                            }
164                        }
165                        None => return Poll::Pending,
166                    }
167                }
168                Err(_) => {
169                    // channel closed, this means [TaskDownloader] was dropped, so
170                    // we can also exit
171                    return Poll::Ready(())
172                }
173            }
174        }
175    }
176}
177
178/// Commands delegated to the spawned [`HeaderDownloader`]
179#[derive(Debug)]
180enum DownloaderUpdates<H> {
181    UpdateSyncGap(SealedHeader<H>, SyncTarget),
182    UpdateLocalHead(SealedHeader<H>),
183    UpdateSyncTarget(SyncTarget),
184    SetBatchSize(usize),
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::headers::{
191        reverse_headers::ReverseHeadersDownloaderBuilder, test_utils::child_header,
192    };
193    use reth_consensus::test_utils::TestConsensus;
194    use reth_network_p2p::test_utils::TestHeadersClient;
195    use std::sync::Arc;
196
197    #[tokio::test(flavor = "multi_thread")]
198    async fn download_one_by_one_on_task() {
199        reth_tracing::init_test_tracing();
200
201        let p3 = SealedHeader::default();
202        let p2 = child_header(&p3);
203        let p1 = child_header(&p2);
204        let p0 = child_header(&p1);
205
206        let client = Arc::new(TestHeadersClient::default());
207        let downloader = ReverseHeadersDownloaderBuilder::default()
208            .stream_batch_size(1)
209            .request_limit(1)
210            .build(Arc::clone(&client), Arc::new(TestConsensus::default()));
211
212        let mut downloader = TaskDownloader::spawn(downloader);
213        downloader.update_local_head(p3.clone());
214        downloader.update_sync_target(SyncTarget::Tip(p0.hash()));
215
216        client
217            .extend(vec![
218                p0.as_ref().clone(),
219                p1.as_ref().clone(),
220                p2.as_ref().clone(),
221                p3.as_ref().clone(),
222            ])
223            .await;
224
225        let headers = downloader.next().await.unwrap();
226        assert_eq!(headers, Ok(vec![p0]));
227
228        let headers = downloader.next().await.unwrap();
229        assert_eq!(headers, Ok(vec![p1]));
230        let headers = downloader.next().await.unwrap();
231        assert_eq!(headers, Ok(vec![p2]));
232    }
233}