use std::{
cmp::Ordering,
task::{ready, Context, Poll},
};
use futures_util::TryStreamExt;
use reth_codecs::Compact;
use reth_primitives_traits::BlockBody;
use tracing::*;
use alloy_primitives::TxNumber;
use reth_db::{tables, transaction::DbTx};
use reth_db_api::{
cursor::{DbCursorRO, DbCursorRW},
transaction::DbTxMut,
};
use reth_network_p2p::bodies::{downloader::BodyDownloader, response::BlockResponse};
use reth_primitives::StaticFileSegment;
use reth_provider::{
providers::{StaticFileProvider, StaticFileWriter},
BlockReader, BlockWriter, DBProvider, ProviderError, StaticFileProviderFactory, StatsReader,
};
use reth_stages_api::{
EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId,
UnwindInput, UnwindOutput,
};
use reth_storage_errors::provider::ProviderResult;
#[derive(Debug)]
pub struct BodyStage<D: BodyDownloader> {
downloader: D,
buffer: Option<Vec<BlockResponse<D::Body>>>,
}
impl<D: BodyDownloader> BodyStage<D> {
pub const fn new(downloader: D) -> Self {
Self { downloader, buffer: None }
}
}
impl<Provider, D> Stage<Provider> for BodyStage<D>
where
Provider: DBProvider<Tx: DbTxMut>
+ StaticFileProviderFactory
+ StatsReader
+ BlockReader
+ BlockWriter<Body = D::Body>,
D: BodyDownloader<Body: BlockBody<Transaction: Compact>>,
{
fn id(&self) -> StageId {
StageId::Bodies
}
fn poll_execute_ready(
&mut self,
cx: &mut Context<'_>,
input: ExecInput,
) -> Poll<Result<(), StageError>> {
if input.target_reached() || self.buffer.is_some() {
return Poll::Ready(Ok(()))
}
self.downloader.set_download_range(input.next_block_range())?;
let maybe_next_result = ready!(self.downloader.try_poll_next_unpin(cx));
let response = match maybe_next_result {
Some(Ok(downloaded)) => {
self.buffer = Some(downloaded);
Ok(())
}
Some(Err(err)) => Err(err.into()),
None => Err(StageError::ChannelClosed),
};
Poll::Ready(response)
}
fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
if input.target_reached() {
return Ok(ExecOutput::done(input.checkpoint()))
}
let (from_block, to_block) = input.next_block_range().into_inner();
let mut next_tx_num = provider
.tx_ref()
.cursor_read::<tables::TransactionBlocks>()?
.last()?
.map(|(id, _)| id + 1)
.unwrap_or_default();
let static_file_provider = provider.static_file_provider();
let mut static_file_producer =
static_file_provider.get_writer(from_block, StaticFileSegment::Transactions)?;
let next_static_file_tx_num = static_file_provider
.get_highest_static_file_tx(StaticFileSegment::Transactions)
.map(|id| id + 1)
.unwrap_or_default();
match next_static_file_tx_num.cmp(&next_tx_num) {
Ordering::Greater => {
static_file_producer
.prune_transactions(next_static_file_tx_num - next_tx_num, from_block - 1)?;
static_file_producer.commit()?;
}
Ordering::Less => {
return Err(missing_static_data_error(
next_static_file_tx_num.saturating_sub(1),
&static_file_provider,
provider,
)?)
}
Ordering::Equal => {}
}
debug!(target: "sync::stages::bodies", stage_progress = from_block, target = to_block, start_tx_id = next_tx_num, "Commencing sync");
let buffer = self.buffer.take().ok_or(StageError::MissingDownloadBuffer)?;
trace!(target: "sync::stages::bodies", bodies_len = buffer.len(), "Writing blocks");
let mut highest_block = from_block;
for response in &buffer {
let block_number = response.block_number();
if block_number > 0 {
let appended_block_number = static_file_producer.increment_block(block_number)?;
if appended_block_number != block_number {
return Err(StageError::InconsistentBlockNumber {
segment: StaticFileSegment::Transactions,
database: block_number,
static_file: appended_block_number,
})
}
}
match response {
BlockResponse::Full(block) => {
for transaction in block.body.transactions() {
let appended_tx_number =
static_file_producer.append_transaction(next_tx_num, transaction)?;
if appended_tx_number != next_tx_num {
return Err(StageError::InconsistentTxNumber {
segment: StaticFileSegment::Transactions,
database: next_tx_num,
static_file: appended_tx_number,
})
}
next_tx_num += 1;
}
}
BlockResponse::Empty(_) => {}
};
highest_block = block_number;
}
provider.append_block_bodies(
buffer.into_iter().map(|response| (response.block_number(), response.into_body())),
)?;
let done = highest_block == to_block;
Ok(ExecOutput {
checkpoint: StageCheckpoint::new(highest_block)
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
done,
})
}
fn unwind(
&mut self,
provider: &Provider,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
self.buffer.take();
let static_file_provider = provider.static_file_provider();
let tx = provider.tx_ref();
let mut body_cursor = tx.cursor_write::<tables::BlockBodyIndices>()?;
let mut ommers_cursor = tx.cursor_write::<tables::BlockOmmers>()?;
let mut withdrawals_cursor = tx.cursor_write::<tables::BlockWithdrawals>()?;
let mut tx_block_cursor = tx.cursor_write::<tables::TransactionBlocks>()?;
let mut rev_walker = body_cursor.walk_back(None)?;
while let Some((number, block_meta)) = rev_walker.next().transpose()? {
if number <= input.unwind_to {
break
}
if ommers_cursor.seek_exact(number)?.is_some() {
ommers_cursor.delete_current()?;
}
if withdrawals_cursor.seek_exact(number)?.is_some() {
withdrawals_cursor.delete_current()?;
}
if !block_meta.is_empty() &&
tx_block_cursor.seek_exact(block_meta.last_tx_num())?.is_some()
{
tx_block_cursor.delete_current()?;
}
rev_walker.delete_current()?;
}
let mut static_file_producer =
static_file_provider.latest_writer(StaticFileSegment::Transactions)?;
let db_tx_num =
body_cursor.last()?.map(|(_, block_meta)| block_meta.last_tx_num()).unwrap_or_default();
let static_file_tx_num: u64 = static_file_provider
.get_highest_static_file_tx(StaticFileSegment::Transactions)
.unwrap_or_default();
if db_tx_num > static_file_tx_num {
return Err(missing_static_data_error(
static_file_tx_num,
&static_file_provider,
provider,
)?)
}
static_file_producer
.prune_transactions(static_file_tx_num.saturating_sub(db_tx_num), input.unwind_to)?;
Ok(UnwindOutput {
checkpoint: StageCheckpoint::new(input.unwind_to)
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
})
}
}
fn missing_static_data_error<Provider>(
last_tx_num: TxNumber,
static_file_provider: &StaticFileProvider,
provider: &Provider,
) -> Result<StageError, ProviderError>
where
Provider: BlockReader,
{
let mut last_block = static_file_provider
.get_highest_static_file_block(StaticFileSegment::Transactions)
.unwrap_or_default();
loop {
if let Some(indices) = provider.block_body_indices(last_block)? {
if indices.last_tx_num() <= last_tx_num {
break
}
}
if last_block == 0 {
break
}
last_block -= 1;
}
let missing_block = Box::new(provider.sealed_header(last_block + 1)?.unwrap_or_default());
Ok(StageError::MissingStaticFileData {
block: missing_block,
segment: StaticFileSegment::Transactions,
})
}
fn stage_checkpoint<Provider>(provider: &Provider) -> ProviderResult<EntitiesCheckpoint>
where
Provider: StatsReader + StaticFileProviderFactory,
{
Ok(EntitiesCheckpoint {
processed: provider.count_entries::<tables::BlockBodyIndices>()? as u64,
total: provider.static_file_provider().count_entries::<tables::Headers>()? as u64,
})
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use reth_provider::StaticFileProviderFactory;
use reth_stages_api::StageUnitCheckpoint;
use test_utils::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
};
use super::*;
stage_test_suite_ext!(BodyTestRunner, body);
#[tokio::test]
async fn partial_body_download() {
let (stage_progress, previous_stage) = (1, 200);
let mut runner = BodyTestRunner::default();
let input = ExecInput {
target: Some(previous_stage),
checkpoint: Some(StageCheckpoint::new(stage_progress)),
};
runner.seed_execution(input).expect("failed to seed execution");
let batch_size = 10;
runner.set_batch_size(batch_size);
let rx = runner.execute(input);
let output = rx.await.unwrap();
runner.db().factory.static_file_provider().commit().unwrap();
assert_matches!(
output,
Ok(ExecOutput { checkpoint: StageCheckpoint {
block_number,
stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
processed, total }))
}, done: false }) if block_number < 200 &&
processed == batch_size + 1 && total == previous_stage + 1
);
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
#[tokio::test]
async fn full_body_download() {
let (stage_progress, previous_stage) = (1, 20);
let mut runner = BodyTestRunner::default();
let input = ExecInput {
target: Some(previous_stage),
checkpoint: Some(StageCheckpoint::new(stage_progress)),
};
runner.seed_execution(input).expect("failed to seed execution");
runner.set_batch_size(40);
let rx = runner.execute(input);
let output = rx.await.unwrap();
runner.db().factory.static_file_provider().commit().unwrap();
assert_matches!(
output,
Ok(ExecOutput {
checkpoint: StageCheckpoint {
block_number: 20,
stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
processed,
total
}))
},
done: true
}) if processed + 1 == total && total == previous_stage + 1
);
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
#[tokio::test]
async fn sync_from_previous_progress() {
let (stage_progress, previous_stage) = (1, 21);
let mut runner = BodyTestRunner::default();
let input = ExecInput {
target: Some(previous_stage),
checkpoint: Some(StageCheckpoint::new(stage_progress)),
};
runner.seed_execution(input).expect("failed to seed execution");
let batch_size = 10;
runner.set_batch_size(batch_size);
let rx = runner.execute(input);
let first_run = rx.await.unwrap();
runner.db().factory.static_file_provider().commit().unwrap();
assert_matches!(
first_run,
Ok(ExecOutput { checkpoint: StageCheckpoint {
block_number,
stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
processed,
total
}))
}, done: false }) if block_number >= 10 &&
processed - 1 == batch_size && total == previous_stage + 1
);
let first_run_checkpoint = first_run.unwrap().checkpoint;
let input =
ExecInput { target: Some(previous_stage), checkpoint: Some(first_run_checkpoint) };
let rx = runner.execute(input);
let output = rx.await.unwrap();
runner.db().factory.static_file_provider().commit().unwrap();
assert_matches!(
output,
Ok(ExecOutput { checkpoint: StageCheckpoint {
block_number,
stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
processed,
total
}))
}, done: true }) if block_number > first_run_checkpoint.block_number &&
processed + 1 == total && total == previous_stage + 1
);
assert_matches!(
runner.validate_execution(input, output.ok()),
Ok(_),
"execution validation"
);
}
#[tokio::test]
async fn unwind_missing_tx() {
let (stage_progress, previous_stage) = (1, 20);
let mut runner = BodyTestRunner::default();
let input = ExecInput {
target: Some(previous_stage),
checkpoint: Some(StageCheckpoint::new(stage_progress)),
};
runner.seed_execution(input).expect("failed to seed execution");
runner.set_batch_size(40);
let rx = runner.execute(input);
let output = rx.await.unwrap();
runner.db().factory.static_file_provider().commit().unwrap();
assert_matches!(
output,
Ok(ExecOutput { checkpoint: StageCheckpoint {
block_number,
stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
processed,
total
}))
}, done: true }) if block_number == previous_stage &&
processed + 1 == total && total == previous_stage + 1
);
let checkpoint = output.unwrap().checkpoint;
runner
.validate_db_blocks(input.checkpoint().block_number, checkpoint.block_number)
.expect("Written block data invalid");
let static_file_provider = runner.db().factory.static_file_provider();
{
let mut static_file_producer =
static_file_provider.latest_writer(StaticFileSegment::Transactions).unwrap();
static_file_producer.prune_transactions(1, checkpoint.block_number).unwrap();
static_file_producer.commit().unwrap();
}
let unwind_to = 1;
let input = UnwindInput { bad_block: None, checkpoint, unwind_to };
let res = runner.unwind(input).await;
assert_matches!(
res,
Ok(UnwindOutput { checkpoint: StageCheckpoint {
block_number: 1,
stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
processed: 1,
total
}))
}}) if total == previous_stage + 1
);
assert_matches!(runner.validate_unwind(input), Ok(_), "unwind validation");
}
mod test_utils {
use crate::{
stages::bodies::BodyStage,
test_utils::{
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
UnwindStageTestRunner,
},
};
use alloy_consensus::Header;
use alloy_primitives::{BlockHash, BlockNumber, TxNumber, B256};
use futures_util::Stream;
use reth_db::{static_file::HeaderMask, tables};
use reth_db_api::{
cursor::DbCursorRO,
models::{StoredBlockBodyIndices, StoredBlockOmmers},
transaction::{DbTx, DbTxMut},
};
use reth_network_p2p::{
bodies::{
downloader::{BodyDownloader, BodyDownloaderResult},
response::BlockResponse,
},
error::DownloadResult,
};
use reth_primitives::{BlockBody, SealedBlock, SealedHeader, StaticFileSegment};
use reth_provider::{
providers::StaticFileWriter, test_utils::MockNodeTypesWithDB, HeaderProvider,
ProviderFactory, StaticFileProviderFactory, TransactionsProvider,
};
use reth_stages_api::{ExecInput, ExecOutput, UnwindInput};
use reth_testing_utils::generators::{
self, random_block_range, random_signed_tx, BlockRangeParams,
};
use std::{
collections::{HashMap, VecDeque},
ops::RangeInclusive,
pin::Pin,
task::{Context, Poll},
};
pub(crate) const GENESIS_HASH: B256 = B256::ZERO;
pub(crate) fn body_by_hash(block: &SealedBlock) -> (B256, BlockBody) {
(block.hash(), block.body.clone())
}
pub(crate) struct BodyTestRunner {
responses: HashMap<B256, BlockBody>,
db: TestStageDB,
batch_size: u64,
}
impl Default for BodyTestRunner {
fn default() -> Self {
Self { responses: HashMap::default(), db: TestStageDB::default(), batch_size: 1000 }
}
}
impl BodyTestRunner {
pub(crate) fn set_batch_size(&mut self, batch_size: u64) {
self.batch_size = batch_size;
}
pub(crate) fn set_responses(&mut self, responses: HashMap<B256, BlockBody>) {
self.responses = responses;
}
}
impl StageTestRunner for BodyTestRunner {
type S = BodyStage<TestBodyDownloader>;
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
BodyStage::new(TestBodyDownloader::new(
self.db.factory.clone(),
self.responses.clone(),
self.batch_size,
))
}
}
impl ExecuteStageTestRunner for BodyTestRunner {
type Seed = Vec<SealedBlock>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let start = input.checkpoint().block_number;
let end = input.target();
let static_file_provider = self.db.factory.static_file_provider();
let mut rng = generators::rng();
let blocks = random_block_range(
&mut rng,
0..=end,
BlockRangeParams {
parent: Some(GENESIS_HASH),
tx_count: 0..2,
..Default::default()
},
);
self.db.insert_headers_with_td(blocks.iter().map(|block| &block.header))?;
if let Some(progress) = blocks.get(start as usize) {
{
let tx = self.db.factory.provider_rw()?.into_tx();
let mut static_file_producer = static_file_provider
.get_writer(start, StaticFileSegment::Transactions)?;
let body = StoredBlockBodyIndices {
first_tx_num: 0,
tx_count: progress.body.transactions.len() as u64,
};
static_file_producer.set_block_range(0..=progress.number);
body.tx_num_range().try_for_each(|tx_num| {
let transaction = random_signed_tx(&mut rng);
static_file_producer.append_transaction(tx_num, &transaction).map(drop)
})?;
if body.tx_count != 0 {
tx.put::<tables::TransactionBlocks>(
body.last_tx_num(),
progress.number,
)?;
}
tx.put::<tables::BlockBodyIndices>(progress.number, body)?;
if !progress.ommers_hash_is_empty() {
tx.put::<tables::BlockOmmers>(
progress.number,
StoredBlockOmmers { ommers: progress.body.ommers.clone() },
)?;
}
static_file_producer.commit()?;
tx.commit()?;
}
}
self.set_responses(blocks.iter().map(body_by_hash).collect());
Ok(blocks)
}
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
let highest_block = match output.as_ref() {
Some(output) => output.checkpoint,
None => input.checkpoint(),
}
.block_number;
self.validate_db_blocks(highest_block, highest_block)
}
}
impl UnwindStageTestRunner for BodyTestRunner {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.db.ensure_no_entry_above::<tables::BlockBodyIndices, _>(
input.unwind_to,
|key| key,
)?;
self.db
.ensure_no_entry_above::<tables::BlockOmmers, _>(input.unwind_to, |key| key)?;
if let Some(last_tx_id) = self.get_last_tx_id()? {
self.db
.ensure_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
self.db.ensure_no_entry_above::<tables::TransactionBlocks, _>(
last_tx_id,
|key| key,
)?;
}
Ok(())
}
}
impl BodyTestRunner {
pub(crate) fn get_last_tx_id(&self) -> Result<Option<TxNumber>, TestRunnerError> {
let last_body = self.db.query(|tx| {
let v = tx.cursor_read::<tables::BlockBodyIndices>()?.last()?;
Ok(v)
})?;
Ok(match last_body {
Some((_, body)) if body.tx_count != 0 => {
Some(body.first_tx_num + body.tx_count - 1)
}
_ => None,
})
}
pub(crate) fn validate_db_blocks(
&self,
prev_progress: BlockNumber,
highest_block: BlockNumber,
) -> Result<(), TestRunnerError> {
let static_file_provider = self.db.factory.static_file_provider();
self.db.query(|tx| {
let mut bodies_cursor = tx.cursor_read::<tables::BlockBodyIndices>()?;
let mut ommers_cursor = tx.cursor_read::<tables::BlockOmmers>()?;
let mut tx_block_cursor = tx.cursor_read::<tables::TransactionBlocks>()?;
let first_body_key = match bodies_cursor.first()? {
Some((key, _)) => key,
None => return Ok(()),
};
let mut prev_number: Option<BlockNumber> = None;
for entry in bodies_cursor.walk(Some(first_body_key))? {
let (number, body) = entry?;
if number > prev_progress {
if let Some(prev_key) = prev_number {
assert_eq!(prev_key + 1, number, "Body entries must be sequential");
}
}
assert!(
number <= highest_block,
"We wrote a block body outside of our synced range. Found block with number {number}, highest block according to stage is {highest_block}",
);
let header = static_file_provider.header_by_number(number)?.expect("to be present");
let stored_ommers = ommers_cursor.seek_exact(number)?;
if header.ommers_hash_is_empty() {
assert!(stored_ommers.is_none(), "Unexpected ommers entry");
} else {
assert!(stored_ommers.is_some(), "Missing ommers entry");
}
let tx_block_id = tx_block_cursor.seek_exact(body.last_tx_num())?.map(|(_,b)| b);
if body.tx_count == 0 {
assert_ne!(tx_block_id,Some(number));
} else {
assert_eq!(tx_block_id, Some(number));
}
for tx_id in body.tx_num_range() {
assert!(static_file_provider.transaction_by_id(tx_id)?.is_some(), "Transaction is missing.");
}
prev_number = Some(number);
}
Ok(())
})?;
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct TestBodyDownloader {
provider_factory: ProviderFactory<MockNodeTypesWithDB>,
responses: HashMap<B256, BlockBody>,
headers: VecDeque<SealedHeader>,
batch_size: u64,
}
impl TestBodyDownloader {
pub(crate) fn new(
provider_factory: ProviderFactory<MockNodeTypesWithDB>,
responses: HashMap<B256, BlockBody>,
batch_size: u64,
) -> Self {
Self { provider_factory, responses, headers: VecDeque::default(), batch_size }
}
}
impl BodyDownloader for TestBodyDownloader {
type Body = BlockBody;
fn set_download_range(
&mut self,
range: RangeInclusive<BlockNumber>,
) -> DownloadResult<()> {
let static_file_provider = self.provider_factory.static_file_provider();
for header in static_file_provider.fetch_range_iter(
StaticFileSegment::Headers,
*range.start()..*range.end() + 1,
|cursor, number| cursor.get_two::<HeaderMask<Header, BlockHash>>(number.into()),
)? {
let (header, hash) = header?;
self.headers.push_back(SealedHeader::new(header, hash));
}
Ok(())
}
}
impl Stream for TestBodyDownloader {
type Item = BodyDownloaderResult<BlockBody>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.headers.is_empty() {
return Poll::Ready(None)
}
let mut response =
Vec::with_capacity(std::cmp::min(this.headers.len(), this.batch_size as usize));
while let Some(header) = this.headers.pop_front() {
if header.is_empty() {
response.push(BlockResponse::Empty(header))
} else {
let body =
this.responses.remove(&header.hash()).expect("requested unknown body");
response.push(BlockResponse::Full(SealedBlock { header, body }));
}
if response.len() as u64 >= this.batch_size {
break
}
}
if !response.is_empty() {
return Poll::Ready(Some(Ok(response)))
}
panic!("requested bodies without setting headers")
}
}
}
}