reth_downloaders/headers/
task.rsuse futures::{FutureExt, Stream};
use futures_util::StreamExt;
use pin_project::pin_project;
use reth_network_p2p::headers::{
downloader::{HeaderDownloader, SyncTarget},
error::HeadersDownloaderResult,
};
use reth_primitives::SealedHeader;
use reth_tasks::{TaskSpawner, TokioTaskExecutor};
use std::{
fmt::Debug,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::sync::{mpsc, mpsc::UnboundedSender};
use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
use tokio_util::sync::PollSender;
pub const HEADERS_TASK_BUFFER_SIZE: usize = 8;
#[derive(Debug)]
#[pin_project]
pub struct TaskDownloader<H> {
#[pin]
from_downloader: ReceiverStream<HeadersDownloaderResult<Vec<SealedHeader<H>>, H>>,
to_downloader: UnboundedSender<DownloaderUpdates<H>>,
}
impl<H: Send + Sync + Unpin + 'static> TaskDownloader<H> {
pub fn spawn<T>(downloader: T) -> Self
where
T: HeaderDownloader<Header = H> + 'static,
{
Self::spawn_with(downloader, &TokioTaskExecutor::default())
}
pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
where
T: HeaderDownloader<Header = H> + 'static,
S: TaskSpawner,
{
let (headers_tx, headers_rx) = mpsc::channel(HEADERS_TASK_BUFFER_SIZE);
let (to_downloader, updates_rx) = mpsc::unbounded_channel();
let downloader = SpawnedDownloader {
headers_tx: PollSender::new(headers_tx),
updates: UnboundedReceiverStream::new(updates_rx),
downloader,
};
spawner.spawn(downloader.boxed());
Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader }
}
}
impl<H: Debug + Send + Sync + Unpin + 'static> HeaderDownloader for TaskDownloader<H> {
type Header = H;
fn update_sync_gap(&mut self, head: SealedHeader<H>, target: SyncTarget) {
let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncGap(head, target));
}
fn update_local_head(&mut self, head: SealedHeader<H>) {
let _ = self.to_downloader.send(DownloaderUpdates::UpdateLocalHead(head));
}
fn update_sync_target(&mut self, target: SyncTarget) {
let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncTarget(target));
}
fn set_batch_size(&mut self, limit: usize) {
let _ = self.to_downloader.send(DownloaderUpdates::SetBatchSize(limit));
}
}
impl<H> Stream for TaskDownloader<H> {
type Item = HeadersDownloaderResult<Vec<SealedHeader<H>>, H>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().from_downloader.poll_next(cx)
}
}
#[expect(clippy::complexity)]
struct SpawnedDownloader<T: HeaderDownloader> {
updates: UnboundedReceiverStream<DownloaderUpdates<T::Header>>,
headers_tx: PollSender<HeadersDownloaderResult<Vec<SealedHeader<T::Header>>, T::Header>>,
downloader: T,
}
impl<T: HeaderDownloader> Future for SpawnedDownloader<T> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
loop {
match this.updates.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => {
return Poll::Ready(())
}
Poll::Ready(Some(update)) => match update {
DownloaderUpdates::UpdateSyncGap(head, target) => {
this.downloader.update_sync_gap(head, target);
}
DownloaderUpdates::UpdateLocalHead(head) => {
this.downloader.update_local_head(head);
}
DownloaderUpdates::UpdateSyncTarget(target) => {
this.downloader.update_sync_target(target);
}
DownloaderUpdates::SetBatchSize(limit) => {
this.downloader.set_batch_size(limit);
}
},
}
}
match ready!(this.headers_tx.poll_reserve(cx)) {
Ok(()) => {
match ready!(this.downloader.poll_next_unpin(cx)) {
Some(headers) => {
if this.headers_tx.send_item(headers).is_err() {
return Poll::Ready(())
}
}
None => return Poll::Pending,
}
}
Err(_) => {
return Poll::Ready(())
}
}
}
}
}
enum DownloaderUpdates<H> {
UpdateSyncGap(SealedHeader<H>, SyncTarget),
UpdateLocalHead(SealedHeader<H>),
UpdateSyncTarget(SyncTarget),
SetBatchSize(usize),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::headers::{
reverse_headers::ReverseHeadersDownloaderBuilder, test_utils::child_header,
};
use reth_consensus::test_utils::TestConsensus;
use reth_network_p2p::test_utils::TestHeadersClient;
use std::sync::Arc;
#[tokio::test(flavor = "multi_thread")]
async fn download_one_by_one_on_task() {
reth_tracing::init_test_tracing();
let p3 = SealedHeader::default();
let p2 = child_header(&p3);
let p1 = child_header(&p2);
let p0 = child_header(&p1);
let client = Arc::new(TestHeadersClient::default());
let downloader = ReverseHeadersDownloaderBuilder::default()
.stream_batch_size(1)
.request_limit(1)
.build(Arc::clone(&client), Arc::new(TestConsensus::default()));
let mut downloader = TaskDownloader::spawn(downloader);
downloader.update_local_head(p3.clone());
downloader.update_sync_target(SyncTarget::Tip(p0.hash()));
client
.extend(vec![
p0.as_ref().clone(),
p1.as_ref().clone(),
p2.as_ref().clone(),
p3.as_ref().clone(),
])
.await;
let headers = downloader.next().await.unwrap();
assert_eq!(headers, Ok(vec![p0]));
let headers = downloader.next().await.unwrap();
assert_eq!(headers, Ok(vec![p1]));
let headers = downloader.next().await.unwrap();
assert_eq!(headers, Ok(vec![p2]));
}
}