reth_downloaders/bodies/
task.rs

1use alloy_primitives::BlockNumber;
2use futures::Stream;
3use futures_util::{FutureExt, StreamExt};
4use pin_project::pin_project;
5use reth_network_p2p::{
6    bodies::downloader::{BodyDownloader, BodyDownloaderResult},
7    error::DownloadResult,
8};
9use reth_primitives_traits::Block;
10use reth_tasks::{TaskSpawner, TokioTaskExecutor};
11use std::{
12    fmt::Debug,
13    future::Future,
14    ops::RangeInclusive,
15    pin::Pin,
16    task::{ready, Context, Poll},
17};
18use tokio::sync::{mpsc, mpsc::UnboundedSender};
19use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
20use tokio_util::sync::PollSender;
21
22/// The maximum number of [`BodyDownloaderResult`]s to hold in the buffer.
23pub const BODIES_TASK_BUFFER_SIZE: usize = 4;
24
25/// A [BodyDownloader] that drives a spawned [BodyDownloader] on a spawned task.
26#[derive(Debug)]
27#[pin_project]
28pub struct TaskDownloader<B: Block> {
29    #[pin]
30    from_downloader: ReceiverStream<BodyDownloaderResult<B>>,
31    to_downloader: UnboundedSender<RangeInclusive<BlockNumber>>,
32}
33
34impl<B: Block + 'static> TaskDownloader<B> {
35    /// Spawns the given `downloader` via [`tokio::task::spawn`] returns a [`TaskDownloader`] that's
36    /// connected to that task.
37    ///
38    /// # Panics
39    ///
40    /// This method panics if called outside of a Tokio runtime
41    ///
42    /// # Example
43    ///
44    /// ```
45    /// use reth_consensus::{Consensus, ConsensusError};
46    /// use reth_downloaders::bodies::{bodies::BodiesDownloaderBuilder, task::TaskDownloader};
47    /// use reth_network_p2p::bodies::client::BodiesClient;
48    /// use reth_primitives_traits::{Block, InMemorySize};
49    /// use reth_storage_api::HeaderProvider;
50    /// use std::{fmt::Debug, sync::Arc};
51    ///
52    /// fn t<
53    ///     B: Block + 'static,
54    ///     C: BodiesClient<Body = B::Body> + 'static,
55    ///     Provider: HeaderProvider<Header = B::Header> + Unpin + 'static,
56    /// >(
57    ///     client: Arc<C>,
58    ///     consensus: Arc<dyn Consensus<B, Error = ConsensusError>>,
59    ///     provider: Provider,
60    /// ) {
61    ///     let downloader =
62    ///         BodiesDownloaderBuilder::default().build::<B, _, _>(client, consensus, provider);
63    ///     let downloader = TaskDownloader::spawn(downloader);
64    /// }
65    /// ```
66    pub fn spawn<T>(downloader: T) -> Self
67    where
68        T: BodyDownloader<Block = B> + 'static,
69    {
70        Self::spawn_with(downloader, &TokioTaskExecutor::default())
71    }
72
73    /// Spawns the given `downloader` via the given [`TaskSpawner`] returns a [`TaskDownloader`]
74    /// that's connected to that task.
75    pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
76    where
77        T: BodyDownloader<Block = B> + 'static,
78        S: TaskSpawner,
79    {
80        let (bodies_tx, bodies_rx) = mpsc::channel(BODIES_TASK_BUFFER_SIZE);
81        let (to_downloader, updates_rx) = mpsc::unbounded_channel();
82
83        let downloader = SpawnedDownloader {
84            bodies_tx: PollSender::new(bodies_tx),
85            updates: UnboundedReceiverStream::new(updates_rx),
86            downloader,
87        };
88
89        spawner.spawn(downloader.boxed());
90
91        Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader }
92    }
93}
94
95impl<B: Block + 'static> BodyDownloader for TaskDownloader<B> {
96    type Block = B;
97
98    fn set_download_range(&mut self, range: RangeInclusive<BlockNumber>) -> DownloadResult<()> {
99        let _ = self.to_downloader.send(range);
100        Ok(())
101    }
102}
103
104impl<B: Block + 'static> Stream for TaskDownloader<B> {
105    type Item = BodyDownloaderResult<B>;
106
107    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108        self.project().from_downloader.poll_next(cx)
109    }
110}
111
112/// A [`BodyDownloader`] that runs on its own task
113struct SpawnedDownloader<T: BodyDownloader> {
114    updates: UnboundedReceiverStream<RangeInclusive<BlockNumber>>,
115    bodies_tx: PollSender<BodyDownloaderResult<T::Block>>,
116    downloader: T,
117}
118
119impl<T: BodyDownloader> Future for SpawnedDownloader<T> {
120    type Output = ();
121
122    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123        let this = self.get_mut();
124
125        loop {
126            while let Poll::Ready(update) = this.updates.poll_next_unpin(cx) {
127                if let Some(range) = update {
128                    if let Err(err) = this.downloader.set_download_range(range) {
129                        tracing::error!(target: "downloaders::bodies", %err, "Failed to set bodies download range");
130
131                        // Clone the sender ensure its availability. See [PollSender::clone].
132                        let mut bodies_tx = this.bodies_tx.clone();
133
134                        let forward_error_result = ready!(bodies_tx.poll_reserve(cx))
135                            .and_then(|_| bodies_tx.send_item(Err(err)));
136                        if forward_error_result.is_err() {
137                            // channel closed, this means [TaskDownloader] was dropped,
138                            // so we can also exit
139                            return Poll::Ready(())
140                        }
141                    }
142                } else {
143                    // channel closed, this means [TaskDownloader] was dropped, so we can also
144                    // exit
145                    return Poll::Ready(())
146                }
147            }
148
149            match ready!(this.bodies_tx.poll_reserve(cx)) {
150                Ok(()) => match ready!(this.downloader.poll_next_unpin(cx)) {
151                    Some(bodies) => {
152                        if this.bodies_tx.send_item(bodies).is_err() {
153                            // channel closed, this means [TaskDownloader] was dropped, so we can
154                            // also exit
155                            return Poll::Ready(())
156                        }
157                    }
158                    None => return Poll::Pending,
159                },
160                Err(_) => {
161                    // channel closed, this means [TaskDownloader] was dropped, so we can also
162                    // exit
163                    return Poll::Ready(())
164                }
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::{
174        bodies::{
175            bodies::BodiesDownloaderBuilder,
176            test_utils::{insert_headers, zip_blocks},
177        },
178        test_utils::{generate_bodies, TestBodiesClient},
179    };
180    use assert_matches::assert_matches;
181    use reth_consensus::test_utils::TestConsensus;
182    use reth_network_p2p::error::DownloadError;
183    use reth_provider::test_utils::create_test_provider_factory;
184    use std::sync::Arc;
185
186    #[tokio::test(flavor = "multi_thread")]
187    async fn download_one_by_one_on_task() {
188        reth_tracing::init_test_tracing();
189
190        let factory = create_test_provider_factory();
191        let (headers, mut bodies) = generate_bodies(0..=19);
192
193        insert_headers(factory.db_ref().db(), &headers);
194
195        let client = Arc::new(
196            TestBodiesClient::default().with_bodies(bodies.clone()).with_should_delay(true),
197        );
198        let downloader = BodiesDownloaderBuilder::default().build::<reth_primitives::Block, _, _>(
199            client.clone(),
200            Arc::new(TestConsensus::default()),
201            factory,
202        );
203        let mut downloader = TaskDownloader::spawn(downloader);
204
205        downloader.set_download_range(0..=19).expect("failed to set download range");
206
207        assert_matches!(
208            downloader.next().await,
209            Some(Ok(res)) => assert_eq!(res, zip_blocks(headers.iter(), &mut bodies))
210        );
211        assert_eq!(client.times_requested(), 1);
212    }
213
214    #[tokio::test(flavor = "multi_thread")]
215    #[allow(clippy::reversed_empty_ranges)]
216    async fn set_download_range_error_returned() {
217        reth_tracing::init_test_tracing();
218        let factory = create_test_provider_factory();
219
220        let downloader = BodiesDownloaderBuilder::default().build::<reth_primitives::Block, _, _>(
221            Arc::new(TestBodiesClient::default()),
222            Arc::new(TestConsensus::default()),
223            factory,
224        );
225        let mut downloader = TaskDownloader::spawn(downloader);
226
227        downloader.set_download_range(1..=0).expect("failed to set download range");
228        assert_matches!(downloader.next().await, Some(Err(DownloadError::InvalidBodyRange { .. })));
229    }
230}