Skip to main content

reth_downloaders/bodies/
task.rs

1use alloy_primitives::BlockNumber;
2use futures::Stream;
3use futures_util::StreamExt;
4use pin_project::pin_project;
5use reth_network_p2p::{
6    bodies::downloader::{BodyDownloader, BodyDownloaderResult},
7    error::DownloadResult,
8};
9use reth_primitives_traits::Block;
10use reth_tasks::Runtime;
11use std::{
12    fmt::Debug,
13    future::Future,
14    ops::RangeInclusive,
15    pin::Pin,
16    task::{ready, Context, Poll},
17};
18use tokio::sync::{mpsc, mpsc::UnboundedSender};
19use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
20use tokio_util::sync::PollSender;
21
22/// The maximum number of [`BodyDownloaderResult`]s to hold in the buffer.
23pub const BODIES_TASK_BUFFER_SIZE: usize = 4;
24
25/// A [BodyDownloader] that drives a spawned [BodyDownloader] on a spawned task.
26#[derive(Debug)]
27#[pin_project]
28pub struct TaskDownloader<B: Block> {
29    #[pin]
30    from_downloader: ReceiverStream<BodyDownloaderResult<B>>,
31    to_downloader: UnboundedSender<RangeInclusive<BlockNumber>>,
32}
33
34impl<B: Block + 'static> TaskDownloader<B> {
35    /// Spawns the given `downloader` via the given [`Runtime`] and returns a [`TaskDownloader`]
36    /// that's connected to that task.
37    pub fn spawn_with<T>(downloader: T, runtime: &Runtime) -> Self
38    where
39        T: BodyDownloader<Block = B> + 'static,
40    {
41        let (bodies_tx, bodies_rx) = mpsc::channel(BODIES_TASK_BUFFER_SIZE);
42        let (to_downloader, updates_rx) = mpsc::unbounded_channel();
43
44        let downloader = SpawnedDownloader {
45            bodies_tx: PollSender::new(bodies_tx),
46            updates: UnboundedReceiverStream::new(updates_rx),
47            downloader,
48        };
49
50        runtime.spawn_task(downloader);
51
52        Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader }
53    }
54}
55
56impl<B: Block + 'static> BodyDownloader for TaskDownloader<B> {
57    type Block = B;
58
59    fn set_download_range(&mut self, range: RangeInclusive<BlockNumber>) -> DownloadResult<()> {
60        let _ = self.to_downloader.send(range);
61        Ok(())
62    }
63}
64
65impl<B: Block + 'static> Stream for TaskDownloader<B> {
66    type Item = BodyDownloaderResult<B>;
67
68    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69        self.project().from_downloader.poll_next(cx)
70    }
71}
72
73/// A [`BodyDownloader`] that runs on its own task
74struct SpawnedDownloader<T: BodyDownloader> {
75    updates: UnboundedReceiverStream<RangeInclusive<BlockNumber>>,
76    bodies_tx: PollSender<BodyDownloaderResult<T::Block>>,
77    downloader: T,
78}
79
80impl<T: BodyDownloader> Future for SpawnedDownloader<T> {
81    type Output = ();
82
83    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84        let this = self.get_mut();
85
86        loop {
87            while let Poll::Ready(update) = this.updates.poll_next_unpin(cx) {
88                if let Some(range) = update {
89                    if let Err(err) = this.downloader.set_download_range(range) {
90                        tracing::error!(target: "downloaders::bodies", %err, "Failed to set bodies download range");
91
92                        // Clone the sender ensure its availability. See [PollSender::clone].
93                        let mut bodies_tx = this.bodies_tx.clone();
94
95                        let forward_error_result = ready!(bodies_tx.poll_reserve(cx))
96                            .and_then(|_| bodies_tx.send_item(Err(err)));
97                        if forward_error_result.is_err() {
98                            // channel closed, this means [TaskDownloader] was dropped,
99                            // so we can also exit
100                            return Poll::Ready(())
101                        }
102                    }
103                } else {
104                    // channel closed, this means [TaskDownloader] was dropped, so we can also
105                    // exit
106                    return Poll::Ready(())
107                }
108            }
109
110            match ready!(this.bodies_tx.poll_reserve(cx)) {
111                Ok(()) => match ready!(this.downloader.poll_next_unpin(cx)) {
112                    Some(bodies) => {
113                        if this.bodies_tx.send_item(bodies).is_err() {
114                            // channel closed, this means [TaskDownloader] was dropped, so we can
115                            // also exit
116                            return Poll::Ready(())
117                        }
118                    }
119                    None => return Poll::Pending,
120                },
121                Err(_) => {
122                    // channel closed, this means [TaskDownloader] was dropped, so we can also
123                    // exit
124                    return Poll::Ready(())
125                }
126            }
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::{
135        bodies::{
136            bodies::BodiesDownloaderBuilder,
137            test_utils::{insert_headers, zip_blocks},
138        },
139        test_utils::{generate_bodies, TestBodiesClient},
140    };
141    use assert_matches::assert_matches;
142    use reth_consensus::test_utils::TestConsensus;
143    use reth_network_p2p::error::DownloadError;
144    use reth_provider::test_utils::create_test_provider_factory;
145    use std::sync::Arc;
146
147    #[tokio::test(flavor = "multi_thread")]
148    async fn download_one_by_one_on_task() {
149        reth_tracing::init_test_tracing();
150
151        let factory = create_test_provider_factory();
152        let (headers, mut bodies) = generate_bodies(0..=19);
153
154        insert_headers(&factory, &headers);
155
156        let client = Arc::new(
157            TestBodiesClient::default().with_bodies(bodies.clone()).with_should_delay(true),
158        );
159        let downloader = BodiesDownloaderBuilder::default()
160            .build::<reth_ethereum_primitives::Block, _, _>(
161                client.clone(),
162                Arc::new(TestConsensus::default()),
163                factory,
164            );
165        let runtime = Runtime::test();
166        let mut downloader = TaskDownloader::spawn_with(downloader, &runtime);
167
168        downloader.set_download_range(0..=19).expect("failed to set download range");
169
170        assert_matches!(
171            downloader.next().await,
172            Some(Ok(res)) => assert_eq!(res, zip_blocks(headers.iter(), &mut bodies))
173        );
174        assert_eq!(client.times_requested(), 1);
175    }
176
177    #[tokio::test(flavor = "multi_thread")]
178    #[expect(clippy::reversed_empty_ranges)]
179    async fn set_download_range_error_returned() {
180        reth_tracing::init_test_tracing();
181        let factory = create_test_provider_factory();
182
183        let downloader = BodiesDownloaderBuilder::default()
184            .build::<reth_ethereum_primitives::Block, _, _>(
185                Arc::new(TestBodiesClient::default()),
186                Arc::new(TestConsensus::default()),
187                factory,
188            );
189        let runtime = Runtime::test();
190        let mut downloader = TaskDownloader::spawn_with(downloader, &runtime);
191
192        downloader.set_download_range(1..=0).expect("failed to set download range");
193        assert_matches!(downloader.next().await, Some(Err(DownloadError::InvalidBodyRange { .. })));
194    }
195}