1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use futures_util::StreamExt;
4use reth_config::config::EtlConfig;
5use reth_db_api::{
6 cursor::{DbCursorRO, DbCursorRW},
7 table::Value,
8 tables,
9 transaction::{DbTx, DbTxMut},
10 DbTxUnwindExt, RawKey, RawTable, RawValue,
11};
12use reth_etl::Collector;
13use reth_network_p2p::headers::{
14 downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
15 error::HeadersDownloaderError,
16};
17use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader, NodePrimitives, SealedHeader};
18use reth_provider::{
19 providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderSyncGapProvider,
20 StaticFileProviderFactory,
21};
22use reth_stages_api::{
23 CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
24 StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
25};
26use reth_static_file_types::StaticFileSegment;
27use std::task::{ready, Context, Poll};
28
29use tokio::sync::watch;
30use tracing::*;
31
32#[derive(Debug)]
44pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
45 provider: Provider,
47 downloader: Downloader,
49 tip: watch::Receiver<B256>,
53 sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
55 hash_collector: Collector<BlockHash, BlockNumber>,
57 header_collector: Collector<BlockNumber, Bytes>,
59 is_etl_ready: bool,
61}
62
63impl<Provider, Downloader> HeaderStage<Provider, Downloader>
66where
67 Downloader: HeaderDownloader,
68{
69 pub fn new(
71 database: Provider,
72 downloader: Downloader,
73 tip: watch::Receiver<B256>,
74 etl_config: EtlConfig,
75 ) -> Self {
76 Self {
77 provider: database,
78 downloader,
79 tip,
80 sync_gap: None,
81 hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
82 header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
83 is_etl_ready: false,
84 }
85 }
86
87 fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
92 where
93 P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
94 Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
95 <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
96 {
97 let total_headers = self.header_collector.len();
98
99 info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
100
101 let static_file_provider = provider.static_file_provider();
102
103 let mut last_header_number = static_file_provider
106 .get_highest_static_file_block(StaticFileSegment::Headers)
107 .unwrap_or_default();
108
109 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
112 let interval = (total_headers / 10).max(1);
113 for (index, header) in self.header_collector.iter()?.enumerate() {
114 let (_, header_buf) = header?;
115
116 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
117 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
118 }
119
120 let sealed_header: SealedHeader<Downloader::Header> =
121 bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
122 .map_err(|err| StageError::Fatal(Box::new(err)))?
123 .into();
124
125 let (header, header_hash) = sealed_header.split_ref();
126 if header.number() == 0 {
127 continue
128 }
129 last_header_number = header.number();
130
131 writer.append_header(header, header_hash)?;
133 }
134
135 info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
136
137 let mut cursor_header_numbers =
138 provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
139 let first_sync = if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 &&
142 let Some((hash, block_number)) = cursor_header_numbers.last()? &&
143 block_number.value()? == 0
144 {
145 self.hash_collector.insert(hash.key()?, 0)?;
146 cursor_header_numbers.delete_current()?;
147 true
148 } else {
149 false
150 };
151
152 for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
155 let (hash, number) = hash_to_number?;
156
157 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
158 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
159 }
160
161 if first_sync {
162 cursor_header_numbers.append(
163 RawKey::<BlockHash>::from_vec(hash),
164 &RawValue::<BlockNumber>::from_vec(number),
165 )?;
166 } else {
167 cursor_header_numbers.upsert(
168 RawKey::<BlockHash>::from_vec(hash),
169 &RawValue::<BlockNumber>::from_vec(number),
170 )?;
171 }
172 }
173
174 Ok(last_header_number)
175 }
176}
177
178impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
179where
180 Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
181 P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
182 D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
183 <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
184{
185 fn id(&self) -> StageId {
187 StageId::Headers
188 }
189
190 fn poll_execute_ready(
191 &mut self,
192 cx: &mut Context<'_>,
193 input: ExecInput,
194 ) -> Poll<Result<(), StageError>> {
195 let current_checkpoint = input.checkpoint();
196
197 if self.is_etl_ready {
199 return Poll::Ready(Ok(()))
200 }
201
202 let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
204 let target = SyncTarget::Tip(*self.tip.borrow());
205 let gap = HeaderSyncGap { local_head, target };
206 let tip = gap.target.tip();
207
208 if gap.is_closed() {
210 info!(
211 target: "sync::stages::headers",
212 checkpoint = %current_checkpoint.block_number,
213 target = ?tip,
214 "Target block already reached"
215 );
216 self.is_etl_ready = true;
217 self.sync_gap = Some(gap);
218 return Poll::Ready(Ok(()))
219 }
220
221 debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
222 let local_head_number = gap.local_head.number();
223
224 if self.sync_gap != Some(gap.clone()) {
226 self.sync_gap = Some(gap.clone());
227 self.downloader.update_sync_gap(gap.local_head, gap.target);
228 }
229
230 loop {
232 match ready!(self.downloader.poll_next_unpin(cx)) {
233 Some(Ok(headers)) => {
234 info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
235 for header in headers {
236 let header_number = header.number();
237
238 self.hash_collector.insert(header.hash(), header_number)?;
239 self.header_collector.insert(
240 header_number,
241 Bytes::from(
242 bincode::serialize(&serde_bincode_compat::SealedHeader::from(
243 &header,
244 ))
245 .map_err(|err| StageError::Fatal(Box::new(err)))?,
246 ),
247 )?;
248
249 if header_number == local_head_number + 1 {
252 self.is_etl_ready = true;
253 return Poll::Ready(Ok(()))
254 }
255 }
256 }
257 Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
258 error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
259 self.sync_gap = None;
260 return Poll::Ready(Err(StageError::DetachedHead {
261 local_head: Box::new(local_head.block_with_parent()),
262 header: Box::new(header.block_with_parent()),
263 error,
264 }))
265 }
266 None => {
267 self.sync_gap = None;
268 return Poll::Ready(Err(StageError::ChannelClosed))
269 }
270 }
271 }
272 }
273
274 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
277 let current_checkpoint = input.checkpoint();
278
279 if self.sync_gap.take().ok_or(StageError::MissingSyncGap)?.is_closed() {
280 self.is_etl_ready = false;
281 return Ok(ExecOutput::done(current_checkpoint))
282 }
283
284 if !self.is_etl_ready {
286 return Err(StageError::MissingDownloadBuffer)
287 }
288
289 self.is_etl_ready = false;
291
292 let to_be_processed = self.hash_collector.len() as u64;
294 let last_header_number = self.write_headers(provider)?;
295
296 self.hash_collector.clear();
298 self.header_collector.clear();
299
300 Ok(ExecOutput {
301 checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
302 HeadersCheckpoint {
303 block_range: CheckpointBlockRange {
304 from: input.checkpoint().block_number,
305 to: last_header_number,
306 },
307 progress: EntitiesCheckpoint {
308 processed: input.checkpoint().block_number + to_be_processed,
309 total: last_header_number,
310 },
311 },
312 ),
313 done: true,
316 })
317 }
318
319 fn unwind(
321 &mut self,
322 provider: &Provider,
323 input: UnwindInput,
324 ) -> Result<UnwindOutput, StageError> {
325 self.sync_gap.take();
326
327 provider
331 .tx_ref()
332 .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
333 (input.unwind_to + 1)..,
334 )?;
335 provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
336 let unfinalized_headers_unwound =
337 provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
338
339 let static_file_provider = provider.static_file_provider();
342 let highest_block = static_file_provider
343 .get_highest_static_file_block(StaticFileSegment::Headers)
344 .unwrap_or_default();
345 let static_file_headers_to_unwind = highest_block - input.unwind_to;
346 for block_number in (input.unwind_to + 1)..=highest_block {
347 let hash = static_file_provider.block_hash(block_number)?;
348 if let Some(header_hash) = hash {
354 provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
355 }
356 }
357
358 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
360 writer.prune_headers(static_file_headers_to_unwind)?;
361
362 let stage_checkpoint =
365 input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
366 block_range: stage_checkpoint.block_range,
367 progress: EntitiesCheckpoint {
368 processed: stage_checkpoint.progress.processed.saturating_sub(
369 static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
370 ),
371 total: stage_checkpoint.progress.total,
372 },
373 });
374
375 let mut checkpoint = StageCheckpoint::new(input.unwind_to);
376 if let Some(stage_checkpoint) = stage_checkpoint {
377 checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
378 }
379
380 Ok(UnwindOutput { checkpoint })
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::test_utils::{
388 stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
389 };
390 use alloy_primitives::B256;
391 use assert_matches::assert_matches;
392 use reth_provider::{DatabaseProviderFactory, ProviderFactory, StaticFileProviderFactory};
393 use reth_stages_api::StageUnitCheckpoint;
394 use reth_testing_utils::generators::{self, random_header, random_header_range};
395 use std::sync::Arc;
396 use test_runner::HeadersTestRunner;
397
398 mod test_runner {
399 use super::*;
400 use crate::test_utils::{TestRunnerError, TestStageDB};
401 use reth_consensus::test_utils::TestConsensus;
402 use reth_downloaders::headers::reverse_headers::{
403 ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
404 };
405 use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
406 use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader, HeaderProvider};
407 use tokio::sync::watch;
408
409 pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
410 pub(crate) client: TestHeadersClient,
411 channel: (watch::Sender<B256>, watch::Receiver<B256>),
412 downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
413 db: TestStageDB,
414 }
415
416 impl Default for HeadersTestRunner<TestHeaderDownloader> {
417 fn default() -> Self {
418 let client = TestHeadersClient::default();
419 Self {
420 client: client.clone(),
421 channel: watch::channel(B256::ZERO),
422
423 downloader_factory: Box::new(move || {
424 TestHeaderDownloader::new(client.clone(), 1000, 1000)
425 }),
426 db: TestStageDB::default(),
427 }
428 }
429 }
430
431 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
432 for HeadersTestRunner<D>
433 {
434 type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
435
436 fn db(&self) -> &TestStageDB {
437 &self.db
438 }
439
440 fn stage(&self) -> Self::S {
441 HeaderStage::new(
442 self.db.factory.clone(),
443 (*self.downloader_factory)(),
444 self.channel.1.clone(),
445 EtlConfig::default(),
446 )
447 }
448 }
449
450 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
451 for HeadersTestRunner<D>
452 {
453 type Seed = Vec<SealedHeader>;
454
455 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
456 let mut rng = generators::rng();
457 let start = input.checkpoint().block_number;
458 let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
459 let head = headers.last().cloned().unwrap();
460 self.db.insert_headers(headers.iter())?;
461
462 let end = input.target.unwrap_or_default() + 1;
464
465 if start + 1 >= end {
466 return Ok(Vec::default())
467 }
468
469 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
470 headers.insert(0, head);
471 Ok(headers)
472 }
473
474 fn validate_execution(
476 &self,
477 input: ExecInput,
478 output: Option<ExecOutput>,
479 ) -> Result<(), TestRunnerError> {
480 let initial_checkpoint = input.checkpoint().block_number;
481 match output {
482 Some(output) if output.checkpoint.block_number > initial_checkpoint => {
483 let provider = self.db.factory.provider()?;
484
485 for block_num in initial_checkpoint..output.checkpoint.block_number {
486 let hash = provider.block_hash(block_num)?.expect("no header hash");
488
489 assert_eq!(provider.block_number(hash)?, Some(block_num));
491
492 let header = provider.header_by_number(block_num)?;
494 assert!(header.is_some());
495 let header = SealedHeader::seal_slow(header.unwrap());
496 assert_eq!(header.hash(), hash);
497 }
498 }
499 _ => self.check_no_header_entry_above(initial_checkpoint)?,
500 };
501 Ok(())
502 }
503
504 async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
505 self.client.extend(headers.iter().map(|h| h.clone_header())).await;
506 let tip = if headers.is_empty() {
507 let tip = random_header(&mut generators::rng(), 0, None);
508 self.db.insert_headers(std::iter::once(&tip))?;
509 tip.hash()
510 } else {
511 headers.last().unwrap().hash()
512 };
513 self.send_tip(tip);
514 Ok(())
515 }
516 }
517
518 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
519 for HeadersTestRunner<D>
520 {
521 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
522 self.check_no_header_entry_above(input.unwind_to)
523 }
524 }
525
526 impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
527 pub(crate) fn with_linear_downloader() -> Self {
528 let client = TestHeadersClient::default();
529 Self {
530 client: client.clone(),
531 channel: watch::channel(B256::ZERO),
532 downloader_factory: Box::new(move || {
533 ReverseHeadersDownloaderBuilder::default()
534 .stream_batch_size(500)
535 .build(client.clone(), Arc::new(TestConsensus::default()))
536 }),
537 db: TestStageDB::default(),
538 }
539 }
540 }
541
542 impl<D: HeaderDownloader> HeadersTestRunner<D> {
543 pub(crate) fn check_no_header_entry_above(
544 &self,
545 block: BlockNumber,
546 ) -> Result<(), TestRunnerError> {
547 self.db
548 .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
549 self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
550 self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
551 Ok(())
552 }
553
554 pub(crate) fn send_tip(&self, tip: B256) {
555 self.channel.0.send(tip).expect("failed to send tip");
556 }
557 }
558 }
559
560 stage_test_suite!(HeadersTestRunner, headers);
561
562 #[tokio::test]
565 async fn execute_with_linear_downloader_unwind() {
566 let mut runner = HeadersTestRunner::with_linear_downloader();
567 let (checkpoint, previous_stage) = (1000, 1200);
568 let input = ExecInput {
569 target: Some(previous_stage),
570 checkpoint: Some(StageCheckpoint::new(checkpoint)),
571 };
572 let headers = runner.seed_execution(input).expect("failed to seed execution");
573 let rx = runner.execute(input);
574
575 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
576
577 let tip = headers.last().unwrap();
579 runner.send_tip(tip.hash());
580
581 let result = rx.await.unwrap();
582 runner.db().factory.static_file_provider().commit().unwrap();
583 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
584 block_number,
585 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
586 block_range: CheckpointBlockRange {
587 from,
588 to
589 },
590 progress: EntitiesCheckpoint {
591 processed,
592 total,
593 }
594 }))
595 }, done: true }) if block_number == tip.number &&
596 from == checkpoint && to == previous_stage &&
597 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
599 );
600 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
601 assert!(runner.stage().hash_collector.is_empty());
602 assert!(runner.stage().header_collector.is_empty());
603
604 let sealed_headers = random_header_range(
606 &mut generators::rng(),
607 tip.number + 1..tip.number + 10,
608 tip.hash(),
609 );
610
611 let provider = runner.db().factory.database_provider_rw().unwrap();
612 let static_file_provider = provider.static_file_provider();
613 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers).unwrap();
614 for header in sealed_headers {
615 writer.append_header(header.header(), &header.hash()).unwrap();
616 }
617 drop(writer);
618
619 provider.commit().unwrap();
620
621 let unwind_input = UnwindInput {
623 checkpoint: StageCheckpoint::new(tip.number + 10),
624 unwind_to: tip.number,
625 bad_block: None,
626 };
627
628 let unwind_output = runner.unwind(unwind_input).await.unwrap();
629 assert_eq!(unwind_output.checkpoint.block_number, tip.number);
630
631 assert!(runner.validate_unwind(unwind_input).is_ok());
633 }
634
635 #[tokio::test]
637 async fn execute_with_linear_downloader() {
638 let mut runner = HeadersTestRunner::with_linear_downloader();
639 let (checkpoint, previous_stage) = (1000, 1200);
640 let input = ExecInput {
641 target: Some(previous_stage),
642 checkpoint: Some(StageCheckpoint::new(checkpoint)),
643 };
644 let headers = runner.seed_execution(input).expect("failed to seed execution");
645 let rx = runner.execute(input);
646
647 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
648
649 let tip = headers.last().unwrap();
651 runner.send_tip(tip.hash());
652
653 let result = rx.await.unwrap();
654 runner.db().factory.static_file_provider().commit().unwrap();
655 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
656 block_number,
657 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
658 block_range: CheckpointBlockRange {
659 from,
660 to
661 },
662 progress: EntitiesCheckpoint {
663 processed,
664 total,
665 }
666 }))
667 }, done: true }) if block_number == tip.number &&
668 from == checkpoint && to == previous_stage &&
669 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
671 );
672 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
673 assert!(runner.stage().hash_collector.is_empty());
674 assert!(runner.stage().header_collector.is_empty());
675 }
676}