use crate::metrics::{BodyDownloaderMetrics, ResponseMetrics};
use alloy_consensus::BlockHeader;
use alloy_primitives::B256;
use futures::{Future, FutureExt};
use reth_consensus::Consensus;
use reth_network_p2p::{
bodies::{client::BodiesClient, response::BlockResponse},
error::{DownloadError, DownloadResult},
priority::Priority,
};
use reth_network_peers::{PeerId, WithPeerId};
use reth_primitives::{BlockBody, GotExpected, SealedBlock, SealedHeader};
use reth_primitives_traits::InMemorySize;
use std::{
collections::VecDeque,
mem,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
pub(crate) struct BodiesRequestFuture<H, B: BodiesClient> {
client: Arc<B>,
consensus: Arc<dyn Consensus<H, B::Body>>,
metrics: BodyDownloaderMetrics,
response_metrics: ResponseMetrics,
pending_headers: VecDeque<SealedHeader<H>>,
buffer: Vec<BlockResponse<H, B::Body>>,
fut: Option<B::Output>,
last_request_len: Option<usize>,
}
impl<H, B> BodiesRequestFuture<H, B>
where
H: BlockHeader,
B: BodiesClient + 'static,
{
pub(crate) fn new(
client: Arc<B>,
consensus: Arc<dyn Consensus<H, B::Body>>,
metrics: BodyDownloaderMetrics,
) -> Self {
Self {
client,
consensus,
metrics,
response_metrics: Default::default(),
pending_headers: Default::default(),
buffer: Default::default(),
last_request_len: None,
fut: None,
}
}
pub(crate) fn with_headers(mut self, headers: Vec<SealedHeader<H>>) -> Self {
self.buffer.reserve_exact(headers.len());
self.pending_headers = VecDeque::from(headers);
if let Some(req) = self.next_request() {
self.submit_request(req, Priority::Normal);
}
self
}
fn on_error(&mut self, error: DownloadError, peer_id: Option<PeerId>) {
self.metrics.increment_errors(&error);
tracing::debug!(target: "downloaders::bodies", ?peer_id, %error, "Error requesting bodies");
if let Some(peer_id) = peer_id {
self.client.report_bad_message(peer_id);
}
self.submit_request(
self.next_request().expect("existing hashes to resubmit"),
Priority::High,
);
}
fn next_request(&self) -> Option<Vec<B256>> {
let mut hashes =
self.pending_headers.iter().filter(|h| !h.is_empty()).map(|h| h.hash()).peekable();
hashes.peek().is_some().then(|| hashes.collect())
}
fn submit_request(&mut self, req: Vec<B256>, priority: Priority) {
tracing::trace!(target: "downloaders::bodies", request_len = req.len(), "Requesting bodies");
let client = Arc::clone(&self.client);
self.last_request_len = Some(req.len());
self.fut = Some(client.get_block_bodies_with_priority(req, priority));
}
fn on_block_response(&mut self, response: WithPeerId<Vec<B::Body>>) -> DownloadResult<()>
where
B::Body: InMemorySize,
{
let (peer_id, bodies) = response.split();
let request_len = self.last_request_len.unwrap_or_default();
let response_len = bodies.len();
tracing::trace!(target: "downloaders::bodies", request_len, response_len, ?peer_id, "Received bodies");
self.metrics.total_downloaded.increment(response_len as u64);
if bodies.is_empty() {
return Err(DownloadError::EmptyResponse)
}
if response_len > request_len {
return Err(DownloadError::TooManyBodies(GotExpected {
got: response_len,
expected: request_len,
}))
}
self.try_buffer_blocks(bodies)?;
if let Some(req) = self.next_request() {
self.submit_request(req, Priority::High);
} else {
self.fut = None;
}
Ok(())
}
fn try_buffer_blocks(&mut self, bodies: Vec<B::Body>) -> DownloadResult<()>
where
B::Body: InMemorySize,
{
let bodies_capacity = bodies.capacity();
let bodies_len = bodies.len();
let mut bodies = bodies.into_iter().peekable();
let mut total_size = bodies_capacity * mem::size_of::<BlockBody>();
while bodies.peek().is_some() {
let next_header = match self.pending_headers.pop_front() {
Some(header) => header,
None => return Ok(()), };
if next_header.is_empty() {
total_size += mem::size_of::<BlockBody>();
self.buffer.push(BlockResponse::Empty(next_header));
} else {
let next_body = bodies.next().unwrap();
total_size += next_body.size();
let block = SealedBlock::new(next_header, next_body);
if let Err(error) = self.consensus.validate_block_pre_execution(&block) {
let hash = block.hash();
let number = block.number();
self.pending_headers.push_front(block.header);
return Err(DownloadError::BodyValidation {
hash,
number,
error: Box::new(error),
})
}
self.buffer.push(BlockResponse::Full(block));
}
}
self.response_metrics.response_size_bytes.set(total_size as f64);
self.response_metrics.response_length.set(bodies_len as f64);
Ok(())
}
}
impl<H, B> Future for BodiesRequestFuture<H, B>
where
H: BlockHeader + Unpin + Send + Sync + 'static,
B: BodiesClient<Body: InMemorySize> + 'static,
{
type Output = DownloadResult<Vec<BlockResponse<H, B::Body>>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
if this.pending_headers.is_empty() {
return Poll::Ready(Ok(std::mem::take(&mut this.buffer)))
}
if let Some(fut) = this.fut.as_mut() {
match ready!(fut.poll_unpin(cx)) {
Ok(response) => {
let peer_id = response.peer_id();
if let Err(error) = this.on_block_response(response) {
this.on_error(error, Some(peer_id));
}
}
Err(error) => {
if error.is_channel_closed() {
return Poll::Ready(Err(error.into()))
}
this.on_error(error.into(), None);
}
}
}
while this.pending_headers.front().is_some_and(|h| h.is_empty()) {
let header = this.pending_headers.pop_front().unwrap();
this.buffer.push(BlockResponse::Empty(header));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
bodies::test_utils::zip_blocks,
test_utils::{generate_bodies, TestBodiesClient},
};
use reth_consensus::test_utils::TestConsensus;
use reth_testing_utils::{generators, generators::random_header_range};
#[tokio::test]
async fn request_returns_empty_bodies() {
let mut rng = generators::rng();
let headers = random_header_range(&mut rng, 0..20, B256::ZERO);
let client = Arc::new(TestBodiesClient::default());
let fut = BodiesRequestFuture::new(
client.clone(),
Arc::new(TestConsensus::default()),
BodyDownloaderMetrics::default(),
)
.with_headers(headers.clone());
assert_eq!(
fut.await.unwrap(),
headers.into_iter().map(BlockResponse::Empty).collect::<Vec<_>>()
);
assert_eq!(client.times_requested(), 0);
}
#[tokio::test]
async fn request_submits_until_fulfilled() {
let (headers, mut bodies) = generate_bodies(0..=19);
let batch_size = 2;
let client = Arc::new(
TestBodiesClient::default().with_bodies(bodies.clone()).with_max_batch_size(batch_size),
);
let fut = BodiesRequestFuture::new(
client.clone(),
Arc::new(TestConsensus::default()),
BodyDownloaderMetrics::default(),
)
.with_headers(headers.clone());
assert_eq!(fut.await.unwrap(), zip_blocks(headers.iter(), &mut bodies));
assert_eq!(
client.times_requested(),
(headers.into_iter().filter(|h| !h.is_empty()).count() as u64 + 1) / 2
);
}
}