reth_downloaders/bodies/
task.rs1use alloy_primitives::BlockNumber;
2use futures::Stream;
3use futures_util::StreamExt;
4use pin_project::pin_project;
5use reth_network_p2p::{
6 bodies::downloader::{BodyDownloader, BodyDownloaderResult},
7 error::DownloadResult,
8};
9use reth_primitives_traits::Block;
10use reth_tasks::Runtime;
11use std::{
12 fmt::Debug,
13 future::Future,
14 ops::RangeInclusive,
15 pin::Pin,
16 task::{ready, Context, Poll},
17};
18use tokio::sync::{mpsc, mpsc::UnboundedSender};
19use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
20use tokio_util::sync::PollSender;
21
22pub const BODIES_TASK_BUFFER_SIZE: usize = 4;
24
25#[derive(Debug)]
27#[pin_project]
28pub struct TaskDownloader<B: Block> {
29 #[pin]
30 from_downloader: ReceiverStream<BodyDownloaderResult<B>>,
31 to_downloader: UnboundedSender<RangeInclusive<BlockNumber>>,
32}
33
34impl<B: Block + 'static> TaskDownloader<B> {
35 pub fn spawn_with<T>(downloader: T, runtime: &Runtime) -> Self
38 where
39 T: BodyDownloader<Block = B> + 'static,
40 {
41 let (bodies_tx, bodies_rx) = mpsc::channel(BODIES_TASK_BUFFER_SIZE);
42 let (to_downloader, updates_rx) = mpsc::unbounded_channel();
43
44 let downloader = SpawnedDownloader {
45 bodies_tx: PollSender::new(bodies_tx),
46 updates: UnboundedReceiverStream::new(updates_rx),
47 downloader,
48 };
49
50 runtime.spawn_task(downloader);
51
52 Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader }
53 }
54}
55
56impl<B: Block + 'static> BodyDownloader for TaskDownloader<B> {
57 type Block = B;
58
59 fn set_download_range(&mut self, range: RangeInclusive<BlockNumber>) -> DownloadResult<()> {
60 let _ = self.to_downloader.send(range);
61 Ok(())
62 }
63}
64
65impl<B: Block + 'static> Stream for TaskDownloader<B> {
66 type Item = BodyDownloaderResult<B>;
67
68 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69 self.project().from_downloader.poll_next(cx)
70 }
71}
72
73struct SpawnedDownloader<T: BodyDownloader> {
75 updates: UnboundedReceiverStream<RangeInclusive<BlockNumber>>,
76 bodies_tx: PollSender<BodyDownloaderResult<T::Block>>,
77 downloader: T,
78}
79
80impl<T: BodyDownloader> Future for SpawnedDownloader<T> {
81 type Output = ();
82
83 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84 let this = self.get_mut();
85
86 loop {
87 while let Poll::Ready(update) = this.updates.poll_next_unpin(cx) {
88 if let Some(range) = update {
89 if let Err(err) = this.downloader.set_download_range(range) {
90 tracing::error!(target: "downloaders::bodies", %err, "Failed to set bodies download range");
91
92 let mut bodies_tx = this.bodies_tx.clone();
94
95 let forward_error_result = ready!(bodies_tx.poll_reserve(cx))
96 .and_then(|_| bodies_tx.send_item(Err(err)));
97 if forward_error_result.is_err() {
98 return Poll::Ready(())
101 }
102 }
103 } else {
104 return Poll::Ready(())
107 }
108 }
109
110 match ready!(this.bodies_tx.poll_reserve(cx)) {
111 Ok(()) => match ready!(this.downloader.poll_next_unpin(cx)) {
112 Some(bodies) => {
113 if this.bodies_tx.send_item(bodies).is_err() {
114 return Poll::Ready(())
117 }
118 }
119 None => return Poll::Pending,
120 },
121 Err(_) => {
122 return Poll::Ready(())
125 }
126 }
127 }
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::{
135 bodies::{
136 bodies::BodiesDownloaderBuilder,
137 test_utils::{insert_headers, zip_blocks},
138 },
139 test_utils::{generate_bodies, TestBodiesClient},
140 };
141 use assert_matches::assert_matches;
142 use reth_consensus::test_utils::TestConsensus;
143 use reth_network_p2p::error::DownloadError;
144 use reth_provider::test_utils::create_test_provider_factory;
145 use std::sync::Arc;
146
147 #[tokio::test(flavor = "multi_thread")]
148 async fn download_one_by_one_on_task() {
149 reth_tracing::init_test_tracing();
150
151 let factory = create_test_provider_factory();
152 let (headers, mut bodies) = generate_bodies(0..=19);
153
154 insert_headers(&factory, &headers);
155
156 let client = Arc::new(
157 TestBodiesClient::default().with_bodies(bodies.clone()).with_should_delay(true),
158 );
159 let downloader = BodiesDownloaderBuilder::default()
160 .build::<reth_ethereum_primitives::Block, _, _>(
161 client.clone(),
162 Arc::new(TestConsensus::default()),
163 factory,
164 );
165 let runtime = Runtime::test();
166 let mut downloader = TaskDownloader::spawn_with(downloader, &runtime);
167
168 downloader.set_download_range(0..=19).expect("failed to set download range");
169
170 assert_matches!(
171 downloader.next().await,
172 Some(Ok(res)) => assert_eq!(res, zip_blocks(headers.iter(), &mut bodies))
173 );
174 assert_eq!(client.times_requested(), 1);
175 }
176
177 #[tokio::test(flavor = "multi_thread")]
178 #[expect(clippy::reversed_empty_ranges)]
179 async fn set_download_range_error_returned() {
180 reth_tracing::init_test_tracing();
181 let factory = create_test_provider_factory();
182
183 let downloader = BodiesDownloaderBuilder::default()
184 .build::<reth_ethereum_primitives::Block, _, _>(
185 Arc::new(TestBodiesClient::default()),
186 Arc::new(TestConsensus::default()),
187 factory,
188 );
189 let runtime = Runtime::test();
190 let mut downloader = TaskDownloader::spawn_with(downloader, &runtime);
191
192 downloader.set_download_range(1..=0).expect("failed to set download range");
193 assert_matches!(downloader.next().await, Some(Err(DownloadError::InvalidBodyRange { .. })));
194 }
195}