1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use futures_util::StreamExt;
4use reth_config::config::EtlConfig;
5use reth_db_api::{
6 cursor::{DbCursorRO, DbCursorRW},
7 table::Value,
8 tables,
9 transaction::{DbTx, DbTxMut},
10 DbTxUnwindExt, RawKey, RawTable, RawValue,
11};
12use reth_etl::Collector;
13use reth_network_p2p::headers::{
14 downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
15 error::HeadersDownloaderError,
16};
17use reth_primitives_traits::{
18 serde_bincode_compat, FullBlockHeader, HeaderTy, NodePrimitives, SealedHeader,
19};
20use reth_provider::{
21 providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderSyncGapProvider,
22 StaticFileProviderFactory,
23};
24use reth_stages_api::{
25 CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
26 StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
27};
28use reth_static_file_types::StaticFileSegment;
29use std::task::{ready, Context, Poll};
30
31use tokio::sync::watch;
32use tracing::*;
33
34#[derive(Debug)]
46pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
47 provider: Provider,
49 downloader: Downloader,
51 tip: watch::Receiver<B256>,
55 sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
57 hash_collector: Collector<BlockHash, BlockNumber>,
59 header_collector: Collector<BlockNumber, Bytes>,
61 is_etl_ready: bool,
63}
64
65impl<Provider, Downloader> HeaderStage<Provider, Downloader>
68where
69 Downloader: HeaderDownloader,
70{
71 pub fn new(
73 database: Provider,
74 downloader: Downloader,
75 tip: watch::Receiver<B256>,
76 etl_config: EtlConfig,
77 ) -> Self {
78 Self {
79 provider: database,
80 downloader,
81 tip,
82 sync_gap: None,
83 hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
84 header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
85 is_etl_ready: false,
86 }
87 }
88
89 fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
94 where
95 P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
96 Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
97 <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
98 {
99 let total_headers = self.header_collector.len();
100
101 info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
102
103 let static_file_provider = provider.static_file_provider();
104
105 let mut last_header_number = static_file_provider
108 .get_highest_static_file_block(StaticFileSegment::Headers)
109 .unwrap_or_default();
110
111 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
114 let interval = (total_headers / 10).max(1);
115 for (index, header) in self.header_collector.iter()?.enumerate() {
116 let (_, header_buf) = header?;
117
118 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
119 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
120 }
121
122 let sealed_header: SealedHeader<Downloader::Header> =
123 bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
124 .map_err(|err| StageError::Fatal(Box::new(err)))?
125 .into();
126
127 let (header, header_hash) = sealed_header.split_ref();
128 if header.number() == 0 {
129 continue
130 }
131 last_header_number = header.number();
132
133 writer.append_header(header, header_hash)?;
135 }
136
137 info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
138
139 let mut cursor_header_numbers =
140 provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
141 let first_sync = if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 &&
144 let Some((hash, block_number)) = cursor_header_numbers.last()? &&
145 block_number.value()? == 0
146 {
147 self.hash_collector.insert(hash.key()?, 0)?;
148 cursor_header_numbers.delete_current()?;
149 true
150 } else {
151 false
152 };
153
154 for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
157 let (hash, number) = hash_to_number?;
158
159 if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
160 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
161 }
162
163 if first_sync {
164 cursor_header_numbers.append(
165 RawKey::<BlockHash>::from_vec(hash),
166 &RawValue::<BlockNumber>::from_vec(number),
167 )?;
168 } else {
169 cursor_header_numbers.upsert(
170 RawKey::<BlockHash>::from_vec(hash),
171 &RawValue::<BlockNumber>::from_vec(number),
172 )?;
173 }
174 }
175
176 Ok(last_header_number)
177 }
178}
179
180impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
181where
182 Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
183 P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
184 D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
185 <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
186{
187 fn id(&self) -> StageId {
189 StageId::Headers
190 }
191
192 fn poll_execute_ready(
193 &mut self,
194 cx: &mut Context<'_>,
195 input: ExecInput,
196 ) -> Poll<Result<(), StageError>> {
197 let current_checkpoint = input.checkpoint();
198
199 if self.is_etl_ready {
201 return Poll::Ready(Ok(()))
202 }
203
204 let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
206 let target = SyncTarget::Tip(*self.tip.borrow());
207 let gap = HeaderSyncGap { local_head, target };
208 let tip = gap.target.tip();
209
210 if gap.is_closed() {
212 info!(
213 target: "sync::stages::headers",
214 checkpoint = %current_checkpoint.block_number,
215 target = ?tip,
216 "Target block already reached"
217 );
218 self.is_etl_ready = true;
219 self.sync_gap = Some(gap);
220 return Poll::Ready(Ok(()))
221 }
222
223 debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
224 let local_head_number = gap.local_head.number();
225
226 if self.sync_gap != Some(gap.clone()) {
228 self.sync_gap = Some(gap.clone());
229 self.downloader.update_sync_gap(gap.local_head, gap.target);
230 }
231
232 loop {
234 match ready!(self.downloader.poll_next_unpin(cx)) {
235 Some(Ok(headers)) => {
236 info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
237 for header in headers {
238 let header_number = header.number();
239
240 self.hash_collector.insert(header.hash(), header_number)?;
241 self.header_collector.insert(
242 header_number,
243 Bytes::from(
244 bincode::serialize(&serde_bincode_compat::SealedHeader::from(
245 &header,
246 ))
247 .map_err(|err| StageError::Fatal(Box::new(err)))?,
248 ),
249 )?;
250
251 if header_number == local_head_number + 1 {
254 self.is_etl_ready = true;
255 return Poll::Ready(Ok(()))
256 }
257 }
258 }
259 Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
260 error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
261 self.sync_gap = None;
262 return Poll::Ready(Err(StageError::DetachedHead {
263 local_head: Box::new(local_head.block_with_parent()),
264 header: Box::new(header.block_with_parent()),
265 error,
266 }))
267 }
268 None => {
269 self.sync_gap = None;
270 return Poll::Ready(Err(StageError::ChannelClosed))
271 }
272 }
273 }
274 }
275
276 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
279 let current_checkpoint = input.checkpoint();
280
281 if self.sync_gap.take().ok_or(StageError::MissingSyncGap)?.is_closed() {
282 self.is_etl_ready = false;
283 return Ok(ExecOutput::done(current_checkpoint))
284 }
285
286 if !self.is_etl_ready {
288 return Err(StageError::MissingDownloadBuffer)
289 }
290
291 self.is_etl_ready = false;
293
294 let to_be_processed = self.hash_collector.len() as u64;
296 let last_header_number = self.write_headers(provider)?;
297
298 self.hash_collector.clear();
300 self.header_collector.clear();
301
302 Ok(ExecOutput {
303 checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
304 HeadersCheckpoint {
305 block_range: CheckpointBlockRange {
306 from: input.checkpoint().block_number,
307 to: last_header_number,
308 },
309 progress: EntitiesCheckpoint {
310 processed: input.checkpoint().block_number + to_be_processed,
311 total: last_header_number,
312 },
313 },
314 ),
315 done: true,
318 })
319 }
320
321 fn unwind(
323 &mut self,
324 provider: &Provider,
325 input: UnwindInput,
326 ) -> Result<UnwindOutput, StageError> {
327 self.sync_gap.take();
328
329 provider
333 .tx_ref()
334 .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
335 (input.unwind_to + 1)..,
336 )?;
337 provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
338 let unfinalized_headers_unwound = provider.tx_ref().unwind_table_by_num::<tables::Headers<
339 HeaderTy<Provider::Primitives>,
340 >>(input.unwind_to)?;
341
342 let static_file_provider = provider.static_file_provider();
345 let highest_block = static_file_provider
346 .get_highest_static_file_block(StaticFileSegment::Headers)
347 .unwrap_or_default();
348 let static_file_headers_to_unwind = highest_block - input.unwind_to;
349 for block_number in (input.unwind_to + 1)..=highest_block {
350 let hash = static_file_provider.block_hash(block_number)?;
351 if let Some(header_hash) = hash {
357 provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
358 }
359 }
360
361 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
363 writer.prune_headers(static_file_headers_to_unwind)?;
364
365 let stage_checkpoint =
368 input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
369 block_range: stage_checkpoint.block_range,
370 progress: EntitiesCheckpoint {
371 processed: stage_checkpoint.progress.processed.saturating_sub(
372 static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
373 ),
374 total: stage_checkpoint.progress.total,
375 },
376 });
377
378 let mut checkpoint = StageCheckpoint::new(input.unwind_to);
379 if let Some(stage_checkpoint) = stage_checkpoint {
380 checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
381 }
382
383 Ok(UnwindOutput { checkpoint })
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::test_utils::{
391 stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
392 };
393 use alloy_primitives::B256;
394 use assert_matches::assert_matches;
395 use reth_provider::{DatabaseProviderFactory, ProviderFactory, StaticFileProviderFactory};
396 use reth_stages_api::StageUnitCheckpoint;
397 use reth_testing_utils::generators::{self, random_header, random_header_range};
398 use std::sync::Arc;
399 use test_runner::HeadersTestRunner;
400
401 mod test_runner {
402 use super::*;
403 use crate::test_utils::{TestRunnerError, TestStageDB};
404 use reth_consensus::test_utils::TestConsensus;
405 use reth_downloaders::headers::reverse_headers::{
406 ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
407 };
408 use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
409 use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader, HeaderProvider};
410 use tokio::sync::watch;
411
412 pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
413 pub(crate) client: TestHeadersClient,
414 channel: (watch::Sender<B256>, watch::Receiver<B256>),
415 downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
416 db: TestStageDB,
417 }
418
419 impl Default for HeadersTestRunner<TestHeaderDownloader> {
420 fn default() -> Self {
421 let client = TestHeadersClient::default();
422 Self {
423 client: client.clone(),
424 channel: watch::channel(B256::ZERO),
425
426 downloader_factory: Box::new(move || {
427 TestHeaderDownloader::new(client.clone(), 1000, 1000)
428 }),
429 db: TestStageDB::default(),
430 }
431 }
432 }
433
434 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
435 for HeadersTestRunner<D>
436 {
437 type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
438
439 fn db(&self) -> &TestStageDB {
440 &self.db
441 }
442
443 fn stage(&self) -> Self::S {
444 HeaderStage::new(
445 self.db.factory.clone(),
446 (*self.downloader_factory)(),
447 self.channel.1.clone(),
448 EtlConfig::default(),
449 )
450 }
451 }
452
453 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
454 for HeadersTestRunner<D>
455 {
456 type Seed = Vec<SealedHeader>;
457
458 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
459 let mut rng = generators::rng();
460 let start = input.checkpoint().block_number;
461 let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
462 let head = headers.last().cloned().unwrap();
463 self.db.insert_headers(headers.iter())?;
464
465 let end = input.target.unwrap_or_default() + 1;
467
468 if start + 1 >= end {
469 return Ok(Vec::default())
470 }
471
472 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
473 headers.insert(0, head);
474 Ok(headers)
475 }
476
477 fn validate_execution(
479 &self,
480 input: ExecInput,
481 output: Option<ExecOutput>,
482 ) -> Result<(), TestRunnerError> {
483 let initial_checkpoint = input.checkpoint().block_number;
484 match output {
485 Some(output) if output.checkpoint.block_number > initial_checkpoint => {
486 let provider = self.db.factory.provider()?;
487
488 for block_num in initial_checkpoint..output.checkpoint.block_number {
489 let hash = provider.block_hash(block_num)?.expect("no header hash");
491
492 assert_eq!(provider.block_number(hash)?, Some(block_num));
494
495 let header = provider.header_by_number(block_num)?;
497 assert!(header.is_some());
498 let header = SealedHeader::seal_slow(header.unwrap());
499 assert_eq!(header.hash(), hash);
500 }
501 }
502 _ => self.check_no_header_entry_above(initial_checkpoint)?,
503 };
504 Ok(())
505 }
506
507 async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
508 self.client.extend(headers.iter().map(|h| h.clone_header())).await;
509 let tip = if headers.is_empty() {
510 let tip = random_header(&mut generators::rng(), 0, None);
511 self.db.insert_headers(std::iter::once(&tip))?;
512 tip.hash()
513 } else {
514 headers.last().unwrap().hash()
515 };
516 self.send_tip(tip);
517 Ok(())
518 }
519 }
520
521 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
522 for HeadersTestRunner<D>
523 {
524 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
525 self.check_no_header_entry_above(input.unwind_to)
526 }
527 }
528
529 impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
530 pub(crate) fn with_linear_downloader() -> Self {
531 let client = TestHeadersClient::default();
532 Self {
533 client: client.clone(),
534 channel: watch::channel(B256::ZERO),
535 downloader_factory: Box::new(move || {
536 ReverseHeadersDownloaderBuilder::default()
537 .stream_batch_size(500)
538 .build(client.clone(), Arc::new(TestConsensus::default()))
539 }),
540 db: TestStageDB::default(),
541 }
542 }
543 }
544
545 impl<D: HeaderDownloader> HeadersTestRunner<D> {
546 pub(crate) fn check_no_header_entry_above(
547 &self,
548 block: BlockNumber,
549 ) -> Result<(), TestRunnerError> {
550 self.db
551 .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
552 self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
553 self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
554 Ok(())
555 }
556
557 pub(crate) fn send_tip(&self, tip: B256) {
558 self.channel.0.send(tip).expect("failed to send tip");
559 }
560 }
561 }
562
563 stage_test_suite!(HeadersTestRunner, headers);
564
565 #[tokio::test]
568 async fn execute_with_linear_downloader_unwind() {
569 let mut runner = HeadersTestRunner::with_linear_downloader();
570 let (checkpoint, previous_stage) = (1000, 1200);
571 let input = ExecInput {
572 target: Some(previous_stage),
573 checkpoint: Some(StageCheckpoint::new(checkpoint)),
574 };
575 let headers = runner.seed_execution(input).expect("failed to seed execution");
576 let rx = runner.execute(input);
577
578 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
579
580 let tip = headers.last().unwrap();
582 runner.send_tip(tip.hash());
583
584 let result = rx.await.unwrap();
585 runner.db().factory.static_file_provider().commit().unwrap();
586 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
587 block_number,
588 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
589 block_range: CheckpointBlockRange {
590 from,
591 to
592 },
593 progress: EntitiesCheckpoint {
594 processed,
595 total,
596 }
597 }))
598 }, done: true }) if block_number == tip.number &&
599 from == checkpoint && to == previous_stage &&
600 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
602 );
603 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
604 assert!(runner.stage().hash_collector.is_empty());
605 assert!(runner.stage().header_collector.is_empty());
606
607 let sealed_headers = random_header_range(
609 &mut generators::rng(),
610 tip.number + 1..tip.number + 10,
611 tip.hash(),
612 );
613
614 let provider = runner.db().factory.database_provider_rw().unwrap();
615 let static_file_provider = provider.static_file_provider();
616 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers).unwrap();
617 for header in sealed_headers {
618 writer.append_header(header.header(), &header.hash()).unwrap();
619 }
620 drop(writer);
621
622 provider.commit().unwrap();
623
624 let unwind_input = UnwindInput {
626 checkpoint: StageCheckpoint::new(tip.number + 10),
627 unwind_to: tip.number,
628 bad_block: None,
629 };
630
631 let unwind_output = runner.unwind(unwind_input).await.unwrap();
632 assert_eq!(unwind_output.checkpoint.block_number, tip.number);
633
634 assert!(runner.validate_unwind(unwind_input).is_ok());
636 }
637
638 #[tokio::test]
640 async fn execute_with_linear_downloader() {
641 let mut runner = HeadersTestRunner::with_linear_downloader();
642 let (checkpoint, previous_stage) = (1000, 1200);
643 let input = ExecInput {
644 target: Some(previous_stage),
645 checkpoint: Some(StageCheckpoint::new(checkpoint)),
646 };
647 let headers = runner.seed_execution(input).expect("failed to seed execution");
648 let rx = runner.execute(input);
649
650 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
651
652 let tip = headers.last().unwrap();
654 runner.send_tip(tip.hash());
655
656 let result = rx.await.unwrap();
657 runner.db().factory.static_file_provider().commit().unwrap();
658 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
659 block_number,
660 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
661 block_range: CheckpointBlockRange {
662 from,
663 to
664 },
665 progress: EntitiesCheckpoint {
666 processed,
667 total,
668 }
669 }))
670 }, done: true }) if block_number == tip.number &&
671 from == checkpoint && to == previous_stage &&
672 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
674 );
675 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
676 assert!(runner.stage().hash_collector.is_empty());
677 assert!(runner.stage().header_collector.is_empty());
678 }
679}