Skip to main content

reth_stages/stages/
headers.rs

1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use alloy_rlp::Decodable;
4use futures_util::StreamExt;
5use reth_config::config::EtlConfig;
6use reth_db_api::{
7    cursor::{DbCursorRO, DbCursorRW},
8    table::Value,
9    tables,
10    transaction::{DbTx, DbTxMut},
11    DbTxUnwindExt, RawKey, RawTable, RawValue,
12};
13use reth_etl::Collector;
14use reth_network_p2p::headers::{
15    downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
16    error::HeadersDownloaderError,
17};
18use reth_primitives_traits::{FullBlockHeader, HeaderTy, NodePrimitives, SealedHeader};
19use reth_provider::{
20    providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderSyncGapProvider,
21    StaticFileProviderFactory,
22};
23use reth_stages_api::{
24    CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
25    StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
26};
27use reth_static_file_types::StaticFileSegment;
28use std::task::{ready, Context, Poll};
29
30use tokio::sync::watch;
31use tracing::*;
32
33/// The headers stage.
34///
35/// The headers stage downloads all block headers from the highest block in storage to
36/// the perceived highest block on the network.
37///
38/// The headers are processed and data is inserted into static files, as well as into the
39/// [`HeaderNumbers`][reth_db_api::tables::HeaderNumbers] table.
40///
41/// NOTE: This stage downloads headers in reverse and pushes them to the ETL [`Collector`]. It then
42/// proceeds to push them sequentially to static files. The stage checkpoint is not updated until
43/// this stage is done.
44#[derive(Debug)]
45pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
46    /// Database handle.
47    provider: Provider,
48    /// Strategy for downloading the headers
49    downloader: Downloader,
50    /// The tip for the stage.
51    ///
52    /// This determines the sync target of the stage (set by the pipeline).
53    tip: watch::Receiver<B256>,
54    /// Current sync gap.
55    sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
56    /// ETL collector with `HeaderHash` -> `BlockNumber`
57    hash_collector: Collector<BlockHash, BlockNumber>,
58    /// ETL collector with `BlockNumber` -> `RLP-encoded SealedHeader`
59    header_collector: Collector<BlockNumber, Bytes>,
60    /// Returns true if the ETL collector has all necessary headers to fill the gap.
61    is_etl_ready: bool,
62}
63
64// === impl HeaderStage ===
65
66impl<Provider, Downloader> HeaderStage<Provider, Downloader>
67where
68    Downloader: HeaderDownloader,
69{
70    /// Create a new header stage
71    pub fn new(
72        database: Provider,
73        downloader: Downloader,
74        tip: watch::Receiver<B256>,
75        etl_config: EtlConfig,
76    ) -> Self {
77        Self {
78            provider: database,
79            downloader,
80            tip,
81            sync_gap: None,
82            hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
83            header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
84            is_etl_ready: false,
85        }
86    }
87
88    /// Clear all ETL state. Called on error paths to prevent buffer pollution on retry.
89    fn clear_etl_state(&mut self) {
90        self.sync_gap = None;
91        self.hash_collector.clear();
92        self.header_collector.clear();
93        self.is_etl_ready = false;
94    }
95
96    /// Write downloaded headers to storage from ETL.
97    ///
98    /// Writes to static files ( `Header | HeaderTD | HeaderHash` ) and [`tables::HeaderNumbers`]
99    /// database table.
100    fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
101    where
102        P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
103        Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
104        <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
105    {
106        let total_headers = self.header_collector.len();
107
108        info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
109
110        let static_file_provider = provider.static_file_provider();
111
112        // Consistency check of expected headers in static files vs DB is done on provider::sync_gap
113        // when poll_execute_ready is polled.
114        let mut last_header_number = static_file_provider
115            .get_highest_static_file_block(StaticFileSegment::Headers)
116            .unwrap_or_default();
117
118        // Although headers were downloaded in reverse order, the collector iterates it in ascending
119        // order
120        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
121        let interval = (total_headers / 10).max(1);
122        for (index, header) in self.header_collector.iter()?.enumerate() {
123            let (_, header_buf) = header?;
124
125            if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
126                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
127            }
128
129            let sealed_header: SealedHeader<Downloader::Header> = SealedHeader::new_unhashed(
130                Decodable::decode(&mut header_buf.as_slice())
131                    .map_err(|err| StageError::Fatal(Box::new(err)))?,
132            );
133
134            let (header, header_hash) = sealed_header.split_ref();
135            if header.number() == 0 {
136                continue
137            }
138            last_header_number = header.number();
139
140            // Append to Headers segment
141            writer.append_header(header, header_hash)?;
142        }
143
144        info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
145
146        let mut cursor_header_numbers =
147            provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
148        // If we only have the genesis block hash, then we are at first sync, and we can remove it,
149        // add it to the collector and use tx.append on all hashes.
150        let first_sync = if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 &&
151            let Some((hash, block_number)) = cursor_header_numbers.last()? &&
152            block_number.value()? == 0
153        {
154            self.hash_collector.insert(hash.key()?, 0)?;
155            cursor_header_numbers.delete_current()?;
156            true
157        } else {
158            false
159        };
160
161        // Since ETL sorts all entries by hashes, we are either appending (first sync) or inserting
162        // in order (further syncs).
163        for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
164            let (hash, number) = hash_to_number?;
165
166            if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
167                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
168            }
169
170            if first_sync {
171                cursor_header_numbers.append(
172                    RawKey::<BlockHash>::from_vec(hash),
173                    &RawValue::<BlockNumber>::from_vec(number),
174                )?;
175            } else {
176                cursor_header_numbers.upsert(
177                    RawKey::<BlockHash>::from_vec(hash),
178                    &RawValue::<BlockNumber>::from_vec(number),
179                )?;
180            }
181        }
182
183        Ok(last_header_number)
184    }
185}
186
187impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
188where
189    Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
190    P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
191    D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
192    <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
193{
194    /// Return the id of the stage
195    fn id(&self) -> StageId {
196        StageId::Headers
197    }
198
199    fn poll_execute_ready(
200        &mut self,
201        cx: &mut Context<'_>,
202        input: ExecInput,
203    ) -> Poll<Result<(), StageError>> {
204        let current_checkpoint = input.checkpoint();
205
206        // Return if stage has already completed the gap on the ETL files
207        if self.is_etl_ready {
208            return Poll::Ready(Ok(()))
209        }
210
211        // Lookup the head and tip of the sync range
212        let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
213        let target = SyncTarget::Tip(*self.tip.borrow());
214        let gap = HeaderSyncGap { local_head, target };
215        let tip = gap.target.tip();
216
217        // Nothing to sync
218        if gap.is_closed() {
219            info!(
220                target: "sync::stages::headers",
221                checkpoint = %current_checkpoint.block_number,
222                target = ?tip,
223                "Target block already reached"
224            );
225            self.is_etl_ready = true;
226            self.sync_gap = Some(gap);
227            return Poll::Ready(Ok(()))
228        }
229
230        debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
231        let local_head_number = gap.local_head.number();
232
233        // let the downloader know what to sync
234        if self.sync_gap != Some(gap.clone()) {
235            self.sync_gap = Some(gap.clone());
236            self.downloader.update_sync_gap(gap.local_head, gap.target);
237        }
238
239        // We only want to stop once we have all the headers on ETL filespace (disk).
240        loop {
241            match ready!(self.downloader.poll_next_unpin(cx)) {
242                Some(Ok(headers)) => {
243                    info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
244                    for header in headers {
245                        let header_number = header.number();
246
247                        self.hash_collector.insert(header.hash(), header_number)?;
248                        self.header_collector
249                            .insert(header_number, Bytes::from(alloy_rlp::encode(&*header)))?;
250
251                        // Headers are downloaded in reverse, so if we reach here, we know we have
252                        // filled the gap.
253                        if header_number == local_head_number + 1 {
254                            self.is_etl_ready = true;
255                            return Poll::Ready(Ok(()))
256                        }
257                    }
258                }
259                Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
260                    error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
261                    self.clear_etl_state();
262                    return Poll::Ready(Err(StageError::DetachedHead {
263                        local_head: Box::new(local_head.block_with_parent()),
264                        header: Box::new(header.block_with_parent()),
265                        error,
266                    }))
267                }
268                None => {
269                    self.clear_etl_state();
270                    return Poll::Ready(Err(StageError::ChannelClosed))
271                }
272            }
273        }
274    }
275
276    /// Download the headers in reverse order (falling block numbers)
277    /// starting from the tip of the chain
278    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
279        let current_checkpoint = input.checkpoint();
280
281        if self.sync_gap.take().ok_or(StageError::MissingSyncGap)?.is_closed() {
282            self.is_etl_ready = false;
283            return Ok(ExecOutput::done(current_checkpoint))
284        }
285
286        // We should be here only after we have downloaded all headers into the disk buffer (ETL).
287        if !self.is_etl_ready {
288            return Err(StageError::MissingDownloadBuffer)
289        }
290
291        // Reset flag
292        self.is_etl_ready = false;
293
294        // Write the headers and related tables to DB from ETL space
295        let to_be_processed = self.hash_collector.len() as u64;
296        let last_header_number = self.write_headers(provider)?;
297
298        // Clear ETL collectors
299        self.hash_collector.clear();
300        self.header_collector.clear();
301
302        Ok(ExecOutput {
303            checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
304                HeadersCheckpoint {
305                    block_range: CheckpointBlockRange {
306                        from: input.checkpoint().block_number,
307                        to: last_header_number,
308                    },
309                    progress: EntitiesCheckpoint {
310                        processed: input.checkpoint().block_number + to_be_processed,
311                        total: last_header_number,
312                    },
313                },
314            ),
315            // We only reach here if all headers have been downloaded by ETL, and pushed to DB all
316            // in one stage run.
317            done: true,
318        })
319    }
320
321    /// Unwind the stage.
322    fn unwind(
323        &mut self,
324        provider: &Provider,
325        input: UnwindInput,
326    ) -> Result<UnwindOutput, StageError> {
327        self.clear_etl_state();
328
329        // First unwind the db tables, until the unwind_to block number. use the walker to unwind
330        // HeaderNumbers based on the index in CanonicalHeaders
331        // unwind from the next block number since the unwind_to block is exclusive
332        provider
333            .tx_ref()
334            .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
335                (input.unwind_to + 1)..,
336            )?;
337        provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
338        let unfinalized_headers_unwound = provider.tx_ref().unwind_table_by_num::<tables::Headers<
339            HeaderTy<Provider::Primitives>,
340        >>(input.unwind_to)?;
341
342        // determine how many headers to unwind from the static files based on the highest block and
343        // the unwind_to block
344        let static_file_provider = provider.static_file_provider();
345        let highest_block = static_file_provider
346            .get_highest_static_file_block(StaticFileSegment::Headers)
347            .unwrap_or_default();
348        let static_file_headers_to_unwind = highest_block - input.unwind_to;
349        for block_number in (input.unwind_to + 1)..=highest_block {
350            let hash = static_file_provider.block_hash(block_number)?;
351            // we have to delete from HeaderNumbers here as well as in the above unwind, since that
352            // mapping contains entries for both headers in the db and headers in static files
353            //
354            // so if we are unwinding past the lowest block in the db, we have to iterate through
355            // the HeaderNumbers entries that we'll delete in static files below
356            if let Some(header_hash) = hash {
357                provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
358            }
359        }
360
361        // Now unwind the static files until the unwind_to block number
362        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
363        writer.prune_headers(static_file_headers_to_unwind)?;
364
365        // Set the stage checkpoint entities processed based on how much we unwound - we add the
366        // headers unwound from static files and db
367        let stage_checkpoint =
368            input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
369                block_range: stage_checkpoint.block_range,
370                progress: EntitiesCheckpoint {
371                    processed: stage_checkpoint.progress.processed.saturating_sub(
372                        static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
373                    ),
374                    total: stage_checkpoint.progress.total,
375                },
376            });
377
378        let mut checkpoint = StageCheckpoint::new(input.unwind_to);
379        if let Some(stage_checkpoint) = stage_checkpoint {
380            checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
381        }
382
383        Ok(UnwindOutput { checkpoint })
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::test_utils::{
391        stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
392    };
393    use alloy_primitives::B256;
394    use assert_matches::assert_matches;
395    use reth_provider::{DatabaseProviderFactory, ProviderFactory, StaticFileProviderFactory};
396    use reth_stages_api::StageUnitCheckpoint;
397    use reth_testing_utils::generators::{self, random_header, random_header_range};
398    use std::sync::Arc;
399    use test_runner::HeadersTestRunner;
400
401    mod test_runner {
402        use super::*;
403        use crate::test_utils::{TestRunnerError, TestStageDB};
404        use reth_consensus::test_utils::TestConsensus;
405        use reth_downloaders::headers::reverse_headers::{
406            ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
407        };
408        use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
409        use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader, HeaderProvider};
410        use tokio::sync::watch;
411
412        pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
413            pub(crate) client: TestHeadersClient,
414            channel: (watch::Sender<B256>, watch::Receiver<B256>),
415            downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
416            db: TestStageDB,
417        }
418
419        impl Default for HeadersTestRunner<TestHeaderDownloader> {
420            fn default() -> Self {
421                let client = TestHeadersClient::default();
422                Self {
423                    client: client.clone(),
424                    channel: watch::channel(B256::ZERO),
425
426                    downloader_factory: Box::new(move || {
427                        TestHeaderDownloader::new(client.clone(), 1000, 1000)
428                    }),
429                    db: TestStageDB::default(),
430                }
431            }
432        }
433
434        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
435            for HeadersTestRunner<D>
436        {
437            type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
438
439            fn db(&self) -> &TestStageDB {
440                &self.db
441            }
442
443            fn stage(&self) -> Self::S {
444                HeaderStage::new(
445                    self.db.factory.clone(),
446                    (*self.downloader_factory)(),
447                    self.channel.1.clone(),
448                    EtlConfig::default(),
449                )
450            }
451        }
452
453        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
454            for HeadersTestRunner<D>
455        {
456            type Seed = Vec<SealedHeader>;
457
458            fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
459                let mut rng = generators::rng();
460                let start = input.checkpoint().block_number;
461                let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
462                let head = headers.last().cloned().unwrap();
463                self.db.insert_headers(headers.iter())?;
464
465                // use previous checkpoint as seed size
466                let end = input.target.unwrap_or_default() + 1;
467
468                if start + 1 >= end {
469                    return Ok(Vec::default())
470                }
471
472                let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
473                headers.insert(0, head);
474                Ok(headers)
475            }
476
477            /// Validate stored headers
478            fn validate_execution(
479                &self,
480                input: ExecInput,
481                output: Option<ExecOutput>,
482            ) -> Result<(), TestRunnerError> {
483                let initial_checkpoint = input.checkpoint().block_number;
484                match output {
485                    Some(output) if output.checkpoint.block_number > initial_checkpoint => {
486                        let provider = self.db.factory.provider()?;
487
488                        for block_num in initial_checkpoint..output.checkpoint.block_number {
489                            // look up the header hash
490                            let hash = provider.block_hash(block_num)?.expect("no header hash");
491
492                            // validate the header number
493                            assert_eq!(provider.block_number(hash)?, Some(block_num));
494
495                            // validate the header
496                            let header = provider.header_by_number(block_num)?;
497                            assert!(header.is_some());
498                            let header = SealedHeader::seal_slow(header.unwrap());
499                            assert_eq!(header.hash(), hash);
500                        }
501                    }
502                    _ => self.check_no_header_entry_above(initial_checkpoint)?,
503                };
504                Ok(())
505            }
506
507            async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
508                self.client.extend(headers.iter().map(|h| h.clone_header())).await;
509                let tip = if headers.is_empty() {
510                    let tip = random_header(&mut generators::rng(), 0, None);
511                    self.db.insert_headers(std::iter::once(&tip))?;
512                    tip.hash()
513                } else {
514                    headers.last().unwrap().hash()
515                };
516                self.send_tip(tip);
517                Ok(())
518            }
519        }
520
521        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
522            for HeadersTestRunner<D>
523        {
524            fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
525                self.check_no_header_entry_above(input.unwind_to)
526            }
527        }
528
529        impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
530            pub(crate) fn with_linear_downloader() -> Self {
531                let client = TestHeadersClient::default();
532                Self {
533                    client: client.clone(),
534                    channel: watch::channel(B256::ZERO),
535                    downloader_factory: Box::new(move || {
536                        ReverseHeadersDownloaderBuilder::default()
537                            .stream_batch_size(500)
538                            .build(client.clone(), Arc::new(TestConsensus::default()))
539                    }),
540                    db: TestStageDB::default(),
541                }
542            }
543        }
544
545        impl<D: HeaderDownloader> HeadersTestRunner<D> {
546            pub(crate) fn check_no_header_entry_above(
547                &self,
548                block: BlockNumber,
549            ) -> Result<(), TestRunnerError> {
550                self.db
551                    .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
552                self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
553                self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
554                Ok(())
555            }
556
557            pub(crate) fn send_tip(&self, tip: B256) {
558                self.channel.0.send(tip).expect("failed to send tip");
559            }
560        }
561    }
562
563    stage_test_suite!(HeadersTestRunner, headers);
564
565    /// Execute the stage with linear downloader, unwinds, and ensures that the database tables
566    /// along with the static files are cleaned up.
567    #[tokio::test]
568    async fn execute_with_linear_downloader_unwind() {
569        let mut runner = HeadersTestRunner::with_linear_downloader();
570        let (checkpoint, previous_stage) = (1000, 1200);
571        let input = ExecInput {
572            target: Some(previous_stage),
573            checkpoint: Some(StageCheckpoint::new(checkpoint)),
574        };
575        let headers = runner.seed_execution(input).expect("failed to seed execution");
576        let rx = runner.execute(input);
577
578        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
579
580        // skip `after_execution` hook for linear downloader
581        let tip = headers.last().unwrap();
582        runner.send_tip(tip.hash());
583
584        let result = rx.await.unwrap();
585        runner.db().factory.static_file_provider().commit().unwrap();
586        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
587            block_number,
588            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
589                block_range: CheckpointBlockRange {
590                    from,
591                    to
592                },
593                progress: EntitiesCheckpoint {
594                    processed,
595                    total,
596                }
597            }))
598        }, done: true }) if block_number == tip.number &&
599            from == checkpoint && to == previous_stage &&
600            // -1 because we don't need to download the local head
601            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
602        );
603        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
604        assert!(runner.stage().hash_collector.is_empty());
605        assert!(runner.stage().header_collector.is_empty());
606
607        // let's insert some blocks using append_blocks_with_state
608        let sealed_headers = random_header_range(
609            &mut generators::rng(),
610            tip.number + 1..tip.number + 10,
611            tip.hash(),
612        );
613
614        let provider = runner.db().factory.database_provider_rw().unwrap();
615        let static_file_provider = provider.static_file_provider();
616        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers).unwrap();
617        for header in sealed_headers {
618            writer.append_header(header.header(), &header.hash()).unwrap();
619        }
620        drop(writer);
621
622        provider.commit().unwrap();
623
624        // now we can unwind 10 blocks
625        let unwind_input = UnwindInput {
626            checkpoint: StageCheckpoint::new(tip.number + 10),
627            unwind_to: tip.number,
628            bad_block: None,
629        };
630
631        let unwind_output = runner.unwind(unwind_input).await.unwrap();
632        assert_eq!(unwind_output.checkpoint.block_number, tip.number);
633
634        // validate the unwind, ensure that the tables are cleaned up
635        assert!(runner.validate_unwind(unwind_input).is_ok());
636    }
637
638    /// Execute the stage with linear downloader
639    #[tokio::test]
640    async fn execute_with_linear_downloader() {
641        let mut runner = HeadersTestRunner::with_linear_downloader();
642        let (checkpoint, previous_stage) = (1000, 1200);
643        let input = ExecInput {
644            target: Some(previous_stage),
645            checkpoint: Some(StageCheckpoint::new(checkpoint)),
646        };
647        let headers = runner.seed_execution(input).expect("failed to seed execution");
648        let rx = runner.execute(input);
649
650        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
651
652        // skip `after_execution` hook for linear downloader
653        let tip = headers.last().unwrap();
654        runner.send_tip(tip.hash());
655
656        let result = rx.await.unwrap();
657        runner.db().factory.static_file_provider().commit().unwrap();
658        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
659            block_number,
660            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
661                block_range: CheckpointBlockRange {
662                    from,
663                    to
664                },
665                progress: EntitiesCheckpoint {
666                    processed,
667                    total,
668                }
669            }))
670        }, done: true }) if block_number == tip.number &&
671            from == checkpoint && to == previous_stage &&
672            // -1 because we don't need to download the local head
673            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
674        );
675        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
676        assert!(runner.stage().hash_collector.is_empty());
677        assert!(runner.stage().header_collector.is_empty());
678    }
679}