1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use futures_util::StreamExt;
4use reth_config::config::EtlConfig;
5use reth_db_api::{
6 cursor::{DbCursorRO, DbCursorRW},
7 table::Value,
8 tables,
9 transaction::{DbTx, DbTxMut},
10 DbTxUnwindExt, RawKey, RawTable, RawValue,
11};
12use reth_etl::Collector;
13use reth_network_p2p::headers::{
14 downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
15 error::HeadersDownloaderError,
16};
17use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader, NodePrimitives, SealedHeader};
18use reth_provider::{
19 providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderProvider,
20 HeaderSyncGapProvider, StaticFileProviderFactory,
21};
22use reth_stages_api::{
23 CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
24 StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
25};
26use reth_static_file_types::StaticFileSegment;
27use reth_storage_errors::provider::ProviderError;
28use std::task::{ready, Context, Poll};
29
30use tokio::sync::watch;
31use tracing::*;
32
33#[derive(Debug)]
45pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
46 provider: Provider,
48 downloader: Downloader,
50 tip: watch::Receiver<B256>,
54 sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
56 hash_collector: Collector<BlockHash, BlockNumber>,
58 header_collector: Collector<BlockNumber, Bytes>,
60 is_etl_ready: bool,
62}
63
64impl<Provider, Downloader> HeaderStage<Provider, Downloader>
67where
68 Downloader: HeaderDownloader,
69{
70 pub fn new(
72 database: Provider,
73 downloader: Downloader,
74 tip: watch::Receiver<B256>,
75 etl_config: EtlConfig,
76 ) -> Self {
77 Self {
78 provider: database,
79 downloader,
80 tip,
81 sync_gap: None,
82 hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
83 header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
84 is_etl_ready: false,
85 }
86 }
87
88 fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
93 where
94 P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
95 Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
96 <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
97 {
98 let total_headers = self.header_collector.len();
99
100 info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
101
102 let static_file_provider = provider.static_file_provider();
103
104 let mut last_header_number = static_file_provider
107 .get_highest_static_file_block(StaticFileSegment::Headers)
108 .unwrap_or_default();
109
110 let mut td = static_file_provider
112 .header_td_by_number(last_header_number)?
113 .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;
114
115 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
118 let interval = (total_headers / 10).max(1);
119 for (index, header) in self.header_collector.iter()?.enumerate() {
120 let (_, header_buf) = header?;
121
122 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
123 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
124 }
125
126 let sealed_header: SealedHeader<Downloader::Header> =
127 bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
128 .map_err(|err| StageError::Fatal(Box::new(err)))?
129 .into();
130
131 let (header, header_hash) = sealed_header.split_ref();
132 if header.number() == 0 {
133 continue
134 }
135 last_header_number = header.number();
136
137 td += header.difficulty();
139
140 writer.append_header(header, td, header_hash)?;
142 }
143
144 info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
145
146 let mut cursor_header_numbers =
147 provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
148 let first_sync = if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 &&
151 let Some((hash, block_number)) = cursor_header_numbers.last()? &&
152 block_number.value()? == 0
153 {
154 self.hash_collector.insert(hash.key()?, 0)?;
155 cursor_header_numbers.delete_current()?;
156 true
157 } else {
158 false
159 };
160
161 for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
164 let (hash, number) = hash_to_number?;
165
166 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
167 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
168 }
169
170 if first_sync {
171 cursor_header_numbers.append(
172 RawKey::<BlockHash>::from_vec(hash),
173 &RawValue::<BlockNumber>::from_vec(number),
174 )?;
175 } else {
176 cursor_header_numbers.upsert(
177 RawKey::<BlockHash>::from_vec(hash),
178 &RawValue::<BlockNumber>::from_vec(number),
179 )?;
180 }
181 }
182
183 Ok(last_header_number)
184 }
185}
186
187impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
188where
189 Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
190 P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
191 D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
192 <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
193{
194 fn id(&self) -> StageId {
196 StageId::Headers
197 }
198
199 fn poll_execute_ready(
200 &mut self,
201 cx: &mut Context<'_>,
202 input: ExecInput,
203 ) -> Poll<Result<(), StageError>> {
204 let current_checkpoint = input.checkpoint();
205
206 if self.is_etl_ready {
208 return Poll::Ready(Ok(()))
209 }
210
211 let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
213 let target = SyncTarget::Tip(*self.tip.borrow());
214 let gap = HeaderSyncGap { local_head, target };
215 let tip = gap.target.tip();
216
217 if gap.is_closed() {
219 info!(
220 target: "sync::stages::headers",
221 checkpoint = %current_checkpoint.block_number,
222 target = ?tip,
223 "Target block already reached"
224 );
225 self.is_etl_ready = true;
226 self.sync_gap = Some(gap);
227 return Poll::Ready(Ok(()))
228 }
229
230 debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
231 let local_head_number = gap.local_head.number();
232
233 if self.sync_gap != Some(gap.clone()) {
235 self.sync_gap = Some(gap.clone());
236 self.downloader.update_sync_gap(gap.local_head, gap.target);
237 }
238
239 loop {
241 match ready!(self.downloader.poll_next_unpin(cx)) {
242 Some(Ok(headers)) => {
243 info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
244 for header in headers {
245 let header_number = header.number();
246
247 self.hash_collector.insert(header.hash(), header_number)?;
248 self.header_collector.insert(
249 header_number,
250 Bytes::from(
251 bincode::serialize(&serde_bincode_compat::SealedHeader::from(
252 &header,
253 ))
254 .map_err(|err| StageError::Fatal(Box::new(err)))?,
255 ),
256 )?;
257
258 if header_number == local_head_number + 1 {
261 self.is_etl_ready = true;
262 return Poll::Ready(Ok(()))
263 }
264 }
265 }
266 Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
267 error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
268 self.sync_gap = None;
269 return Poll::Ready(Err(StageError::DetachedHead {
270 local_head: Box::new(local_head.block_with_parent()),
271 header: Box::new(header.block_with_parent()),
272 error,
273 }))
274 }
275 None => {
276 self.sync_gap = None;
277 return Poll::Ready(Err(StageError::ChannelClosed))
278 }
279 }
280 }
281 }
282
283 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
286 let current_checkpoint = input.checkpoint();
287
288 if self.sync_gap.take().ok_or(StageError::MissingSyncGap)?.is_closed() {
289 self.is_etl_ready = false;
290 return Ok(ExecOutput::done(current_checkpoint))
291 }
292
293 if !self.is_etl_ready {
295 return Err(StageError::MissingDownloadBuffer)
296 }
297
298 self.is_etl_ready = false;
300
301 let to_be_processed = self.hash_collector.len() as u64;
303 let last_header_number = self.write_headers(provider)?;
304
305 self.hash_collector.clear();
307 self.header_collector.clear();
308
309 Ok(ExecOutput {
310 checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
311 HeadersCheckpoint {
312 block_range: CheckpointBlockRange {
313 from: input.checkpoint().block_number,
314 to: last_header_number,
315 },
316 progress: EntitiesCheckpoint {
317 processed: input.checkpoint().block_number + to_be_processed,
318 total: last_header_number,
319 },
320 },
321 ),
322 done: true,
325 })
326 }
327
328 fn unwind(
330 &mut self,
331 provider: &Provider,
332 input: UnwindInput,
333 ) -> Result<UnwindOutput, StageError> {
334 self.sync_gap.take();
335
336 provider
340 .tx_ref()
341 .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
342 (input.unwind_to + 1)..,
343 )?;
344 provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
345 provider
346 .tx_ref()
347 .unwind_table_by_num::<tables::HeaderTerminalDifficulties>(input.unwind_to)?;
348 let unfinalized_headers_unwound =
349 provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
350
351 let static_file_provider = provider.static_file_provider();
354 let highest_block = static_file_provider
355 .get_highest_static_file_block(StaticFileSegment::Headers)
356 .unwrap_or_default();
357 let static_file_headers_to_unwind = highest_block - input.unwind_to;
358 for block_number in (input.unwind_to + 1)..=highest_block {
359 let hash = static_file_provider.block_hash(block_number)?;
360 if let Some(header_hash) = hash {
366 provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
367 }
368 }
369
370 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
372 writer.prune_headers(static_file_headers_to_unwind)?;
373
374 let stage_checkpoint =
377 input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
378 block_range: stage_checkpoint.block_range,
379 progress: EntitiesCheckpoint {
380 processed: stage_checkpoint.progress.processed.saturating_sub(
381 static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
382 ),
383 total: stage_checkpoint.progress.total,
384 },
385 });
386
387 let mut checkpoint = StageCheckpoint::new(input.unwind_to);
388 if let Some(stage_checkpoint) = stage_checkpoint {
389 checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
390 }
391
392 Ok(UnwindOutput { checkpoint })
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::test_utils::{
400 stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
401 };
402 use alloy_primitives::B256;
403 use assert_matches::assert_matches;
404 use reth_provider::{DatabaseProviderFactory, ProviderFactory, StaticFileProviderFactory};
405 use reth_stages_api::StageUnitCheckpoint;
406 use reth_testing_utils::generators::{self, random_header, random_header_range};
407 use std::sync::Arc;
408 use test_runner::HeadersTestRunner;
409
410 mod test_runner {
411 use super::*;
412 use crate::test_utils::{TestRunnerError, TestStageDB};
413 use reth_consensus::test_utils::TestConsensus;
414 use reth_downloaders::headers::reverse_headers::{
415 ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
416 };
417 use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
418 use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader};
419 use tokio::sync::watch;
420
421 pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
422 pub(crate) client: TestHeadersClient,
423 channel: (watch::Sender<B256>, watch::Receiver<B256>),
424 downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
425 db: TestStageDB,
426 }
427
428 impl Default for HeadersTestRunner<TestHeaderDownloader> {
429 fn default() -> Self {
430 let client = TestHeadersClient::default();
431 Self {
432 client: client.clone(),
433 channel: watch::channel(B256::ZERO),
434
435 downloader_factory: Box::new(move || {
436 TestHeaderDownloader::new(client.clone(), 1000, 1000)
437 }),
438 db: TestStageDB::default(),
439 }
440 }
441 }
442
443 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
444 for HeadersTestRunner<D>
445 {
446 type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
447
448 fn db(&self) -> &TestStageDB {
449 &self.db
450 }
451
452 fn stage(&self) -> Self::S {
453 HeaderStage::new(
454 self.db.factory.clone(),
455 (*self.downloader_factory)(),
456 self.channel.1.clone(),
457 EtlConfig::default(),
458 )
459 }
460 }
461
462 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
463 for HeadersTestRunner<D>
464 {
465 type Seed = Vec<SealedHeader>;
466
467 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
468 let mut rng = generators::rng();
469 let start = input.checkpoint().block_number;
470 let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
471 let head = headers.last().cloned().unwrap();
472 self.db.insert_headers_with_td(headers.iter())?;
473
474 let end = input.target.unwrap_or_default() + 1;
476
477 if start + 1 >= end {
478 return Ok(Vec::default())
479 }
480
481 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
482 headers.insert(0, head);
483 Ok(headers)
484 }
485
486 fn validate_execution(
488 &self,
489 input: ExecInput,
490 output: Option<ExecOutput>,
491 ) -> Result<(), TestRunnerError> {
492 let initial_checkpoint = input.checkpoint().block_number;
493 match output {
494 Some(output) if output.checkpoint.block_number > initial_checkpoint => {
495 let provider = self.db.factory.provider()?;
496 let mut td = provider
497 .header_td_by_number(initial_checkpoint.saturating_sub(1))?
498 .unwrap_or_default();
499
500 for block_num in initial_checkpoint..output.checkpoint.block_number {
501 let hash = provider.block_hash(block_num)?.expect("no header hash");
503
504 assert_eq!(provider.block_number(hash)?, Some(block_num));
506
507 let header = provider.header_by_number(block_num)?;
509 assert!(header.is_some());
510 let header = SealedHeader::seal_slow(header.unwrap());
511 assert_eq!(header.hash(), hash);
512
513 td += header.difficulty;
515 assert_eq!(provider.header_td_by_number(block_num)?, Some(td));
516 }
517 }
518 _ => self.check_no_header_entry_above(initial_checkpoint)?,
519 };
520 Ok(())
521 }
522
523 async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
524 self.client.extend(headers.iter().map(|h| h.clone_header())).await;
525 let tip = if headers.is_empty() {
526 let tip = random_header(&mut generators::rng(), 0, None);
527 self.db.insert_headers(std::iter::once(&tip))?;
528 tip.hash()
529 } else {
530 headers.last().unwrap().hash()
531 };
532 self.send_tip(tip);
533 Ok(())
534 }
535 }
536
537 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
538 for HeadersTestRunner<D>
539 {
540 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
541 self.check_no_header_entry_above(input.unwind_to)
542 }
543 }
544
545 impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
546 pub(crate) fn with_linear_downloader() -> Self {
547 let client = TestHeadersClient::default();
548 Self {
549 client: client.clone(),
550 channel: watch::channel(B256::ZERO),
551 downloader_factory: Box::new(move || {
552 ReverseHeadersDownloaderBuilder::default()
553 .stream_batch_size(500)
554 .build(client.clone(), Arc::new(TestConsensus::default()))
555 }),
556 db: TestStageDB::default(),
557 }
558 }
559 }
560
561 impl<D: HeaderDownloader> HeadersTestRunner<D> {
562 pub(crate) fn check_no_header_entry_above(
563 &self,
564 block: BlockNumber,
565 ) -> Result<(), TestRunnerError> {
566 self.db
567 .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
568 self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
569 self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
570 self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
571 block,
572 |num| num,
573 )?;
574 Ok(())
575 }
576
577 pub(crate) fn send_tip(&self, tip: B256) {
578 self.channel.0.send(tip).expect("failed to send tip");
579 }
580 }
581 }
582
583 stage_test_suite!(HeadersTestRunner, headers);
584
585 #[tokio::test]
588 async fn execute_with_linear_downloader_unwind() {
589 let mut runner = HeadersTestRunner::with_linear_downloader();
590 let (checkpoint, previous_stage) = (1000, 1200);
591 let input = ExecInput {
592 target: Some(previous_stage),
593 checkpoint: Some(StageCheckpoint::new(checkpoint)),
594 };
595 let headers = runner.seed_execution(input).expect("failed to seed execution");
596 let rx = runner.execute(input);
597
598 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
599
600 let tip = headers.last().unwrap();
602 runner.send_tip(tip.hash());
603
604 let result = rx.await.unwrap();
605 runner.db().factory.static_file_provider().commit().unwrap();
606 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
607 block_number,
608 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
609 block_range: CheckpointBlockRange {
610 from,
611 to
612 },
613 progress: EntitiesCheckpoint {
614 processed,
615 total,
616 }
617 }))
618 }, done: true }) if block_number == tip.number &&
619 from == checkpoint && to == previous_stage &&
620 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
622 );
623 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
624 assert!(runner.stage().hash_collector.is_empty());
625 assert!(runner.stage().header_collector.is_empty());
626
627 let sealed_headers = random_header_range(
629 &mut generators::rng(),
630 tip.number + 1..tip.number + 10,
631 tip.hash(),
632 );
633
634 let provider = runner.db().factory.database_provider_rw().unwrap();
635 let static_file_provider = provider.static_file_provider();
636 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers).unwrap();
637 for header in sealed_headers {
638 let ttd = if header.number() == 0 {
639 header.difficulty()
640 } else {
641 let parent_block_number = header.number() - 1;
642 let parent_ttd =
643 provider.header_td_by_number(parent_block_number).unwrap().unwrap_or_default();
644 parent_ttd + header.difficulty()
645 };
646
647 writer.append_header(header.header(), ttd, &header.hash()).unwrap();
648 }
649 drop(writer);
650
651 provider.commit().unwrap();
652
653 let unwind_input = UnwindInput {
655 checkpoint: StageCheckpoint::new(tip.number + 10),
656 unwind_to: tip.number,
657 bad_block: None,
658 };
659
660 let unwind_output = runner.unwind(unwind_input).await.unwrap();
661 assert_eq!(unwind_output.checkpoint.block_number, tip.number);
662
663 assert!(runner.validate_unwind(unwind_input).is_ok());
665 }
666
667 #[tokio::test]
669 async fn execute_with_linear_downloader() {
670 let mut runner = HeadersTestRunner::with_linear_downloader();
671 let (checkpoint, previous_stage) = (1000, 1200);
672 let input = ExecInput {
673 target: Some(previous_stage),
674 checkpoint: Some(StageCheckpoint::new(checkpoint)),
675 };
676 let headers = runner.seed_execution(input).expect("failed to seed execution");
677 let rx = runner.execute(input);
678
679 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
680
681 let tip = headers.last().unwrap();
683 runner.send_tip(tip.hash());
684
685 let result = rx.await.unwrap();
686 runner.db().factory.static_file_provider().commit().unwrap();
687 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
688 block_number,
689 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
690 block_range: CheckpointBlockRange {
691 from,
692 to
693 },
694 progress: EntitiesCheckpoint {
695 processed,
696 total,
697 }
698 }))
699 }, done: true }) if block_number == tip.number &&
700 from == checkpoint && to == previous_stage &&
701 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
703 );
704 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
705 assert!(runner.stage().hash_collector.is_empty());
706 assert!(runner.stage().header_collector.is_empty());
707 }
708}