reth_downloaders/bodies/
task.rs
1use alloy_primitives::BlockNumber;
2use futures::Stream;
3use futures_util::{FutureExt, StreamExt};
4use pin_project::pin_project;
5use reth_network_p2p::{
6 bodies::downloader::{BodyDownloader, BodyDownloaderResult},
7 error::DownloadResult,
8};
9use reth_primitives_traits::Block;
10use reth_tasks::{TaskSpawner, TokioTaskExecutor};
11use std::{
12 fmt::Debug,
13 future::Future,
14 ops::RangeInclusive,
15 pin::Pin,
16 task::{ready, Context, Poll},
17};
18use tokio::sync::{mpsc, mpsc::UnboundedSender};
19use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
20use tokio_util::sync::PollSender;
21
22pub const BODIES_TASK_BUFFER_SIZE: usize = 4;
24
25#[derive(Debug)]
27#[pin_project]
28pub struct TaskDownloader<B: Block> {
29 #[pin]
30 from_downloader: ReceiverStream<BodyDownloaderResult<B>>,
31 to_downloader: UnboundedSender<RangeInclusive<BlockNumber>>,
32}
33
34impl<B: Block + 'static> TaskDownloader<B> {
35 pub fn spawn<T>(downloader: T) -> Self
67 where
68 T: BodyDownloader<Block = B> + 'static,
69 {
70 Self::spawn_with(downloader, &TokioTaskExecutor::default())
71 }
72
73 pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
76 where
77 T: BodyDownloader<Block = B> + 'static,
78 S: TaskSpawner,
79 {
80 let (bodies_tx, bodies_rx) = mpsc::channel(BODIES_TASK_BUFFER_SIZE);
81 let (to_downloader, updates_rx) = mpsc::unbounded_channel();
82
83 let downloader = SpawnedDownloader {
84 bodies_tx: PollSender::new(bodies_tx),
85 updates: UnboundedReceiverStream::new(updates_rx),
86 downloader,
87 };
88
89 spawner.spawn(downloader.boxed());
90
91 Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader }
92 }
93}
94
95impl<B: Block + 'static> BodyDownloader for TaskDownloader<B> {
96 type Block = B;
97
98 fn set_download_range(&mut self, range: RangeInclusive<BlockNumber>) -> DownloadResult<()> {
99 let _ = self.to_downloader.send(range);
100 Ok(())
101 }
102}
103
104impl<B: Block + 'static> Stream for TaskDownloader<B> {
105 type Item = BodyDownloaderResult<B>;
106
107 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108 self.project().from_downloader.poll_next(cx)
109 }
110}
111
112struct SpawnedDownloader<T: BodyDownloader> {
114 updates: UnboundedReceiverStream<RangeInclusive<BlockNumber>>,
115 bodies_tx: PollSender<BodyDownloaderResult<T::Block>>,
116 downloader: T,
117}
118
119impl<T: BodyDownloader> Future for SpawnedDownloader<T> {
120 type Output = ();
121
122 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123 let this = self.get_mut();
124
125 loop {
126 while let Poll::Ready(update) = this.updates.poll_next_unpin(cx) {
127 if let Some(range) = update {
128 if let Err(err) = this.downloader.set_download_range(range) {
129 tracing::error!(target: "downloaders::bodies", %err, "Failed to set bodies download range");
130
131 let mut bodies_tx = this.bodies_tx.clone();
133
134 let forward_error_result = ready!(bodies_tx.poll_reserve(cx))
135 .and_then(|_| bodies_tx.send_item(Err(err)));
136 if forward_error_result.is_err() {
137 return Poll::Ready(())
140 }
141 }
142 } else {
143 return Poll::Ready(())
146 }
147 }
148
149 match ready!(this.bodies_tx.poll_reserve(cx)) {
150 Ok(()) => match ready!(this.downloader.poll_next_unpin(cx)) {
151 Some(bodies) => {
152 if this.bodies_tx.send_item(bodies).is_err() {
153 return Poll::Ready(())
156 }
157 }
158 None => return Poll::Pending,
159 },
160 Err(_) => {
161 return Poll::Ready(())
164 }
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::{
174 bodies::{
175 bodies::BodiesDownloaderBuilder,
176 test_utils::{insert_headers, zip_blocks},
177 },
178 test_utils::{generate_bodies, TestBodiesClient},
179 };
180 use assert_matches::assert_matches;
181 use reth_consensus::test_utils::TestConsensus;
182 use reth_network_p2p::error::DownloadError;
183 use reth_provider::test_utils::create_test_provider_factory;
184 use std::sync::Arc;
185
186 #[tokio::test(flavor = "multi_thread")]
187 async fn download_one_by_one_on_task() {
188 reth_tracing::init_test_tracing();
189
190 let factory = create_test_provider_factory();
191 let (headers, mut bodies) = generate_bodies(0..=19);
192
193 insert_headers(factory.db_ref().db(), &headers);
194
195 let client = Arc::new(
196 TestBodiesClient::default().with_bodies(bodies.clone()).with_should_delay(true),
197 );
198 let downloader = BodiesDownloaderBuilder::default().build::<reth_primitives::Block, _, _>(
199 client.clone(),
200 Arc::new(TestConsensus::default()),
201 factory,
202 );
203 let mut downloader = TaskDownloader::spawn(downloader);
204
205 downloader.set_download_range(0..=19).expect("failed to set download range");
206
207 assert_matches!(
208 downloader.next().await,
209 Some(Ok(res)) => assert_eq!(res, zip_blocks(headers.iter(), &mut bodies))
210 );
211 assert_eq!(client.times_requested(), 1);
212 }
213
214 #[tokio::test(flavor = "multi_thread")]
215 #[allow(clippy::reversed_empty_ranges)]
216 async fn set_download_range_error_returned() {
217 reth_tracing::init_test_tracing();
218 let factory = create_test_provider_factory();
219
220 let downloader = BodiesDownloaderBuilder::default().build::<reth_primitives::Block, _, _>(
221 Arc::new(TestBodiesClient::default()),
222 Arc::new(TestConsensus::default()),
223 factory,
224 );
225 let mut downloader = TaskDownloader::spawn(downloader);
226
227 downloader.set_download_range(1..=0).expect("failed to set download range");
228 assert_matches!(downloader.next().await, Some(Err(DownloadError::InvalidBodyRange { .. })));
229 }
230}