1use alloy_consensus::BlockHeader;
2use alloy_eips::{eip1898::BlockWithParent, NumHash};
3use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
4use futures_util::StreamExt;
5use reth_config::config::EtlConfig;
6use reth_consensus::HeaderValidator;
7use reth_db_api::{
8 cursor::{DbCursorRO, DbCursorRW},
9 table::Value,
10 tables,
11 transaction::{DbTx, DbTxMut},
12 DbTxUnwindExt, RawKey, RawTable, RawValue,
13};
14use reth_etl::Collector;
15use reth_network_p2p::headers::{
16 downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
17 error::HeadersDownloaderError,
18};
19use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader, NodePrimitives, SealedHeader};
20use reth_provider::{
21 providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderProvider,
22 HeaderSyncGapProvider, StaticFileProviderFactory,
23};
24use reth_stages_api::{
25 BlockErrorKind, CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput,
26 HeadersCheckpoint, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
27};
28use reth_static_file_types::StaticFileSegment;
29use reth_storage_errors::provider::ProviderError;
30use std::{
31 sync::Arc,
32 task::{ready, Context, Poll},
33};
34use tokio::sync::watch;
35use tracing::*;
36
37#[derive(Debug)]
49pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
50 provider: Provider,
52 downloader: Downloader,
54 tip: watch::Receiver<B256>,
58 consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
60 sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
62 hash_collector: Collector<BlockHash, BlockNumber>,
64 header_collector: Collector<BlockNumber, Bytes>,
66 is_etl_ready: bool,
68}
69
70impl<Provider, Downloader> HeaderStage<Provider, Downloader>
73where
74 Downloader: HeaderDownloader,
75{
76 pub fn new(
78 database: Provider,
79 downloader: Downloader,
80 tip: watch::Receiver<B256>,
81 consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
82 etl_config: EtlConfig,
83 ) -> Self {
84 Self {
85 provider: database,
86 downloader,
87 tip,
88 consensus,
89 sync_gap: None,
90 hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
91 header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
92 is_etl_ready: false,
93 }
94 }
95
96 fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
101 where
102 P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
103 Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
104 <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
105 {
106 let total_headers = self.header_collector.len();
107
108 info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
109
110 let static_file_provider = provider.static_file_provider();
111
112 let mut last_header_number = static_file_provider
115 .get_highest_static_file_block(StaticFileSegment::Headers)
116 .unwrap_or_default();
117
118 let mut td = static_file_provider
120 .header_td_by_number(last_header_number)?
121 .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;
122
123 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
126 let interval = (total_headers / 10).max(1);
127 for (index, header) in self.header_collector.iter()?.enumerate() {
128 let (_, header_buf) = header?;
129
130 if index > 0 && index % interval == 0 && total_headers > 100 {
131 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
132 }
133
134 let sealed_header: SealedHeader<Downloader::Header> =
135 bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
136 .map_err(|err| StageError::Fatal(Box::new(err)))?
137 .into();
138
139 let (header, header_hash) = sealed_header.split_ref();
140 if header.number() == 0 {
141 continue
142 }
143 last_header_number = header.number();
144
145 td += header.difficulty();
147
148 self.consensus.validate_header(&sealed_header).map_err(|error| StageError::Block {
149 block: Box::new(BlockWithParent::new(
150 sealed_header.parent_hash(),
151 NumHash::new(sealed_header.number(), sealed_header.hash()),
152 )),
153 error: BlockErrorKind::Validation(error),
154 })?;
155
156 writer.append_header(header, td, header_hash)?;
158 }
159
160 info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
161
162 let mut cursor_header_numbers =
163 provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
164 let mut first_sync = false;
165
166 if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 {
169 if let Some((hash, block_number)) = cursor_header_numbers.last()? {
170 if block_number.value()? == 0 {
171 self.hash_collector.insert(hash.key()?, 0)?;
172 cursor_header_numbers.delete_current()?;
173 first_sync = true;
174 }
175 }
176 }
177
178 for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
181 let (hash, number) = hash_to_number?;
182
183 if index > 0 && index % interval == 0 && total_headers > 100 {
184 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
185 }
186
187 if first_sync {
188 cursor_header_numbers.append(
189 RawKey::<BlockHash>::from_vec(hash),
190 &RawValue::<BlockNumber>::from_vec(number),
191 )?;
192 } else {
193 cursor_header_numbers.upsert(
194 RawKey::<BlockHash>::from_vec(hash),
195 &RawValue::<BlockNumber>::from_vec(number),
196 )?;
197 }
198 }
199
200 Ok(last_header_number)
201 }
202}
203
204impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
205where
206 Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
207 P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
208 D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
209 <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
210{
211 fn id(&self) -> StageId {
213 StageId::Headers
214 }
215
216 fn poll_execute_ready(
217 &mut self,
218 cx: &mut Context<'_>,
219 input: ExecInput,
220 ) -> Poll<Result<(), StageError>> {
221 let current_checkpoint = input.checkpoint();
222
223 if self.is_etl_ready {
225 return Poll::Ready(Ok(()))
226 }
227
228 let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
230 let target = SyncTarget::Tip(*self.tip.borrow());
231 let gap = HeaderSyncGap { local_head, target };
232 let tip = gap.target.tip();
233 self.sync_gap = Some(gap.clone());
234
235 if gap.is_closed() {
237 info!(
238 target: "sync::stages::headers",
239 checkpoint = %current_checkpoint.block_number,
240 target = ?tip,
241 "Target block already reached"
242 );
243 self.is_etl_ready = true;
244 return Poll::Ready(Ok(()))
245 }
246
247 debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
248 let local_head_number = gap.local_head.number();
249
250 self.downloader.update_sync_gap(gap.local_head, gap.target);
252
253 loop {
255 match ready!(self.downloader.poll_next_unpin(cx)) {
256 Some(Ok(headers)) => {
257 info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
258 for header in headers {
259 let header_number = header.number();
260
261 self.hash_collector.insert(header.hash(), header_number)?;
262 self.header_collector.insert(
263 header_number,
264 Bytes::from(
265 bincode::serialize(&serde_bincode_compat::SealedHeader::from(
266 &header,
267 ))
268 .map_err(|err| StageError::Fatal(Box::new(err)))?,
269 ),
270 )?;
271
272 if header_number == local_head_number + 1 {
275 self.is_etl_ready = true;
276 return Poll::Ready(Ok(()))
277 }
278 }
279 }
280 Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
281 error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
282 return Poll::Ready(Err(StageError::DetachedHead {
283 local_head: Box::new(local_head.block_with_parent()),
284 header: Box::new(header.block_with_parent()),
285 error,
286 }))
287 }
288 None => return Poll::Ready(Err(StageError::ChannelClosed)),
289 }
290 }
291 }
292
293 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
296 let current_checkpoint = input.checkpoint();
297
298 if self.sync_gap.as_ref().ok_or(StageError::MissingSyncGap)?.is_closed() {
299 self.is_etl_ready = false;
300 return Ok(ExecOutput::done(current_checkpoint))
301 }
302
303 if !self.is_etl_ready {
305 return Err(StageError::MissingDownloadBuffer)
306 }
307
308 self.is_etl_ready = false;
310
311 let to_be_processed = self.hash_collector.len() as u64;
313 let last_header_number = self.write_headers(provider)?;
314
315 self.hash_collector.clear();
317 self.header_collector.clear();
318
319 Ok(ExecOutput {
320 checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
321 HeadersCheckpoint {
322 block_range: CheckpointBlockRange {
323 from: input.checkpoint().block_number,
324 to: last_header_number,
325 },
326 progress: EntitiesCheckpoint {
327 processed: input.checkpoint().block_number + to_be_processed,
328 total: last_header_number,
329 },
330 },
331 ),
332 done: true,
335 })
336 }
337
338 fn unwind(
340 &mut self,
341 provider: &Provider,
342 input: UnwindInput,
343 ) -> Result<UnwindOutput, StageError> {
344 self.sync_gap.take();
345
346 provider
350 .tx_ref()
351 .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
352 (input.unwind_to + 1)..,
353 )?;
354 provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
355 provider
356 .tx_ref()
357 .unwind_table_by_num::<tables::HeaderTerminalDifficulties>(input.unwind_to)?;
358 let unfinalized_headers_unwound =
359 provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
360
361 let static_file_provider = provider.static_file_provider();
364 let highest_block = static_file_provider
365 .get_highest_static_file_block(StaticFileSegment::Headers)
366 .unwrap_or_default();
367 let static_file_headers_to_unwind = highest_block - input.unwind_to;
368 for block_number in (input.unwind_to + 1)..=highest_block {
369 let hash = static_file_provider.block_hash(block_number)?;
370 if let Some(header_hash) = hash {
376 provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
377 }
378 }
379
380 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
382 writer.prune_headers(static_file_headers_to_unwind)?;
383
384 let stage_checkpoint =
387 input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
388 block_range: stage_checkpoint.block_range,
389 progress: EntitiesCheckpoint {
390 processed: stage_checkpoint.progress.processed.saturating_sub(
391 static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
392 ),
393 total: stage_checkpoint.progress.total,
394 },
395 });
396
397 let mut checkpoint = StageCheckpoint::new(input.unwind_to);
398 if let Some(stage_checkpoint) = stage_checkpoint {
399 checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
400 }
401
402 Ok(UnwindOutput { checkpoint })
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use crate::test_utils::{
410 stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
411 };
412 use alloy_primitives::B256;
413 use assert_matches::assert_matches;
414 use reth_ethereum_primitives::BlockBody;
415 use reth_execution_types::ExecutionOutcome;
416 use reth_primitives_traits::{RecoveredBlock, SealedBlock};
417 use reth_provider::{BlockWriter, ProviderFactory, StaticFileProviderFactory};
418 use reth_stages_api::StageUnitCheckpoint;
419 use reth_testing_utils::generators::{self, random_header, random_header_range};
420 use reth_trie::{updates::TrieUpdates, HashedPostStateSorted};
421 use test_runner::HeadersTestRunner;
422
423 mod test_runner {
424 use super::*;
425 use crate::test_utils::{TestRunnerError, TestStageDB};
426 use reth_consensus::test_utils::TestConsensus;
427 use reth_downloaders::headers::reverse_headers::{
428 ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
429 };
430 use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
431 use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader};
432 use tokio::sync::watch;
433
434 pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
435 pub(crate) client: TestHeadersClient,
436 channel: (watch::Sender<B256>, watch::Receiver<B256>),
437 downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
438 db: TestStageDB,
439 consensus: Arc<TestConsensus>,
440 }
441
442 impl Default for HeadersTestRunner<TestHeaderDownloader> {
443 fn default() -> Self {
444 let client = TestHeadersClient::default();
445 Self {
446 client: client.clone(),
447 channel: watch::channel(B256::ZERO),
448 consensus: Arc::new(TestConsensus::default()),
449 downloader_factory: Box::new(move || {
450 TestHeaderDownloader::new(
451 client.clone(),
452 Arc::new(TestConsensus::default()),
453 1000,
454 1000,
455 )
456 }),
457 db: TestStageDB::default(),
458 }
459 }
460 }
461
462 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
463 for HeadersTestRunner<D>
464 {
465 type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
466
467 fn db(&self) -> &TestStageDB {
468 &self.db
469 }
470
471 fn stage(&self) -> Self::S {
472 HeaderStage::new(
473 self.db.factory.clone(),
474 (*self.downloader_factory)(),
475 self.channel.1.clone(),
476 self.consensus.clone(),
477 EtlConfig::default(),
478 )
479 }
480 }
481
482 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
483 for HeadersTestRunner<D>
484 {
485 type Seed = Vec<SealedHeader>;
486
487 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
488 let mut rng = generators::rng();
489 let start = input.checkpoint().block_number;
490 let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
491 let head = headers.last().cloned().unwrap();
492 self.db.insert_headers_with_td(headers.iter())?;
493
494 let end = input.target.unwrap_or_default() + 1;
496
497 if start + 1 >= end {
498 return Ok(Vec::default())
499 }
500
501 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
502 headers.insert(0, head);
503 Ok(headers)
504 }
505
506 fn validate_execution(
508 &self,
509 input: ExecInput,
510 output: Option<ExecOutput>,
511 ) -> Result<(), TestRunnerError> {
512 let initial_checkpoint = input.checkpoint().block_number;
513 match output {
514 Some(output) if output.checkpoint.block_number > initial_checkpoint => {
515 let provider = self.db.factory.provider()?;
516 let mut td = provider
517 .header_td_by_number(initial_checkpoint.saturating_sub(1))?
518 .unwrap_or_default();
519
520 for block_num in initial_checkpoint..output.checkpoint.block_number {
521 let hash = provider.block_hash(block_num)?.expect("no header hash");
523
524 assert_eq!(provider.block_number(hash)?, Some(block_num));
526
527 let header = provider.header_by_number(block_num)?;
529 assert!(header.is_some());
530 let header = SealedHeader::seal_slow(header.unwrap());
531 assert_eq!(header.hash(), hash);
532
533 td += header.difficulty;
535 assert_eq!(provider.header_td_by_number(block_num)?, Some(td));
536 }
537 }
538 _ => self.check_no_header_entry_above(initial_checkpoint)?,
539 };
540 Ok(())
541 }
542
543 async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
544 self.client.extend(headers.iter().map(|h| h.clone_header())).await;
545 let tip = if headers.is_empty() {
546 let tip = random_header(&mut generators::rng(), 0, None);
547 self.db.insert_headers(std::iter::once(&tip))?;
548 tip.hash()
549 } else {
550 headers.last().unwrap().hash()
551 };
552 self.send_tip(tip);
553 Ok(())
554 }
555 }
556
557 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
558 for HeadersTestRunner<D>
559 {
560 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
561 self.check_no_header_entry_above(input.unwind_to)
562 }
563 }
564
565 impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
566 pub(crate) fn with_linear_downloader() -> Self {
567 let client = TestHeadersClient::default();
568 Self {
569 client: client.clone(),
570 channel: watch::channel(B256::ZERO),
571 downloader_factory: Box::new(move || {
572 ReverseHeadersDownloaderBuilder::default()
573 .stream_batch_size(500)
574 .build(client.clone(), Arc::new(TestConsensus::default()))
575 }),
576 db: TestStageDB::default(),
577 consensus: Arc::new(TestConsensus::default()),
578 }
579 }
580 }
581
582 impl<D: HeaderDownloader> HeadersTestRunner<D> {
583 pub(crate) fn check_no_header_entry_above(
584 &self,
585 block: BlockNumber,
586 ) -> Result<(), TestRunnerError> {
587 self.db
588 .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
589 self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
590 self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
591 self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
592 block,
593 |num| num,
594 )?;
595 Ok(())
596 }
597
598 pub(crate) fn send_tip(&self, tip: B256) {
599 self.channel.0.send(tip).expect("failed to send tip");
600 }
601 }
602 }
603
604 stage_test_suite!(HeadersTestRunner, headers);
605
606 #[tokio::test]
609 async fn execute_with_linear_downloader_unwind() {
610 let mut runner = HeadersTestRunner::with_linear_downloader();
611 let (checkpoint, previous_stage) = (1000, 1200);
612 let input = ExecInput {
613 target: Some(previous_stage),
614 checkpoint: Some(StageCheckpoint::new(checkpoint)),
615 };
616 let headers = runner.seed_execution(input).expect("failed to seed execution");
617 let rx = runner.execute(input);
618
619 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
620
621 let tip = headers.last().unwrap();
623 runner.send_tip(tip.hash());
624
625 let result = rx.await.unwrap();
626 runner.db().factory.static_file_provider().commit().unwrap();
627 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
628 block_number,
629 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
630 block_range: CheckpointBlockRange {
631 from,
632 to
633 },
634 progress: EntitiesCheckpoint {
635 processed,
636 total,
637 }
638 }))
639 }, done: true }) if block_number == tip.number &&
640 from == checkpoint && to == previous_stage &&
641 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
643 );
644 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
645 assert!(runner.stage().hash_collector.is_empty());
646 assert!(runner.stage().header_collector.is_empty());
647
648 let sealed_headers =
650 random_header_range(&mut generators::rng(), tip.number..tip.number + 10, tip.hash());
651
652 let sealed_blocks = sealed_headers
654 .iter()
655 .map(|header| {
656 RecoveredBlock::new_sealed(
657 SealedBlock::from_sealed_parts(header.clone(), BlockBody::default()),
658 vec![],
659 )
660 })
661 .collect();
662
663 let provider = runner.db().factory.provider_rw().unwrap();
665 provider
666 .append_blocks_with_state(
667 sealed_blocks,
668 &ExecutionOutcome::default(),
669 HashedPostStateSorted::default(),
670 TrieUpdates::default(),
671 )
672 .unwrap();
673 provider.commit().unwrap();
674
675 let unwind_input = UnwindInput {
677 checkpoint: StageCheckpoint::new(tip.number + 10),
678 unwind_to: tip.number,
679 bad_block: None,
680 };
681
682 let unwind_output = runner.unwind(unwind_input).await.unwrap();
683 assert_eq!(unwind_output.checkpoint.block_number, tip.number);
684
685 assert!(runner.validate_unwind(unwind_input).is_ok());
687 }
688
689 #[tokio::test]
691 async fn execute_with_linear_downloader() {
692 let mut runner = HeadersTestRunner::with_linear_downloader();
693 let (checkpoint, previous_stage) = (1000, 1200);
694 let input = ExecInput {
695 target: Some(previous_stage),
696 checkpoint: Some(StageCheckpoint::new(checkpoint)),
697 };
698 let headers = runner.seed_execution(input).expect("failed to seed execution");
699 let rx = runner.execute(input);
700
701 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
702
703 let tip = headers.last().unwrap();
705 runner.send_tip(tip.hash());
706
707 let result = rx.await.unwrap();
708 runner.db().factory.static_file_provider().commit().unwrap();
709 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
710 block_number,
711 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
712 block_range: CheckpointBlockRange {
713 from,
714 to
715 },
716 progress: EntitiesCheckpoint {
717 processed,
718 total,
719 }
720 }))
721 }, done: true }) if block_number == tip.number &&
722 from == checkpoint && to == previous_stage &&
723 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
725 );
726 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
727 assert!(runner.stage().hash_collector.is_empty());
728 assert!(runner.stage().header_collector.is_empty());
729 }
730}