1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use futures_util::StreamExt;
4use reth_config::config::EtlConfig;
5use reth_db_api::{
6 cursor::{DbCursorRO, DbCursorRW},
7 table::Value,
8 tables,
9 transaction::{DbTx, DbTxMut},
10 DbTxUnwindExt, RawKey, RawTable, RawValue,
11};
12use reth_etl::Collector;
13use reth_network_p2p::headers::{
14 downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
15 error::HeadersDownloaderError,
16};
17use reth_primitives_traits::{
18 serde_bincode_compat, FullBlockHeader, HeaderTy, NodePrimitives, SealedHeader,
19};
20use reth_provider::{
21 providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderSyncGapProvider,
22 StaticFileProviderFactory,
23};
24use reth_stages_api::{
25 CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
26 StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
27};
28use reth_static_file_types::StaticFileSegment;
29use std::task::{ready, Context, Poll};
30
31use tokio::sync::watch;
32use tracing::*;
33
34#[derive(Debug)]
46pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
47 provider: Provider,
49 downloader: Downloader,
51 tip: watch::Receiver<B256>,
55 sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
57 hash_collector: Collector<BlockHash, BlockNumber>,
59 header_collector: Collector<BlockNumber, Bytes>,
61 is_etl_ready: bool,
63}
64
65impl<Provider, Downloader> HeaderStage<Provider, Downloader>
68where
69 Downloader: HeaderDownloader,
70{
71 pub fn new(
73 database: Provider,
74 downloader: Downloader,
75 tip: watch::Receiver<B256>,
76 etl_config: EtlConfig,
77 ) -> Self {
78 Self {
79 provider: database,
80 downloader,
81 tip,
82 sync_gap: None,
83 hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
84 header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
85 is_etl_ready: false,
86 }
87 }
88
89 fn clear_etl_state(&mut self) {
91 self.sync_gap = None;
92 self.hash_collector.clear();
93 self.header_collector.clear();
94 self.is_etl_ready = false;
95 }
96
97 fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
102 where
103 P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
104 Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
105 <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
106 {
107 let total_headers = self.header_collector.len();
108
109 info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
110
111 let static_file_provider = provider.static_file_provider();
112
113 let mut last_header_number = static_file_provider
116 .get_highest_static_file_block(StaticFileSegment::Headers)
117 .unwrap_or_default();
118
119 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
122 let interval = (total_headers / 10).max(1);
123 for (index, header) in self.header_collector.iter()?.enumerate() {
124 let (_, header_buf) = header?;
125
126 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
127 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
128 }
129
130 let sealed_header: SealedHeader<Downloader::Header> =
131 bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
132 .map_err(|err| StageError::Fatal(Box::new(err)))?
133 .into();
134
135 let (header, header_hash) = sealed_header.split_ref();
136 if header.number() == 0 {
137 continue
138 }
139 last_header_number = header.number();
140
141 writer.append_header(header, header_hash)?;
143 }
144
145 info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
146
147 let mut cursor_header_numbers =
148 provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
149 let first_sync = if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 &&
152 let Some((hash, block_number)) = cursor_header_numbers.last()? &&
153 block_number.value()? == 0
154 {
155 self.hash_collector.insert(hash.key()?, 0)?;
156 cursor_header_numbers.delete_current()?;
157 true
158 } else {
159 false
160 };
161
162 for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
165 let (hash, number) = hash_to_number?;
166
167 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
168 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
169 }
170
171 if first_sync {
172 cursor_header_numbers.append(
173 RawKey::<BlockHash>::from_vec(hash),
174 &RawValue::<BlockNumber>::from_vec(number),
175 )?;
176 } else {
177 cursor_header_numbers.upsert(
178 RawKey::<BlockHash>::from_vec(hash),
179 &RawValue::<BlockNumber>::from_vec(number),
180 )?;
181 }
182 }
183
184 Ok(last_header_number)
185 }
186}
187
188impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
189where
190 Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
191 P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
192 D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
193 <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
194{
195 fn id(&self) -> StageId {
197 StageId::Headers
198 }
199
200 fn poll_execute_ready(
201 &mut self,
202 cx: &mut Context<'_>,
203 input: ExecInput,
204 ) -> Poll<Result<(), StageError>> {
205 let current_checkpoint = input.checkpoint();
206
207 if self.is_etl_ready {
209 return Poll::Ready(Ok(()))
210 }
211
212 let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
214 let target = SyncTarget::Tip(*self.tip.borrow());
215 let gap = HeaderSyncGap { local_head, target };
216 let tip = gap.target.tip();
217
218 if gap.is_closed() {
220 info!(
221 target: "sync::stages::headers",
222 checkpoint = %current_checkpoint.block_number,
223 target = ?tip,
224 "Target block already reached"
225 );
226 self.is_etl_ready = true;
227 self.sync_gap = Some(gap);
228 return Poll::Ready(Ok(()))
229 }
230
231 debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
232 let local_head_number = gap.local_head.number();
233
234 if self.sync_gap != Some(gap.clone()) {
236 self.sync_gap = Some(gap.clone());
237 self.downloader.update_sync_gap(gap.local_head, gap.target);
238 }
239
240 loop {
242 match ready!(self.downloader.poll_next_unpin(cx)) {
243 Some(Ok(headers)) => {
244 info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
245 for header in headers {
246 let header_number = header.number();
247
248 self.hash_collector.insert(header.hash(), header_number)?;
249 self.header_collector.insert(
250 header_number,
251 Bytes::from(
252 bincode::serialize(&serde_bincode_compat::SealedHeader::from(
253 &header,
254 ))
255 .map_err(|err| StageError::Fatal(Box::new(err)))?,
256 ),
257 )?;
258
259 if header_number == local_head_number + 1 {
262 self.is_etl_ready = true;
263 return Poll::Ready(Ok(()))
264 }
265 }
266 }
267 Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
268 error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
269 self.clear_etl_state();
270 return Poll::Ready(Err(StageError::DetachedHead {
271 local_head: Box::new(local_head.block_with_parent()),
272 header: Box::new(header.block_with_parent()),
273 error,
274 }))
275 }
276 None => {
277 self.clear_etl_state();
278 return Poll::Ready(Err(StageError::ChannelClosed))
279 }
280 }
281 }
282 }
283
284 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
287 let current_checkpoint = input.checkpoint();
288
289 if self.sync_gap.take().ok_or(StageError::MissingSyncGap)?.is_closed() {
290 self.is_etl_ready = false;
291 return Ok(ExecOutput::done(current_checkpoint))
292 }
293
294 if !self.is_etl_ready {
296 return Err(StageError::MissingDownloadBuffer)
297 }
298
299 self.is_etl_ready = false;
301
302 let to_be_processed = self.hash_collector.len() as u64;
304 let last_header_number = self.write_headers(provider)?;
305
306 self.hash_collector.clear();
308 self.header_collector.clear();
309
310 Ok(ExecOutput {
311 checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
312 HeadersCheckpoint {
313 block_range: CheckpointBlockRange {
314 from: input.checkpoint().block_number,
315 to: last_header_number,
316 },
317 progress: EntitiesCheckpoint {
318 processed: input.checkpoint().block_number + to_be_processed,
319 total: last_header_number,
320 },
321 },
322 ),
323 done: true,
326 })
327 }
328
329 fn unwind(
331 &mut self,
332 provider: &Provider,
333 input: UnwindInput,
334 ) -> Result<UnwindOutput, StageError> {
335 self.clear_etl_state();
336
337 provider
341 .tx_ref()
342 .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
343 (input.unwind_to + 1)..,
344 )?;
345 provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
346 let unfinalized_headers_unwound = provider.tx_ref().unwind_table_by_num::<tables::Headers<
347 HeaderTy<Provider::Primitives>,
348 >>(input.unwind_to)?;
349
350 let static_file_provider = provider.static_file_provider();
353 let highest_block = static_file_provider
354 .get_highest_static_file_block(StaticFileSegment::Headers)
355 .unwrap_or_default();
356 let static_file_headers_to_unwind = highest_block - input.unwind_to;
357 for block_number in (input.unwind_to + 1)..=highest_block {
358 let hash = static_file_provider.block_hash(block_number)?;
359 if let Some(header_hash) = hash {
365 provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
366 }
367 }
368
369 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
371 writer.prune_headers(static_file_headers_to_unwind)?;
372
373 let stage_checkpoint =
376 input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
377 block_range: stage_checkpoint.block_range,
378 progress: EntitiesCheckpoint {
379 processed: stage_checkpoint.progress.processed.saturating_sub(
380 static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
381 ),
382 total: stage_checkpoint.progress.total,
383 },
384 });
385
386 let mut checkpoint = StageCheckpoint::new(input.unwind_to);
387 if let Some(stage_checkpoint) = stage_checkpoint {
388 checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
389 }
390
391 Ok(UnwindOutput { checkpoint })
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::test_utils::{
399 stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
400 };
401 use alloy_primitives::B256;
402 use assert_matches::assert_matches;
403 use reth_provider::{DatabaseProviderFactory, ProviderFactory, StaticFileProviderFactory};
404 use reth_stages_api::StageUnitCheckpoint;
405 use reth_testing_utils::generators::{self, random_header, random_header_range};
406 use std::sync::Arc;
407 use test_runner::HeadersTestRunner;
408
409 mod test_runner {
410 use super::*;
411 use crate::test_utils::{TestRunnerError, TestStageDB};
412 use reth_consensus::test_utils::TestConsensus;
413 use reth_downloaders::headers::reverse_headers::{
414 ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
415 };
416 use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
417 use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader, HeaderProvider};
418 use tokio::sync::watch;
419
420 pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
421 pub(crate) client: TestHeadersClient,
422 channel: (watch::Sender<B256>, watch::Receiver<B256>),
423 downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
424 db: TestStageDB,
425 }
426
427 impl Default for HeadersTestRunner<TestHeaderDownloader> {
428 fn default() -> Self {
429 let client = TestHeadersClient::default();
430 Self {
431 client: client.clone(),
432 channel: watch::channel(B256::ZERO),
433
434 downloader_factory: Box::new(move || {
435 TestHeaderDownloader::new(client.clone(), 1000, 1000)
436 }),
437 db: TestStageDB::default(),
438 }
439 }
440 }
441
442 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
443 for HeadersTestRunner<D>
444 {
445 type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
446
447 fn db(&self) -> &TestStageDB {
448 &self.db
449 }
450
451 fn stage(&self) -> Self::S {
452 HeaderStage::new(
453 self.db.factory.clone(),
454 (*self.downloader_factory)(),
455 self.channel.1.clone(),
456 EtlConfig::default(),
457 )
458 }
459 }
460
461 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
462 for HeadersTestRunner<D>
463 {
464 type Seed = Vec<SealedHeader>;
465
466 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
467 let mut rng = generators::rng();
468 let start = input.checkpoint().block_number;
469 let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
470 let head = headers.last().cloned().unwrap();
471 self.db.insert_headers(headers.iter())?;
472
473 let end = input.target.unwrap_or_default() + 1;
475
476 if start + 1 >= end {
477 return Ok(Vec::default())
478 }
479
480 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
481 headers.insert(0, head);
482 Ok(headers)
483 }
484
485 fn validate_execution(
487 &self,
488 input: ExecInput,
489 output: Option<ExecOutput>,
490 ) -> Result<(), TestRunnerError> {
491 let initial_checkpoint = input.checkpoint().block_number;
492 match output {
493 Some(output) if output.checkpoint.block_number > initial_checkpoint => {
494 let provider = self.db.factory.provider()?;
495
496 for block_num in initial_checkpoint..output.checkpoint.block_number {
497 let hash = provider.block_hash(block_num)?.expect("no header hash");
499
500 assert_eq!(provider.block_number(hash)?, Some(block_num));
502
503 let header = provider.header_by_number(block_num)?;
505 assert!(header.is_some());
506 let header = SealedHeader::seal_slow(header.unwrap());
507 assert_eq!(header.hash(), hash);
508 }
509 }
510 _ => self.check_no_header_entry_above(initial_checkpoint)?,
511 };
512 Ok(())
513 }
514
515 async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
516 self.client.extend(headers.iter().map(|h| h.clone_header())).await;
517 let tip = if headers.is_empty() {
518 let tip = random_header(&mut generators::rng(), 0, None);
519 self.db.insert_headers(std::iter::once(&tip))?;
520 tip.hash()
521 } else {
522 headers.last().unwrap().hash()
523 };
524 self.send_tip(tip);
525 Ok(())
526 }
527 }
528
529 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
530 for HeadersTestRunner<D>
531 {
532 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
533 self.check_no_header_entry_above(input.unwind_to)
534 }
535 }
536
537 impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
538 pub(crate) fn with_linear_downloader() -> Self {
539 let client = TestHeadersClient::default();
540 Self {
541 client: client.clone(),
542 channel: watch::channel(B256::ZERO),
543 downloader_factory: Box::new(move || {
544 ReverseHeadersDownloaderBuilder::default()
545 .stream_batch_size(500)
546 .build(client.clone(), Arc::new(TestConsensus::default()))
547 }),
548 db: TestStageDB::default(),
549 }
550 }
551 }
552
553 impl<D: HeaderDownloader> HeadersTestRunner<D> {
554 pub(crate) fn check_no_header_entry_above(
555 &self,
556 block: BlockNumber,
557 ) -> Result<(), TestRunnerError> {
558 self.db
559 .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
560 self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
561 self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
562 Ok(())
563 }
564
565 pub(crate) fn send_tip(&self, tip: B256) {
566 self.channel.0.send(tip).expect("failed to send tip");
567 }
568 }
569 }
570
571 stage_test_suite!(HeadersTestRunner, headers);
572
573 #[tokio::test]
576 async fn execute_with_linear_downloader_unwind() {
577 let mut runner = HeadersTestRunner::with_linear_downloader();
578 let (checkpoint, previous_stage) = (1000, 1200);
579 let input = ExecInput {
580 target: Some(previous_stage),
581 checkpoint: Some(StageCheckpoint::new(checkpoint)),
582 };
583 let headers = runner.seed_execution(input).expect("failed to seed execution");
584 let rx = runner.execute(input);
585
586 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
587
588 let tip = headers.last().unwrap();
590 runner.send_tip(tip.hash());
591
592 let result = rx.await.unwrap();
593 runner.db().factory.static_file_provider().commit().unwrap();
594 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
595 block_number,
596 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
597 block_range: CheckpointBlockRange {
598 from,
599 to
600 },
601 progress: EntitiesCheckpoint {
602 processed,
603 total,
604 }
605 }))
606 }, done: true }) if block_number == tip.number &&
607 from == checkpoint && to == previous_stage &&
608 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
610 );
611 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
612 assert!(runner.stage().hash_collector.is_empty());
613 assert!(runner.stage().header_collector.is_empty());
614
615 let sealed_headers = random_header_range(
617 &mut generators::rng(),
618 tip.number + 1..tip.number + 10,
619 tip.hash(),
620 );
621
622 let provider = runner.db().factory.database_provider_rw().unwrap();
623 let static_file_provider = provider.static_file_provider();
624 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers).unwrap();
625 for header in sealed_headers {
626 writer.append_header(header.header(), &header.hash()).unwrap();
627 }
628 drop(writer);
629
630 provider.commit().unwrap();
631
632 let unwind_input = UnwindInput {
634 checkpoint: StageCheckpoint::new(tip.number + 10),
635 unwind_to: tip.number,
636 bad_block: None,
637 };
638
639 let unwind_output = runner.unwind(unwind_input).await.unwrap();
640 assert_eq!(unwind_output.checkpoint.block_number, tip.number);
641
642 assert!(runner.validate_unwind(unwind_input).is_ok());
644 }
645
646 #[tokio::test]
648 async fn execute_with_linear_downloader() {
649 let mut runner = HeadersTestRunner::with_linear_downloader();
650 let (checkpoint, previous_stage) = (1000, 1200);
651 let input = ExecInput {
652 target: Some(previous_stage),
653 checkpoint: Some(StageCheckpoint::new(checkpoint)),
654 };
655 let headers = runner.seed_execution(input).expect("failed to seed execution");
656 let rx = runner.execute(input);
657
658 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
659
660 let tip = headers.last().unwrap();
662 runner.send_tip(tip.hash());
663
664 let result = rx.await.unwrap();
665 runner.db().factory.static_file_provider().commit().unwrap();
666 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
667 block_number,
668 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
669 block_range: CheckpointBlockRange {
670 from,
671 to
672 },
673 progress: EntitiesCheckpoint {
674 processed,
675 total,
676 }
677 }))
678 }, done: true }) if block_number == tip.number &&
679 from == checkpoint && to == previous_stage &&
680 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
682 );
683 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
684 assert!(runner.stage().hash_collector.is_empty());
685 assert!(runner.stage().header_collector.is_empty());
686 }
687}