1use super::missing_static_data_error;
2use futures_util::TryStreamExt;
3use reth_db_api::{
4 cursor::DbCursorRO,
5 tables,
6 transaction::{DbTx, DbTxMut},
7};
8use reth_network_p2p::bodies::{downloader::BodyDownloader, response::BlockResponse};
9use reth_provider::{
10 providers::StaticFileWriter, BlockReader, BlockWriter, DBProvider, ProviderError,
11 StaticFileProviderFactory, StatsReader, StorageLocation,
12};
13use reth_stages_api::{
14 EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId,
15 UnwindInput, UnwindOutput,
16};
17use reth_static_file_types::StaticFileSegment;
18use reth_storage_errors::provider::ProviderResult;
19use std::{
20 cmp::Ordering,
21 task::{ready, Context, Poll},
22};
23use tracing::*;
24
25#[derive(Debug)]
56pub struct BodyStage<D: BodyDownloader> {
57 downloader: D,
59 buffer: Option<Vec<BlockResponse<D::Block>>>,
61}
62
63impl<D: BodyDownloader> BodyStage<D> {
64 pub const fn new(downloader: D) -> Self {
66 Self { downloader, buffer: None }
67 }
68
69 fn ensure_consistency<Provider>(
71 &self,
72 provider: &Provider,
73 unwind_block: Option<u64>,
74 ) -> Result<(), StageError>
75 where
76 Provider: DBProvider<Tx: DbTxMut> + BlockReader + StaticFileProviderFactory,
77 {
78 let next_tx_num = provider
80 .tx_ref()
81 .cursor_read::<tables::TransactionBlocks>()?
82 .last()?
83 .map(|(id, _)| id + 1)
84 .unwrap_or_default();
85
86 let static_file_provider = provider.static_file_provider();
87
88 let next_static_file_tx_num = static_file_provider
91 .get_highest_static_file_tx(StaticFileSegment::Transactions)
92 .map(|id| id + 1)
93 .unwrap_or_default();
94
95 match next_static_file_tx_num.cmp(&next_tx_num) {
96 Ordering::Greater => {
100 let highest_db_block =
101 provider.tx_ref().entries::<tables::BlockBodyIndices>()? as u64;
102 let mut static_file_producer =
103 static_file_provider.latest_writer(StaticFileSegment::Transactions)?;
104 static_file_producer
105 .prune_transactions(next_static_file_tx_num - next_tx_num, highest_db_block)?;
106 static_file_producer.commit()?;
109 }
110 Ordering::Less => {
114 if let Some(unwind_to) = unwind_block {
117 let next_tx_num_after_unwind = provider
118 .block_body_indices(unwind_to)?
119 .map(|b| b.next_tx_num())
120 .ok_or(ProviderError::BlockBodyIndicesNotFound(unwind_to))?;
121
122 if next_tx_num_after_unwind > next_static_file_tx_num {
124 return Err(missing_static_data_error(
125 next_static_file_tx_num.saturating_sub(1),
126 &static_file_provider,
127 provider,
128 StaticFileSegment::Transactions,
129 )?)
130 }
131 } else {
132 return Err(missing_static_data_error(
133 next_static_file_tx_num.saturating_sub(1),
134 &static_file_provider,
135 provider,
136 StaticFileSegment::Transactions,
137 )?)
138 }
139 }
140 Ordering::Equal => {}
141 }
142
143 Ok(())
144 }
145}
146
147impl<Provider, D> Stage<Provider> for BodyStage<D>
148where
149 Provider: DBProvider<Tx: DbTxMut>
150 + StaticFileProviderFactory
151 + StatsReader
152 + BlockReader
153 + BlockWriter<Block = D::Block>,
154 D: BodyDownloader,
155{
156 fn id(&self) -> StageId {
158 StageId::Bodies
159 }
160
161 fn poll_execute_ready(
162 &mut self,
163 cx: &mut Context<'_>,
164 input: ExecInput,
165 ) -> Poll<Result<(), StageError>> {
166 if input.target_reached() || self.buffer.is_some() {
167 return Poll::Ready(Ok(()))
168 }
169
170 self.downloader.set_download_range(input.next_block_range())?;
172
173 let maybe_next_result = ready!(self.downloader.try_poll_next_unpin(cx));
175
176 let response = match maybe_next_result {
179 Some(Ok(downloaded)) => {
180 self.buffer = Some(downloaded);
181 Ok(())
182 }
183 Some(Err(err)) => Err(err.into()),
184 None => Err(StageError::ChannelClosed),
185 };
186 Poll::Ready(response)
187 }
188
189 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
192 if input.target_reached() {
193 return Ok(ExecOutput::done(input.checkpoint()))
194 }
195 let (from_block, to_block) = input.next_block_range().into_inner();
196
197 self.ensure_consistency(provider, None)?;
198
199 debug!(target: "sync::stages::bodies", stage_progress = from_block, target = to_block, "Commencing sync");
200
201 let buffer = self.buffer.take().ok_or(StageError::MissingDownloadBuffer)?;
202 trace!(target: "sync::stages::bodies", bodies_len = buffer.len(), "Writing blocks");
203 let highest_block = buffer.last().map(|r| r.block_number()).unwrap_or(from_block);
204
205 provider.append_block_bodies(
207 buffer
208 .into_iter()
209 .map(|response| (response.block_number(), response.into_body()))
210 .collect(),
211 StorageLocation::StaticFiles,
213 )?;
214
215 let done = highest_block == to_block;
219 Ok(ExecOutput {
220 checkpoint: StageCheckpoint::new(highest_block)
221 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
222 done,
223 })
224 }
225
226 fn unwind(
228 &mut self,
229 provider: &Provider,
230 input: UnwindInput,
231 ) -> Result<UnwindOutput, StageError> {
232 self.buffer.take();
233
234 self.ensure_consistency(provider, Some(input.unwind_to))?;
235 provider.remove_bodies_above(input.unwind_to, StorageLocation::Both)?;
236
237 Ok(UnwindOutput {
238 checkpoint: StageCheckpoint::new(input.unwind_to)
239 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
240 })
241 }
242}
243
244fn stage_checkpoint<Provider>(provider: &Provider) -> ProviderResult<EntitiesCheckpoint>
248where
249 Provider: StatsReader + StaticFileProviderFactory,
250{
251 Ok(EntitiesCheckpoint {
252 processed: provider.count_entries::<tables::BlockBodyIndices>()? as u64,
253 total: provider.static_file_provider().count_entries::<tables::Headers>()? as u64,
257 })
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::test_utils::{
264 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
265 };
266 use assert_matches::assert_matches;
267 use reth_provider::StaticFileProviderFactory;
268 use reth_stages_api::StageUnitCheckpoint;
269 use test_utils::*;
270
271 stage_test_suite_ext!(BodyTestRunner, body);
272
273 #[tokio::test]
275 async fn partial_body_download() {
276 let (stage_progress, previous_stage) = (1, 200);
277
278 let mut runner = BodyTestRunner::default();
280 let input = ExecInput {
281 target: Some(previous_stage),
282 checkpoint: Some(StageCheckpoint::new(stage_progress)),
283 };
284 runner.seed_execution(input).expect("failed to seed execution");
285
286 let batch_size = 10;
289 runner.set_batch_size(batch_size);
290
291 let rx = runner.execute(input);
293
294 let output = rx.await.unwrap();
297 runner.db().factory.static_file_provider().commit().unwrap();
298 assert_matches!(
299 output,
300 Ok(ExecOutput { checkpoint: StageCheckpoint {
301 block_number,
302 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
303 processed, total }))
306 }, done: false }) if block_number < 200 &&
307 processed == batch_size + 1 && total == previous_stage + 1
308 );
309 assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
310 }
311
312 #[tokio::test]
314 async fn full_body_download() {
315 let (stage_progress, previous_stage) = (1, 20);
316
317 let mut runner = BodyTestRunner::default();
319 let input = ExecInput {
320 target: Some(previous_stage),
321 checkpoint: Some(StageCheckpoint::new(stage_progress)),
322 };
323 runner.seed_execution(input).expect("failed to seed execution");
324
325 runner.set_batch_size(40);
327
328 let rx = runner.execute(input);
330
331 let output = rx.await.unwrap();
334 runner.db().factory.static_file_provider().commit().unwrap();
335 assert_matches!(
336 output,
337 Ok(ExecOutput {
338 checkpoint: StageCheckpoint {
339 block_number: 20,
340 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
341 processed,
342 total
343 }))
344 },
345 done: true
346 }) if processed + 1 == total && total == previous_stage + 1
347 );
348 assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
349 }
350
351 #[tokio::test]
353 async fn sync_from_previous_progress() {
354 let (stage_progress, previous_stage) = (1, 21);
355
356 let mut runner = BodyTestRunner::default();
358 let input = ExecInput {
359 target: Some(previous_stage),
360 checkpoint: Some(StageCheckpoint::new(stage_progress)),
361 };
362 runner.seed_execution(input).expect("failed to seed execution");
363
364 let batch_size = 10;
365 runner.set_batch_size(batch_size);
366
367 let rx = runner.execute(input);
369
370 let first_run = rx.await.unwrap();
372 runner.db().factory.static_file_provider().commit().unwrap();
373 assert_matches!(
374 first_run,
375 Ok(ExecOutput { checkpoint: StageCheckpoint {
376 block_number,
377 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
378 processed,
379 total
380 }))
381 }, done: false }) if block_number >= 10 &&
382 processed - 1 == batch_size && total == previous_stage + 1
383 );
384 let first_run_checkpoint = first_run.unwrap().checkpoint;
385
386 let input =
388 ExecInput { target: Some(previous_stage), checkpoint: Some(first_run_checkpoint) };
389 let rx = runner.execute(input);
390
391 let output = rx.await.unwrap();
393 runner.db().factory.static_file_provider().commit().unwrap();
394 assert_matches!(
395 output,
396 Ok(ExecOutput { checkpoint: StageCheckpoint {
397 block_number,
398 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
399 processed,
400 total
401 }))
402 }, done: true }) if block_number > first_run_checkpoint.block_number &&
403 processed + 1 == total && total == previous_stage + 1
404 );
405 assert_matches!(
406 runner.validate_execution(input, output.ok()),
407 Ok(_),
408 "execution validation"
409 );
410 }
411
412 #[tokio::test]
414 async fn unwind_missing_tx() {
415 let (stage_progress, previous_stage) = (1, 20);
416
417 let mut runner = BodyTestRunner::default();
419 let input = ExecInput {
420 target: Some(previous_stage),
421 checkpoint: Some(StageCheckpoint::new(stage_progress)),
422 };
423 runner.seed_execution(input).expect("failed to seed execution");
424
425 runner.set_batch_size(40);
427
428 let rx = runner.execute(input);
430
431 let output = rx.await.unwrap();
434 runner.db().factory.static_file_provider().commit().unwrap();
435 assert_matches!(
436 output,
437 Ok(ExecOutput { checkpoint: StageCheckpoint {
438 block_number,
439 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
440 processed,
441 total
442 }))
443 }, done: true }) if block_number == previous_stage &&
444 processed + 1 == total && total == previous_stage + 1
445 );
446 let checkpoint = output.unwrap().checkpoint;
447 runner
448 .validate_db_blocks(input.checkpoint().block_number, checkpoint.block_number)
449 .expect("Written block data invalid");
450
451 let static_file_provider = runner.db().factory.static_file_provider();
453 {
454 let mut static_file_producer =
455 static_file_provider.latest_writer(StaticFileSegment::Transactions).unwrap();
456 static_file_producer.prune_transactions(1, checkpoint.block_number).unwrap();
457 static_file_producer.commit().unwrap();
458 }
459 let unwind_to = 1;
461 let input = UnwindInput { bad_block: None, checkpoint, unwind_to };
462 let res = runner.unwind(input).await;
463 assert_matches!(
464 res,
465 Ok(UnwindOutput { checkpoint: StageCheckpoint {
466 block_number: 1,
467 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
468 processed: 1,
469 total
470 }))
471 }}) if total == previous_stage + 1
472 );
473
474 assert_matches!(runner.validate_unwind(input), Ok(_), "unwind validation");
475 }
476
477 mod test_utils {
478 use crate::{
479 stages::bodies::BodyStage,
480 test_utils::{
481 ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
482 UnwindStageTestRunner,
483 },
484 };
485 use alloy_consensus::{BlockHeader, Header};
486 use alloy_primitives::{BlockNumber, TxNumber, B256};
487 use futures_util::Stream;
488 use reth_db::{static_file::HeaderWithHashMask, tables};
489 use reth_db_api::{
490 cursor::DbCursorRO,
491 models::{StoredBlockBodyIndices, StoredBlockOmmers},
492 transaction::{DbTx, DbTxMut},
493 };
494 use reth_ethereum_primitives::{Block, BlockBody};
495 use reth_network_p2p::{
496 bodies::{
497 downloader::{BodyDownloader, BodyDownloaderResult},
498 response::BlockResponse,
499 },
500 error::DownloadResult,
501 };
502 use reth_primitives_traits::{SealedBlock, SealedHeader};
503 use reth_provider::{
504 providers::StaticFileWriter, test_utils::MockNodeTypesWithDB, HeaderProvider,
505 ProviderFactory, StaticFileProviderFactory, TransactionsProvider,
506 };
507 use reth_stages_api::{ExecInput, ExecOutput, UnwindInput};
508 use reth_static_file_types::StaticFileSegment;
509 use reth_testing_utils::generators::{
510 self, random_block_range, random_signed_tx, BlockRangeParams,
511 };
512 use std::{
513 collections::{HashMap, VecDeque},
514 ops::RangeInclusive,
515 pin::Pin,
516 task::{Context, Poll},
517 };
518
519 pub(crate) const GENESIS_HASH: B256 = B256::ZERO;
521
522 pub(crate) fn body_by_hash(block: &SealedBlock<Block>) -> (B256, BlockBody) {
524 (block.hash(), block.body().clone())
525 }
526
527 pub(crate) struct BodyTestRunner {
529 responses: HashMap<B256, BlockBody>,
530 db: TestStageDB,
531 batch_size: u64,
532 }
533
534 impl Default for BodyTestRunner {
535 fn default() -> Self {
536 Self { responses: HashMap::default(), db: TestStageDB::default(), batch_size: 1000 }
537 }
538 }
539
540 impl BodyTestRunner {
541 pub(crate) fn set_batch_size(&mut self, batch_size: u64) {
542 self.batch_size = batch_size;
543 }
544
545 pub(crate) fn set_responses(&mut self, responses: HashMap<B256, BlockBody>) {
546 self.responses = responses;
547 }
548 }
549
550 impl StageTestRunner for BodyTestRunner {
551 type S = BodyStage<TestBodyDownloader>;
552
553 fn db(&self) -> &TestStageDB {
554 &self.db
555 }
556
557 fn stage(&self) -> Self::S {
558 BodyStage::new(TestBodyDownloader::new(
559 self.db.factory.clone(),
560 self.responses.clone(),
561 self.batch_size,
562 ))
563 }
564 }
565
566 impl ExecuteStageTestRunner for BodyTestRunner {
567 type Seed = Vec<SealedBlock<Block>>;
568
569 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
570 let start = input.checkpoint().block_number;
571 let end = input.target();
572
573 let static_file_provider = self.db.factory.static_file_provider();
574
575 let mut rng = generators::rng();
576
577 let blocks = random_block_range(
579 &mut rng,
580 0..=end,
581 BlockRangeParams {
582 parent: Some(GENESIS_HASH),
583 tx_count: 0..2,
584 ..Default::default()
585 },
586 );
587 self.db.insert_headers_with_td(blocks.iter().map(|block| block.sealed_header()))?;
588 if let Some(progress) = blocks.get(start as usize) {
589 {
591 let tx = self.db.factory.provider_rw()?.into_tx();
592 let mut static_file_producer = static_file_provider
593 .get_writer(start, StaticFileSegment::Transactions)?;
594
595 let body = StoredBlockBodyIndices {
596 first_tx_num: 0,
597 tx_count: progress.transaction_count() as u64,
598 };
599
600 static_file_producer.set_block_range(0..=progress.number);
601
602 body.tx_num_range().try_for_each(|tx_num| {
603 let transaction = random_signed_tx(&mut rng);
604 static_file_producer.append_transaction(tx_num, &transaction).map(drop)
605 })?;
606
607 if body.tx_count != 0 {
608 tx.put::<tables::TransactionBlocks>(
609 body.last_tx_num(),
610 progress.number,
611 )?;
612 }
613
614 tx.put::<tables::BlockBodyIndices>(progress.number, body)?;
615
616 if !progress.ommers_hash_is_empty() {
617 tx.put::<tables::BlockOmmers>(
618 progress.number,
619 StoredBlockOmmers { ommers: progress.body().ommers.clone() },
620 )?;
621 }
622
623 static_file_producer.commit()?;
624 tx.commit()?;
625 }
626 }
627 self.set_responses(blocks.iter().map(body_by_hash).collect());
628 Ok(blocks)
629 }
630
631 fn validate_execution(
632 &self,
633 input: ExecInput,
634 output: Option<ExecOutput>,
635 ) -> Result<(), TestRunnerError> {
636 let highest_block = match output.as_ref() {
637 Some(output) => output.checkpoint,
638 None => input.checkpoint(),
639 }
640 .block_number;
641 self.validate_db_blocks(highest_block, highest_block)
642 }
643 }
644
645 impl UnwindStageTestRunner for BodyTestRunner {
646 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
647 self.db.ensure_no_entry_above::<tables::BlockBodyIndices, _>(
648 input.unwind_to,
649 |key| key,
650 )?;
651 self.db
652 .ensure_no_entry_above::<tables::BlockOmmers, _>(input.unwind_to, |key| key)?;
653 if let Some(last_tx_id) = self.get_last_tx_id()? {
654 self.db
655 .ensure_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
656 self.db.ensure_no_entry_above::<tables::TransactionBlocks, _>(
657 last_tx_id,
658 |key| key,
659 )?;
660 }
661 Ok(())
662 }
663 }
664
665 impl BodyTestRunner {
666 pub(crate) fn get_last_tx_id(&self) -> Result<Option<TxNumber>, TestRunnerError> {
668 let last_body = self.db.query(|tx| {
669 let v = tx.cursor_read::<tables::BlockBodyIndices>()?.last()?;
670 Ok(v)
671 })?;
672 Ok(match last_body {
673 Some((_, body)) if body.tx_count != 0 => {
674 Some(body.first_tx_num + body.tx_count - 1)
675 }
676 _ => None,
677 })
678 }
679
680 pub(crate) fn validate_db_blocks(
682 &self,
683 prev_progress: BlockNumber,
684 highest_block: BlockNumber,
685 ) -> Result<(), TestRunnerError> {
686 let static_file_provider = self.db.factory.static_file_provider();
687
688 self.db.query(|tx| {
689 let mut bodies_cursor = tx.cursor_read::<tables::BlockBodyIndices>()?;
691 let mut ommers_cursor = tx.cursor_read::<tables::BlockOmmers>()?;
692 let mut tx_block_cursor = tx.cursor_read::<tables::TransactionBlocks>()?;
693
694 let first_body_key = match bodies_cursor.first()? {
695 Some((key, _)) => key,
696 None => return Ok(()),
697 };
698
699 let mut prev_number: Option<BlockNumber> = None;
700
701
702 for entry in bodies_cursor.walk(Some(first_body_key))? {
703 let (number, body) = entry?;
704
705 if number > prev_progress {
708 if let Some(prev_key) = prev_number {
709 assert_eq!(prev_key + 1, number, "Body entries must be sequential");
710 }
711 }
712
713 assert!(
715 number <= highest_block,
716 "We wrote a block body outside of our synced range. Found block with number {number}, highest block according to stage is {highest_block}",
717 );
718
719 let header = static_file_provider.header_by_number(number)?.expect("to be present");
720 let stored_ommers = ommers_cursor.seek_exact(number)?;
722 if header.ommers_hash_is_empty() {
723 assert!(stored_ommers.is_none(), "Unexpected ommers entry");
724 } else {
725 assert!(stored_ommers.is_some(), "Missing ommers entry");
726 }
727
728 let tx_block_id = tx_block_cursor.seek_exact(body.last_tx_num())?.map(|(_,b)| b);
729 if body.tx_count == 0 {
730 assert_ne!(tx_block_id,Some(number));
731 } else {
732 assert_eq!(tx_block_id, Some(number));
733 }
734
735 for tx_id in body.tx_num_range() {
736 assert!(static_file_provider.transaction_by_id(tx_id)?.is_some(), "Transaction is missing.");
737 }
738
739 prev_number = Some(number);
740 }
741 Ok(())
742 })?;
743 Ok(())
744 }
745 }
746
747 #[derive(Debug)]
749 pub(crate) struct TestBodyDownloader {
750 provider_factory: ProviderFactory<MockNodeTypesWithDB>,
751 responses: HashMap<B256, BlockBody>,
752 headers: VecDeque<SealedHeader>,
753 batch_size: u64,
754 }
755
756 impl TestBodyDownloader {
757 pub(crate) fn new(
758 provider_factory: ProviderFactory<MockNodeTypesWithDB>,
759 responses: HashMap<B256, BlockBody>,
760 batch_size: u64,
761 ) -> Self {
762 Self { provider_factory, responses, headers: VecDeque::default(), batch_size }
763 }
764 }
765
766 impl BodyDownloader for TestBodyDownloader {
767 type Block = Block;
768
769 fn set_download_range(
770 &mut self,
771 range: RangeInclusive<BlockNumber>,
772 ) -> DownloadResult<()> {
773 let static_file_provider = self.provider_factory.static_file_provider();
774
775 for header in static_file_provider.fetch_range_iter(
776 StaticFileSegment::Headers,
777 *range.start()..*range.end() + 1,
778 |cursor, number| cursor.get_two::<HeaderWithHashMask<Header>>(number.into()),
779 )? {
780 let (header, hash) = header?;
781 self.headers.push_back(SealedHeader::new(header, hash));
782 }
783
784 Ok(())
785 }
786 }
787
788 impl Stream for TestBodyDownloader {
789 type Item = BodyDownloaderResult<Block>;
790 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
791 let this = self.get_mut();
792
793 if this.headers.is_empty() {
794 return Poll::Ready(None)
795 }
796
797 let mut response =
798 Vec::with_capacity(std::cmp::min(this.headers.len(), this.batch_size as usize));
799 while let Some(header) = this.headers.pop_front() {
800 if header.is_empty() {
801 response.push(BlockResponse::Empty(header))
802 } else {
803 let body =
804 this.responses.remove(&header.hash()).expect("requested unknown body");
805 response.push(BlockResponse::Full(SealedBlock::from_sealed_parts(
806 header, body,
807 )));
808 }
809
810 if response.len() as u64 >= this.batch_size {
811 break
812 }
813 }
814
815 if !response.is_empty() {
816 return Poll::Ready(Some(Ok(response)))
817 }
818
819 panic!("requested bodies without setting headers")
820 }
821 }
822 }
823}