reth_downloaders/bodies/
task.rsuse alloy_primitives::BlockNumber;
use futures::Stream;
use futures_util::{FutureExt, StreamExt};
use pin_project::pin_project;
use reth_network_p2p::{
bodies::downloader::{BodyDownloader, BodyDownloaderResult},
error::DownloadResult,
};
use reth_tasks::{TaskSpawner, TokioTaskExecutor};
use std::{
fmt::Debug,
future::Future,
ops::RangeInclusive,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::sync::{mpsc, mpsc::UnboundedSender};
use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
use tokio_util::sync::PollSender;
pub const BODIES_TASK_BUFFER_SIZE: usize = 4;
#[derive(Debug)]
#[pin_project]
pub struct TaskDownloader<H, B> {
#[pin]
from_downloader: ReceiverStream<BodyDownloaderResult<H, B>>,
to_downloader: UnboundedSender<RangeInclusive<BlockNumber>>,
}
impl<H: Send + Sync + Unpin + 'static, B: Send + Sync + Unpin + 'static> TaskDownloader<H, B> {
pub fn spawn<T>(downloader: T) -> Self
where
T: BodyDownloader<Header = H, Body = B> + 'static,
{
Self::spawn_with(downloader, &TokioTaskExecutor::default())
}
pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
where
T: BodyDownloader<Header = H, Body = B> + 'static,
S: TaskSpawner,
{
let (bodies_tx, bodies_rx) = mpsc::channel(BODIES_TASK_BUFFER_SIZE);
let (to_downloader, updates_rx) = mpsc::unbounded_channel();
let downloader = SpawnedDownloader {
bodies_tx: PollSender::new(bodies_tx),
updates: UnboundedReceiverStream::new(updates_rx),
downloader,
};
spawner.spawn(downloader.boxed());
Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader }
}
}
impl<H: Debug + Send + Sync + Unpin + 'static, B: Debug + Send + Sync + Unpin + 'static>
BodyDownloader for TaskDownloader<H, B>
{
type Header = H;
type Body = B;
fn set_download_range(&mut self, range: RangeInclusive<BlockNumber>) -> DownloadResult<()> {
let _ = self.to_downloader.send(range);
Ok(())
}
}
impl<H, B> Stream for TaskDownloader<H, B> {
type Item = BodyDownloaderResult<H, B>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().from_downloader.poll_next(cx)
}
}
struct SpawnedDownloader<T: BodyDownloader> {
updates: UnboundedReceiverStream<RangeInclusive<BlockNumber>>,
bodies_tx: PollSender<BodyDownloaderResult<T::Header, T::Body>>,
downloader: T,
}
impl<T: BodyDownloader> Future for SpawnedDownloader<T> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
while let Poll::Ready(update) = this.updates.poll_next_unpin(cx) {
if let Some(range) = update {
if let Err(err) = this.downloader.set_download_range(range) {
tracing::error!(target: "downloaders::bodies", %err, "Failed to set bodies download range");
let mut bodies_tx = this.bodies_tx.clone();
let forward_error_result = ready!(bodies_tx.poll_reserve(cx))
.and_then(|_| bodies_tx.send_item(Err(err)));
if forward_error_result.is_err() {
return Poll::Ready(())
}
}
} else {
return Poll::Ready(())
}
}
match ready!(this.bodies_tx.poll_reserve(cx)) {
Ok(()) => match ready!(this.downloader.poll_next_unpin(cx)) {
Some(bodies) => {
if this.bodies_tx.send_item(bodies).is_err() {
return Poll::Ready(())
}
}
None => return Poll::Pending,
},
Err(_) => {
return Poll::Ready(())
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
bodies::{
bodies::BodiesDownloaderBuilder,
test_utils::{insert_headers, zip_blocks},
},
test_utils::{generate_bodies, TestBodiesClient},
};
use assert_matches::assert_matches;
use reth_consensus::test_utils::TestConsensus;
use reth_network_p2p::error::DownloadError;
use reth_provider::test_utils::create_test_provider_factory;
use std::sync::Arc;
#[tokio::test(flavor = "multi_thread")]
async fn download_one_by_one_on_task() {
reth_tracing::init_test_tracing();
let factory = create_test_provider_factory();
let (headers, mut bodies) = generate_bodies(0..=19);
insert_headers(factory.db_ref().db(), &headers);
let client = Arc::new(
TestBodiesClient::default().with_bodies(bodies.clone()).with_should_delay(true),
);
let downloader = BodiesDownloaderBuilder::default().build(
client.clone(),
Arc::new(TestConsensus::default()),
factory,
);
let mut downloader = TaskDownloader::spawn(downloader);
downloader.set_download_range(0..=19).expect("failed to set download range");
assert_matches!(
downloader.next().await,
Some(Ok(res)) => assert_eq!(res, zip_blocks(headers.iter(), &mut bodies))
);
assert_eq!(client.times_requested(), 1);
}
#[tokio::test(flavor = "multi_thread")]
#[allow(clippy::reversed_empty_ranges)]
async fn set_download_range_error_returned() {
reth_tracing::init_test_tracing();
let factory = create_test_provider_factory();
let downloader = BodiesDownloaderBuilder::default().build(
Arc::new(TestBodiesClient::default()),
Arc::new(TestConsensus::default()),
factory,
);
let mut downloader = TaskDownloader::spawn(downloader);
downloader.set_download_range(1..=0).expect("failed to set download range");
assert_matches!(downloader.next().await, Some(Err(DownloadError::InvalidBodyRange { .. })));
}
}