1use super::missing_static_data_error;
2use futures_util::TryStreamExt;
3use reth_db_api::{
4 cursor::DbCursorRO,
5 tables,
6 transaction::{DbTx, DbTxMut},
7};
8use reth_network_p2p::bodies::{downloader::BodyDownloader, response::BlockResponse};
9use reth_provider::{
10 providers::StaticFileWriter, BlockReader, BlockWriter, DBProvider, ProviderError,
11 StaticFileProviderFactory, StatsReader,
12};
13use reth_stages_api::{
14 EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId,
15 UnwindInput, UnwindOutput,
16};
17use reth_static_file_types::StaticFileSegment;
18use reth_storage_errors::provider::ProviderResult;
19use std::{
20 cmp::Ordering,
21 task::{ready, Context, Poll},
22};
23use tracing::*;
24
25#[derive(Debug)]
56pub struct BodyStage<D: BodyDownloader> {
57 downloader: D,
59 buffer: Option<Vec<BlockResponse<D::Block>>>,
61}
62
63impl<D: BodyDownloader> BodyStage<D> {
64 pub const fn new(downloader: D) -> Self {
66 Self { downloader, buffer: None }
67 }
68}
69
70pub(crate) fn ensure_consistency<Provider>(
72 provider: &Provider,
73 unwind_block: Option<u64>,
74) -> Result<(), StageError>
75where
76 Provider: DBProvider<Tx: DbTxMut> + BlockReader + StaticFileProviderFactory,
77{
78 let next_tx_num = provider
80 .tx_ref()
81 .cursor_read::<tables::TransactionBlocks>()?
82 .last()?
83 .map(|(id, _)| id + 1)
84 .unwrap_or_default();
85
86 let static_file_provider = provider.static_file_provider();
87
88 let next_static_file_tx_num = static_file_provider
91 .get_highest_static_file_tx(StaticFileSegment::Transactions)
92 .map(|id| id + 1)
93 .unwrap_or_default();
94
95 match next_static_file_tx_num.cmp(&next_tx_num) {
96 Ordering::Greater => {
100 let highest_db_block = provider.tx_ref().entries::<tables::BlockBodyIndices>()? as u64;
101 let mut static_file_producer =
102 static_file_provider.latest_writer(StaticFileSegment::Transactions)?;
103 static_file_producer
104 .prune_transactions(next_static_file_tx_num - next_tx_num, highest_db_block)?;
105 static_file_producer.commit()?;
108 }
109 Ordering::Less => {
113 if let Some(unwind_to) = unwind_block {
116 let next_tx_num_after_unwind = provider
117 .block_body_indices(unwind_to)?
118 .map(|b| b.next_tx_num())
119 .ok_or(ProviderError::BlockBodyIndicesNotFound(unwind_to))?;
120
121 if next_tx_num_after_unwind > next_static_file_tx_num {
123 return Err(missing_static_data_error(
124 next_static_file_tx_num.saturating_sub(1),
125 &static_file_provider,
126 provider,
127 StaticFileSegment::Transactions,
128 )?)
129 }
130 } else {
131 return Err(missing_static_data_error(
132 next_static_file_tx_num.saturating_sub(1),
133 &static_file_provider,
134 provider,
135 StaticFileSegment::Transactions,
136 )?)
137 }
138 }
139 Ordering::Equal => {}
140 }
141
142 Ok(())
143}
144
145impl<Provider, D> Stage<Provider> for BodyStage<D>
146where
147 Provider: DBProvider<Tx: DbTxMut>
148 + StaticFileProviderFactory
149 + StatsReader
150 + BlockReader
151 + BlockWriter<Block = D::Block>,
152 D: BodyDownloader,
153{
154 fn id(&self) -> StageId {
156 StageId::Bodies
157 }
158
159 fn poll_execute_ready(
160 &mut self,
161 cx: &mut Context<'_>,
162 input: ExecInput,
163 ) -> Poll<Result<(), StageError>> {
164 if input.target_reached() || self.buffer.is_some() {
165 return Poll::Ready(Ok(()))
166 }
167
168 self.downloader.set_download_range(input.next_block_range())?;
170
171 let maybe_next_result = ready!(self.downloader.try_poll_next_unpin(cx));
173
174 let response = match maybe_next_result {
177 Some(Ok(downloaded)) => {
178 self.buffer = Some(downloaded);
179 Ok(())
180 }
181 Some(Err(err)) => Err(err.into()),
182 None => Err(StageError::ChannelClosed),
183 };
184 Poll::Ready(response)
185 }
186
187 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
190 if input.target_reached() {
191 return Ok(ExecOutput::done(input.checkpoint()))
192 }
193 let (from_block, to_block) = input.next_block_range().into_inner();
194
195 ensure_consistency(provider, None)?;
196
197 debug!(target: "sync::stages::bodies", stage_progress = from_block, target = to_block, "Commencing sync");
198
199 let buffer = self.buffer.take().ok_or(StageError::MissingDownloadBuffer)?;
200 trace!(target: "sync::stages::bodies", bodies_len = buffer.len(), "Writing blocks");
201 let highest_block = buffer.last().map(|r| r.block_number()).unwrap_or(from_block);
202
203 provider.append_block_bodies(
205 buffer
206 .into_iter()
207 .map(|response| (response.block_number(), response.into_body()))
208 .collect(),
209 )?;
210
211 let done = highest_block == to_block;
215 Ok(ExecOutput {
216 checkpoint: StageCheckpoint::new(highest_block)
217 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
218 done,
219 })
220 }
221
222 fn unwind(
224 &mut self,
225 provider: &Provider,
226 input: UnwindInput,
227 ) -> Result<UnwindOutput, StageError> {
228 self.buffer.take();
229
230 ensure_consistency(provider, Some(input.unwind_to))?;
231 provider.remove_bodies_above(input.unwind_to)?;
232
233 Ok(UnwindOutput {
234 checkpoint: StageCheckpoint::new(input.unwind_to)
235 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
236 })
237 }
238}
239
240fn stage_checkpoint<Provider>(provider: &Provider) -> ProviderResult<EntitiesCheckpoint>
244where
245 Provider: StatsReader + StaticFileProviderFactory,
246{
247 Ok(EntitiesCheckpoint {
248 processed: provider.count_entries::<tables::BlockBodyIndices>()? as u64,
249 total: provider.static_file_provider().count_entries::<tables::Headers>()? as u64,
253 })
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::test_utils::{
260 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
261 };
262 use assert_matches::assert_matches;
263 use reth_provider::StaticFileProviderFactory;
264 use reth_stages_api::StageUnitCheckpoint;
265 use test_utils::*;
266
267 stage_test_suite_ext!(BodyTestRunner, body);
268
269 #[tokio::test]
271 async fn partial_body_download() {
272 let (stage_progress, previous_stage) = (1, 200);
273
274 let mut runner = BodyTestRunner::default();
276 let input = ExecInput {
277 target: Some(previous_stage),
278 checkpoint: Some(StageCheckpoint::new(stage_progress)),
279 };
280 runner.seed_execution(input).expect("failed to seed execution");
281
282 let batch_size = 10;
285 runner.set_batch_size(batch_size);
286
287 let rx = runner.execute(input);
289
290 let output = rx.await.unwrap();
293 runner.db().factory.static_file_provider().commit().unwrap();
294 assert_matches!(
295 output,
296 Ok(ExecOutput { checkpoint: StageCheckpoint {
297 block_number,
298 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
299 processed, total }))
302 }, done: false }) if block_number < 200 &&
303 processed == batch_size + 1 && total == previous_stage + 1
304 );
305 assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
306 }
307
308 #[tokio::test]
310 async fn full_body_download() {
311 let (stage_progress, previous_stage) = (1, 20);
312
313 let mut runner = BodyTestRunner::default();
315 let input = ExecInput {
316 target: Some(previous_stage),
317 checkpoint: Some(StageCheckpoint::new(stage_progress)),
318 };
319 runner.seed_execution(input).expect("failed to seed execution");
320
321 runner.set_batch_size(40);
323
324 let rx = runner.execute(input);
326
327 let output = rx.await.unwrap();
330 runner.db().factory.static_file_provider().commit().unwrap();
331 assert_matches!(
332 output,
333 Ok(ExecOutput {
334 checkpoint: StageCheckpoint {
335 block_number: 20,
336 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
337 processed,
338 total
339 }))
340 },
341 done: true
342 }) if processed + 1 == total && total == previous_stage + 1
343 );
344 assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
345 }
346
347 #[tokio::test]
349 async fn sync_from_previous_progress() {
350 let (stage_progress, previous_stage) = (1, 21);
351
352 let mut runner = BodyTestRunner::default();
354 let input = ExecInput {
355 target: Some(previous_stage),
356 checkpoint: Some(StageCheckpoint::new(stage_progress)),
357 };
358 runner.seed_execution(input).expect("failed to seed execution");
359
360 let batch_size = 10;
361 runner.set_batch_size(batch_size);
362
363 let rx = runner.execute(input);
365
366 let first_run = rx.await.unwrap();
368 runner.db().factory.static_file_provider().commit().unwrap();
369 assert_matches!(
370 first_run,
371 Ok(ExecOutput { checkpoint: StageCheckpoint {
372 block_number,
373 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
374 processed,
375 total
376 }))
377 }, done: false }) if block_number >= 10 &&
378 processed - 1 == batch_size && total == previous_stage + 1
379 );
380 let first_run_checkpoint = first_run.unwrap().checkpoint;
381
382 let input =
384 ExecInput { target: Some(previous_stage), checkpoint: Some(first_run_checkpoint) };
385 let rx = runner.execute(input);
386
387 let output = rx.await.unwrap();
389 runner.db().factory.static_file_provider().commit().unwrap();
390 assert_matches!(
391 output,
392 Ok(ExecOutput { checkpoint: StageCheckpoint {
393 block_number,
394 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
395 processed,
396 total
397 }))
398 }, done: true }) if block_number > first_run_checkpoint.block_number &&
399 processed + 1 == total && total == previous_stage + 1
400 );
401 assert_matches!(
402 runner.validate_execution(input, output.ok()),
403 Ok(_),
404 "execution validation"
405 );
406 }
407
408 #[tokio::test]
410 async fn unwind_missing_tx() {
411 let (stage_progress, previous_stage) = (1, 20);
412
413 let mut runner = BodyTestRunner::default();
415 let input = ExecInput {
416 target: Some(previous_stage),
417 checkpoint: Some(StageCheckpoint::new(stage_progress)),
418 };
419 runner.seed_execution(input).expect("failed to seed execution");
420
421 runner.set_batch_size(40);
423
424 let rx = runner.execute(input);
426
427 let output = rx.await.unwrap();
430 runner.db().factory.static_file_provider().commit().unwrap();
431 assert_matches!(
432 output,
433 Ok(ExecOutput { checkpoint: StageCheckpoint {
434 block_number,
435 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
436 processed,
437 total
438 }))
439 }, done: true }) if block_number == previous_stage &&
440 processed + 1 == total && total == previous_stage + 1
441 );
442 let checkpoint = output.unwrap().checkpoint;
443 runner
444 .validate_db_blocks(input.checkpoint().block_number, checkpoint.block_number)
445 .expect("Written block data invalid");
446
447 let static_file_provider = runner.db().factory.static_file_provider();
449 {
450 let mut static_file_producer =
451 static_file_provider.latest_writer(StaticFileSegment::Transactions).unwrap();
452 static_file_producer.prune_transactions(1, checkpoint.block_number).unwrap();
453 static_file_producer.commit().unwrap();
454 }
455 let unwind_to = 1;
457 let input = UnwindInput { bad_block: None, checkpoint, unwind_to };
458 let res = runner.unwind(input).await;
459 assert_matches!(
460 res,
461 Ok(UnwindOutput { checkpoint: StageCheckpoint {
462 block_number: 1,
463 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
464 processed: 1,
465 total
466 }))
467 }}) if total == previous_stage + 1
468 );
469
470 assert_matches!(runner.validate_unwind(input), Ok(_), "unwind validation");
471 }
472
473 mod test_utils {
474 use crate::{
475 stages::bodies::BodyStage,
476 test_utils::{
477 ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
478 UnwindStageTestRunner,
479 },
480 };
481 use alloy_consensus::{BlockHeader, Header};
482 use alloy_primitives::{BlockNumber, TxNumber, B256};
483 use futures_util::Stream;
484 use reth_db::{static_file::HeaderWithHashMask, tables};
485 use reth_db_api::{
486 cursor::DbCursorRO,
487 models::{StoredBlockBodyIndices, StoredBlockOmmers},
488 transaction::{DbTx, DbTxMut},
489 };
490 use reth_ethereum_primitives::{Block, BlockBody};
491 use reth_network_p2p::{
492 bodies::{
493 downloader::{BodyDownloader, BodyDownloaderResult},
494 response::BlockResponse,
495 },
496 error::DownloadResult,
497 };
498 use reth_primitives_traits::{SealedBlock, SealedHeader};
499 use reth_provider::{
500 providers::StaticFileWriter, test_utils::MockNodeTypesWithDB, HeaderProvider,
501 ProviderFactory, StaticFileProviderFactory, TransactionsProvider,
502 };
503 use reth_stages_api::{ExecInput, ExecOutput, UnwindInput};
504 use reth_static_file_types::StaticFileSegment;
505 use reth_testing_utils::generators::{
506 self, random_block_range, random_signed_tx, BlockRangeParams,
507 };
508 use std::{
509 collections::{HashMap, VecDeque},
510 ops::RangeInclusive,
511 pin::Pin,
512 task::{Context, Poll},
513 };
514
515 pub(crate) const GENESIS_HASH: B256 = B256::ZERO;
517
518 pub(crate) fn body_by_hash(block: &SealedBlock<Block>) -> (B256, BlockBody) {
520 (block.hash(), block.body().clone())
521 }
522
523 pub(crate) struct BodyTestRunner {
525 responses: HashMap<B256, BlockBody>,
526 db: TestStageDB,
527 batch_size: u64,
528 }
529
530 impl Default for BodyTestRunner {
531 fn default() -> Self {
532 Self { responses: HashMap::default(), db: TestStageDB::default(), batch_size: 1000 }
533 }
534 }
535
536 impl BodyTestRunner {
537 pub(crate) fn set_batch_size(&mut self, batch_size: u64) {
538 self.batch_size = batch_size;
539 }
540
541 pub(crate) fn set_responses(&mut self, responses: HashMap<B256, BlockBody>) {
542 self.responses = responses;
543 }
544 }
545
546 impl StageTestRunner for BodyTestRunner {
547 type S = BodyStage<TestBodyDownloader>;
548
549 fn db(&self) -> &TestStageDB {
550 &self.db
551 }
552
553 fn stage(&self) -> Self::S {
554 BodyStage::new(TestBodyDownloader::new(
555 self.db.factory.clone(),
556 self.responses.clone(),
557 self.batch_size,
558 ))
559 }
560 }
561
562 impl ExecuteStageTestRunner for BodyTestRunner {
563 type Seed = Vec<SealedBlock<Block>>;
564
565 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
566 let start = input.checkpoint().block_number;
567 let end = input.target();
568
569 let static_file_provider = self.db.factory.static_file_provider();
570
571 let mut rng = generators::rng();
572
573 let blocks = random_block_range(
575 &mut rng,
576 0..=end,
577 BlockRangeParams {
578 parent: Some(GENESIS_HASH),
579 tx_count: 0..2,
580 ..Default::default()
581 },
582 );
583 self.db.insert_headers_with_td(blocks.iter().map(|block| block.sealed_header()))?;
584 if let Some(progress) = blocks.get(start as usize) {
585 {
587 let tx = self.db.factory.provider_rw()?.into_tx();
588 let mut static_file_producer = static_file_provider
589 .get_writer(start, StaticFileSegment::Transactions)?;
590
591 let body = StoredBlockBodyIndices {
592 first_tx_num: 0,
593 tx_count: progress.transaction_count() as u64,
594 };
595
596 static_file_producer.set_block_range(0..=progress.number);
597
598 body.tx_num_range().try_for_each(|tx_num| {
599 let transaction = random_signed_tx(&mut rng);
600 static_file_producer.append_transaction(tx_num, &transaction).map(drop)
601 })?;
602
603 if body.tx_count != 0 {
604 tx.put::<tables::TransactionBlocks>(
605 body.last_tx_num(),
606 progress.number,
607 )?;
608 }
609
610 tx.put::<tables::BlockBodyIndices>(progress.number, body)?;
611
612 if !progress.ommers_hash_is_empty() {
613 tx.put::<tables::BlockOmmers>(
614 progress.number,
615 StoredBlockOmmers { ommers: progress.body().ommers.clone() },
616 )?;
617 }
618
619 static_file_producer.commit()?;
620 tx.commit()?;
621 }
622 }
623 self.set_responses(blocks.iter().map(body_by_hash).collect());
624 Ok(blocks)
625 }
626
627 fn validate_execution(
628 &self,
629 input: ExecInput,
630 output: Option<ExecOutput>,
631 ) -> Result<(), TestRunnerError> {
632 let highest_block = match output.as_ref() {
633 Some(output) => output.checkpoint,
634 None => input.checkpoint(),
635 }
636 .block_number;
637 self.validate_db_blocks(highest_block, highest_block)
638 }
639 }
640
641 impl UnwindStageTestRunner for BodyTestRunner {
642 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
643 self.db.ensure_no_entry_above::<tables::BlockBodyIndices, _>(
644 input.unwind_to,
645 |key| key,
646 )?;
647 self.db
648 .ensure_no_entry_above::<tables::BlockOmmers, _>(input.unwind_to, |key| key)?;
649 if let Some(last_tx_id) = self.get_last_tx_id()? {
650 self.db
651 .ensure_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
652 self.db.ensure_no_entry_above::<tables::TransactionBlocks, _>(
653 last_tx_id,
654 |key| key,
655 )?;
656 }
657 Ok(())
658 }
659 }
660
661 impl BodyTestRunner {
662 pub(crate) fn get_last_tx_id(&self) -> Result<Option<TxNumber>, TestRunnerError> {
664 let last_body = self.db.query(|tx| {
665 let v = tx.cursor_read::<tables::BlockBodyIndices>()?.last()?;
666 Ok(v)
667 })?;
668 Ok(match last_body {
669 Some((_, body)) if body.tx_count != 0 => {
670 Some(body.first_tx_num + body.tx_count - 1)
671 }
672 _ => None,
673 })
674 }
675
676 pub(crate) fn validate_db_blocks(
678 &self,
679 prev_progress: BlockNumber,
680 highest_block: BlockNumber,
681 ) -> Result<(), TestRunnerError> {
682 let static_file_provider = self.db.factory.static_file_provider();
683
684 self.db.query(|tx| {
685 let mut bodies_cursor = tx.cursor_read::<tables::BlockBodyIndices>()?;
687 let mut ommers_cursor = tx.cursor_read::<tables::BlockOmmers>()?;
688 let mut tx_block_cursor = tx.cursor_read::<tables::TransactionBlocks>()?;
689
690 let first_body_key = match bodies_cursor.first()? {
691 Some((key, _)) => key,
692 None => return Ok(()),
693 };
694
695 let mut prev_number: Option<BlockNumber> = None;
696
697
698 for entry in bodies_cursor.walk(Some(first_body_key))? {
699 let (number, body) = entry?;
700
701 if number > prev_progress
704 && let Some(prev_key) = prev_number {
705 assert_eq!(prev_key + 1, number, "Body entries must be sequential");
706 }
707
708 assert!(
710 number <= highest_block,
711 "We wrote a block body outside of our synced range. Found block with number {number}, highest block according to stage is {highest_block}",
712 );
713
714 let header = static_file_provider.header_by_number(number)?.expect("to be present");
715 let stored_ommers = ommers_cursor.seek_exact(number)?;
717 if header.ommers_hash_is_empty() {
718 assert!(stored_ommers.is_none(), "Unexpected ommers entry");
719 } else {
720 assert!(stored_ommers.is_some(), "Missing ommers entry");
721 }
722
723 let tx_block_id = tx_block_cursor.seek_exact(body.last_tx_num())?.map(|(_,b)| b);
724 if body.tx_count == 0 {
725 assert_ne!(tx_block_id,Some(number));
726 } else {
727 assert_eq!(tx_block_id, Some(number));
728 }
729
730 for tx_id in body.tx_num_range() {
731 assert!(static_file_provider.transaction_by_id(tx_id)?.is_some(), "Transaction is missing.");
732 }
733
734 prev_number = Some(number);
735 }
736 Ok(())
737 })?;
738 Ok(())
739 }
740 }
741
742 #[derive(Debug)]
744 pub(crate) struct TestBodyDownloader {
745 provider_factory: ProviderFactory<MockNodeTypesWithDB>,
746 responses: HashMap<B256, BlockBody>,
747 headers: VecDeque<SealedHeader>,
748 batch_size: u64,
749 }
750
751 impl TestBodyDownloader {
752 pub(crate) fn new(
753 provider_factory: ProviderFactory<MockNodeTypesWithDB>,
754 responses: HashMap<B256, BlockBody>,
755 batch_size: u64,
756 ) -> Self {
757 Self { provider_factory, responses, headers: VecDeque::default(), batch_size }
758 }
759 }
760
761 impl BodyDownloader for TestBodyDownloader {
762 type Block = Block;
763
764 fn set_download_range(
765 &mut self,
766 range: RangeInclusive<BlockNumber>,
767 ) -> DownloadResult<()> {
768 let static_file_provider = self.provider_factory.static_file_provider();
769
770 for header in static_file_provider.fetch_range_iter(
771 StaticFileSegment::Headers,
772 *range.start()..*range.end() + 1,
773 |cursor, number| cursor.get_two::<HeaderWithHashMask<Header>>(number.into()),
774 )? {
775 let (header, hash) = header?;
776 self.headers.push_back(SealedHeader::new(header, hash));
777 }
778
779 Ok(())
780 }
781 }
782
783 impl Stream for TestBodyDownloader {
784 type Item = BodyDownloaderResult<Block>;
785 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
786 let this = self.get_mut();
787
788 if this.headers.is_empty() {
789 return Poll::Ready(None)
790 }
791
792 let mut response =
793 Vec::with_capacity(std::cmp::min(this.headers.len(), this.batch_size as usize));
794 while let Some(header) = this.headers.pop_front() {
795 if header.is_empty() {
796 response.push(BlockResponse::Empty(header))
797 } else {
798 let body =
799 this.responses.remove(&header.hash()).expect("requested unknown body");
800 response.push(BlockResponse::Full(SealedBlock::from_sealed_parts(
801 header, body,
802 )));
803 }
804
805 if response.len() as u64 >= this.batch_size {
806 break
807 }
808 }
809
810 if !response.is_empty() {
811 return Poll::Ready(Some(Ok(response)))
812 }
813
814 panic!("requested bodies without setting headers")
815 }
816 }
817 }
818}