Skip to main content

reth_stages/stages/
headers.rs

1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use futures_util::StreamExt;
4use reth_config::config::EtlConfig;
5use reth_db_api::{
6    cursor::{DbCursorRO, DbCursorRW},
7    table::Value,
8    tables,
9    transaction::{DbTx, DbTxMut},
10    DbTxUnwindExt, RawKey, RawTable, RawValue,
11};
12use reth_etl::Collector;
13use reth_network_p2p::headers::{
14    downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
15    error::HeadersDownloaderError,
16};
17use reth_primitives_traits::{
18    serde_bincode_compat, FullBlockHeader, HeaderTy, NodePrimitives, SealedHeader,
19};
20use reth_provider::{
21    providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderSyncGapProvider,
22    StaticFileProviderFactory,
23};
24use reth_stages_api::{
25    CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
26    StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
27};
28use reth_static_file_types::StaticFileSegment;
29use std::task::{ready, Context, Poll};
30
31use tokio::sync::watch;
32use tracing::*;
33
34/// The headers stage.
35///
36/// The headers stage downloads all block headers from the highest block in storage to
37/// the perceived highest block on the network.
38///
39/// The headers are processed and data is inserted into static files, as well as into the
40/// [`HeaderNumbers`][reth_db_api::tables::HeaderNumbers] table.
41///
42/// NOTE: This stage downloads headers in reverse and pushes them to the ETL [`Collector`]. It then
43/// proceeds to push them sequentially to static files. The stage checkpoint is not updated until
44/// this stage is done.
45#[derive(Debug)]
46pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
47    /// Database handle.
48    provider: Provider,
49    /// Strategy for downloading the headers
50    downloader: Downloader,
51    /// The tip for the stage.
52    ///
53    /// This determines the sync target of the stage (set by the pipeline).
54    tip: watch::Receiver<B256>,
55    /// Current sync gap.
56    sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
57    /// ETL collector with `HeaderHash` -> `BlockNumber`
58    hash_collector: Collector<BlockHash, BlockNumber>,
59    /// ETL collector with `BlockNumber` -> `BincodeSealedHeader`
60    header_collector: Collector<BlockNumber, Bytes>,
61    /// Returns true if the ETL collector has all necessary headers to fill the gap.
62    is_etl_ready: bool,
63}
64
65// === impl HeaderStage ===
66
67impl<Provider, Downloader> HeaderStage<Provider, Downloader>
68where
69    Downloader: HeaderDownloader,
70{
71    /// Create a new header stage
72    pub fn new(
73        database: Provider,
74        downloader: Downloader,
75        tip: watch::Receiver<B256>,
76        etl_config: EtlConfig,
77    ) -> Self {
78        Self {
79            provider: database,
80            downloader,
81            tip,
82            sync_gap: None,
83            hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
84            header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
85            is_etl_ready: false,
86        }
87    }
88
89    /// Clear all ETL state. Called on error paths to prevent buffer pollution on retry.
90    fn clear_etl_state(&mut self) {
91        self.sync_gap = None;
92        self.hash_collector.clear();
93        self.header_collector.clear();
94        self.is_etl_ready = false;
95    }
96
97    /// Write downloaded headers to storage from ETL.
98    ///
99    /// Writes to static files ( `Header | HeaderTD | HeaderHash` ) and [`tables::HeaderNumbers`]
100    /// database table.
101    fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
102    where
103        P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
104        Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
105        <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
106    {
107        let total_headers = self.header_collector.len();
108
109        info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
110
111        let static_file_provider = provider.static_file_provider();
112
113        // Consistency check of expected headers in static files vs DB is done on provider::sync_gap
114        // when poll_execute_ready is polled.
115        let mut last_header_number = static_file_provider
116            .get_highest_static_file_block(StaticFileSegment::Headers)
117            .unwrap_or_default();
118
119        // Although headers were downloaded in reverse order, the collector iterates it in ascending
120        // order
121        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
122        let interval = (total_headers / 10).max(1);
123        for (index, header) in self.header_collector.iter()?.enumerate() {
124            let (_, header_buf) = header?;
125
126            if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
127                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
128            }
129
130            let sealed_header: SealedHeader<Downloader::Header> =
131                bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
132                    .map_err(|err| StageError::Fatal(Box::new(err)))?
133                    .into();
134
135            let (header, header_hash) = sealed_header.split_ref();
136            if header.number() == 0 {
137                continue
138            }
139            last_header_number = header.number();
140
141            // Append to Headers segment
142            writer.append_header(header, header_hash)?;
143        }
144
145        info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
146
147        let mut cursor_header_numbers =
148            provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
149        // If we only have the genesis block hash, then we are at first sync, and we can remove it,
150        // add it to the collector and use tx.append on all hashes.
151        let first_sync = if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 &&
152            let Some((hash, block_number)) = cursor_header_numbers.last()? &&
153            block_number.value()? == 0
154        {
155            self.hash_collector.insert(hash.key()?, 0)?;
156            cursor_header_numbers.delete_current()?;
157            true
158        } else {
159            false
160        };
161
162        // Since ETL sorts all entries by hashes, we are either appending (first sync) or inserting
163        // in order (further syncs).
164        for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
165            let (hash, number) = hash_to_number?;
166
167            if index > 0 && index.is_multiple_of(interval) && total_headers > 100 {
168                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
169            }
170
171            if first_sync {
172                cursor_header_numbers.append(
173                    RawKey::<BlockHash>::from_vec(hash),
174                    &RawValue::<BlockNumber>::from_vec(number),
175                )?;
176            } else {
177                cursor_header_numbers.upsert(
178                    RawKey::<BlockHash>::from_vec(hash),
179                    &RawValue::<BlockNumber>::from_vec(number),
180                )?;
181            }
182        }
183
184        Ok(last_header_number)
185    }
186}
187
188impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
189where
190    Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
191    P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
192    D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
193    <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
194{
195    /// Return the id of the stage
196    fn id(&self) -> StageId {
197        StageId::Headers
198    }
199
200    fn poll_execute_ready(
201        &mut self,
202        cx: &mut Context<'_>,
203        input: ExecInput,
204    ) -> Poll<Result<(), StageError>> {
205        let current_checkpoint = input.checkpoint();
206
207        // Return if stage has already completed the gap on the ETL files
208        if self.is_etl_ready {
209            return Poll::Ready(Ok(()))
210        }
211
212        // Lookup the head and tip of the sync range
213        let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
214        let target = SyncTarget::Tip(*self.tip.borrow());
215        let gap = HeaderSyncGap { local_head, target };
216        let tip = gap.target.tip();
217
218        // Nothing to sync
219        if gap.is_closed() {
220            info!(
221                target: "sync::stages::headers",
222                checkpoint = %current_checkpoint.block_number,
223                target = ?tip,
224                "Target block already reached"
225            );
226            self.is_etl_ready = true;
227            self.sync_gap = Some(gap);
228            return Poll::Ready(Ok(()))
229        }
230
231        debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
232        let local_head_number = gap.local_head.number();
233
234        // let the downloader know what to sync
235        if self.sync_gap != Some(gap.clone()) {
236            self.sync_gap = Some(gap.clone());
237            self.downloader.update_sync_gap(gap.local_head, gap.target);
238        }
239
240        // We only want to stop once we have all the headers on ETL filespace (disk).
241        loop {
242            match ready!(self.downloader.poll_next_unpin(cx)) {
243                Some(Ok(headers)) => {
244                    info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
245                    for header in headers {
246                        let header_number = header.number();
247
248                        self.hash_collector.insert(header.hash(), header_number)?;
249                        self.header_collector.insert(
250                            header_number,
251                            Bytes::from(
252                                bincode::serialize(&serde_bincode_compat::SealedHeader::from(
253                                    &header,
254                                ))
255                                .map_err(|err| StageError::Fatal(Box::new(err)))?,
256                            ),
257                        )?;
258
259                        // Headers are downloaded in reverse, so if we reach here, we know we have
260                        // filled the gap.
261                        if header_number == local_head_number + 1 {
262                            self.is_etl_ready = true;
263                            return Poll::Ready(Ok(()))
264                        }
265                    }
266                }
267                Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
268                    error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
269                    self.clear_etl_state();
270                    return Poll::Ready(Err(StageError::DetachedHead {
271                        local_head: Box::new(local_head.block_with_parent()),
272                        header: Box::new(header.block_with_parent()),
273                        error,
274                    }))
275                }
276                None => {
277                    self.clear_etl_state();
278                    return Poll::Ready(Err(StageError::ChannelClosed))
279                }
280            }
281        }
282    }
283
284    /// Download the headers in reverse order (falling block numbers)
285    /// starting from the tip of the chain
286    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
287        let current_checkpoint = input.checkpoint();
288
289        if self.sync_gap.take().ok_or(StageError::MissingSyncGap)?.is_closed() {
290            self.is_etl_ready = false;
291            return Ok(ExecOutput::done(current_checkpoint))
292        }
293
294        // We should be here only after we have downloaded all headers into the disk buffer (ETL).
295        if !self.is_etl_ready {
296            return Err(StageError::MissingDownloadBuffer)
297        }
298
299        // Reset flag
300        self.is_etl_ready = false;
301
302        // Write the headers and related tables to DB from ETL space
303        let to_be_processed = self.hash_collector.len() as u64;
304        let last_header_number = self.write_headers(provider)?;
305
306        // Clear ETL collectors
307        self.hash_collector.clear();
308        self.header_collector.clear();
309
310        Ok(ExecOutput {
311            checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
312                HeadersCheckpoint {
313                    block_range: CheckpointBlockRange {
314                        from: input.checkpoint().block_number,
315                        to: last_header_number,
316                    },
317                    progress: EntitiesCheckpoint {
318                        processed: input.checkpoint().block_number + to_be_processed,
319                        total: last_header_number,
320                    },
321                },
322            ),
323            // We only reach here if all headers have been downloaded by ETL, and pushed to DB all
324            // in one stage run.
325            done: true,
326        })
327    }
328
329    /// Unwind the stage.
330    fn unwind(
331        &mut self,
332        provider: &Provider,
333        input: UnwindInput,
334    ) -> Result<UnwindOutput, StageError> {
335        self.clear_etl_state();
336
337        // First unwind the db tables, until the unwind_to block number. use the walker to unwind
338        // HeaderNumbers based on the index in CanonicalHeaders
339        // unwind from the next block number since the unwind_to block is exclusive
340        provider
341            .tx_ref()
342            .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
343                (input.unwind_to + 1)..,
344            )?;
345        provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
346        let unfinalized_headers_unwound = provider.tx_ref().unwind_table_by_num::<tables::Headers<
347            HeaderTy<Provider::Primitives>,
348        >>(input.unwind_to)?;
349
350        // determine how many headers to unwind from the static files based on the highest block and
351        // the unwind_to block
352        let static_file_provider = provider.static_file_provider();
353        let highest_block = static_file_provider
354            .get_highest_static_file_block(StaticFileSegment::Headers)
355            .unwrap_or_default();
356        let static_file_headers_to_unwind = highest_block - input.unwind_to;
357        for block_number in (input.unwind_to + 1)..=highest_block {
358            let hash = static_file_provider.block_hash(block_number)?;
359            // we have to delete from HeaderNumbers here as well as in the above unwind, since that
360            // mapping contains entries for both headers in the db and headers in static files
361            //
362            // so if we are unwinding past the lowest block in the db, we have to iterate through
363            // the HeaderNumbers entries that we'll delete in static files below
364            if let Some(header_hash) = hash {
365                provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
366            }
367        }
368
369        // Now unwind the static files until the unwind_to block number
370        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
371        writer.prune_headers(static_file_headers_to_unwind)?;
372
373        // Set the stage checkpoint entities processed based on how much we unwound - we add the
374        // headers unwound from static files and db
375        let stage_checkpoint =
376            input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
377                block_range: stage_checkpoint.block_range,
378                progress: EntitiesCheckpoint {
379                    processed: stage_checkpoint.progress.processed.saturating_sub(
380                        static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
381                    ),
382                    total: stage_checkpoint.progress.total,
383                },
384            });
385
386        let mut checkpoint = StageCheckpoint::new(input.unwind_to);
387        if let Some(stage_checkpoint) = stage_checkpoint {
388            checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
389        }
390
391        Ok(UnwindOutput { checkpoint })
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::test_utils::{
399        stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
400    };
401    use alloy_primitives::B256;
402    use assert_matches::assert_matches;
403    use reth_provider::{DatabaseProviderFactory, ProviderFactory, StaticFileProviderFactory};
404    use reth_stages_api::StageUnitCheckpoint;
405    use reth_testing_utils::generators::{self, random_header, random_header_range};
406    use std::sync::Arc;
407    use test_runner::HeadersTestRunner;
408
409    mod test_runner {
410        use super::*;
411        use crate::test_utils::{TestRunnerError, TestStageDB};
412        use reth_consensus::test_utils::TestConsensus;
413        use reth_downloaders::headers::reverse_headers::{
414            ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
415        };
416        use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
417        use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader, HeaderProvider};
418        use tokio::sync::watch;
419
420        pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
421            pub(crate) client: TestHeadersClient,
422            channel: (watch::Sender<B256>, watch::Receiver<B256>),
423            downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
424            db: TestStageDB,
425        }
426
427        impl Default for HeadersTestRunner<TestHeaderDownloader> {
428            fn default() -> Self {
429                let client = TestHeadersClient::default();
430                Self {
431                    client: client.clone(),
432                    channel: watch::channel(B256::ZERO),
433
434                    downloader_factory: Box::new(move || {
435                        TestHeaderDownloader::new(client.clone(), 1000, 1000)
436                    }),
437                    db: TestStageDB::default(),
438                }
439            }
440        }
441
442        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
443            for HeadersTestRunner<D>
444        {
445            type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
446
447            fn db(&self) -> &TestStageDB {
448                &self.db
449            }
450
451            fn stage(&self) -> Self::S {
452                HeaderStage::new(
453                    self.db.factory.clone(),
454                    (*self.downloader_factory)(),
455                    self.channel.1.clone(),
456                    EtlConfig::default(),
457                )
458            }
459        }
460
461        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
462            for HeadersTestRunner<D>
463        {
464            type Seed = Vec<SealedHeader>;
465
466            fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
467                let mut rng = generators::rng();
468                let start = input.checkpoint().block_number;
469                let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
470                let head = headers.last().cloned().unwrap();
471                self.db.insert_headers(headers.iter())?;
472
473                // use previous checkpoint as seed size
474                let end = input.target.unwrap_or_default() + 1;
475
476                if start + 1 >= end {
477                    return Ok(Vec::default())
478                }
479
480                let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
481                headers.insert(0, head);
482                Ok(headers)
483            }
484
485            /// Validate stored headers
486            fn validate_execution(
487                &self,
488                input: ExecInput,
489                output: Option<ExecOutput>,
490            ) -> Result<(), TestRunnerError> {
491                let initial_checkpoint = input.checkpoint().block_number;
492                match output {
493                    Some(output) if output.checkpoint.block_number > initial_checkpoint => {
494                        let provider = self.db.factory.provider()?;
495
496                        for block_num in initial_checkpoint..output.checkpoint.block_number {
497                            // look up the header hash
498                            let hash = provider.block_hash(block_num)?.expect("no header hash");
499
500                            // validate the header number
501                            assert_eq!(provider.block_number(hash)?, Some(block_num));
502
503                            // validate the header
504                            let header = provider.header_by_number(block_num)?;
505                            assert!(header.is_some());
506                            let header = SealedHeader::seal_slow(header.unwrap());
507                            assert_eq!(header.hash(), hash);
508                        }
509                    }
510                    _ => self.check_no_header_entry_above(initial_checkpoint)?,
511                };
512                Ok(())
513            }
514
515            async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
516                self.client.extend(headers.iter().map(|h| h.clone_header())).await;
517                let tip = if headers.is_empty() {
518                    let tip = random_header(&mut generators::rng(), 0, None);
519                    self.db.insert_headers(std::iter::once(&tip))?;
520                    tip.hash()
521                } else {
522                    headers.last().unwrap().hash()
523                };
524                self.send_tip(tip);
525                Ok(())
526            }
527        }
528
529        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
530            for HeadersTestRunner<D>
531        {
532            fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
533                self.check_no_header_entry_above(input.unwind_to)
534            }
535        }
536
537        impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
538            pub(crate) fn with_linear_downloader() -> Self {
539                let client = TestHeadersClient::default();
540                Self {
541                    client: client.clone(),
542                    channel: watch::channel(B256::ZERO),
543                    downloader_factory: Box::new(move || {
544                        ReverseHeadersDownloaderBuilder::default()
545                            .stream_batch_size(500)
546                            .build(client.clone(), Arc::new(TestConsensus::default()))
547                    }),
548                    db: TestStageDB::default(),
549                }
550            }
551        }
552
553        impl<D: HeaderDownloader> HeadersTestRunner<D> {
554            pub(crate) fn check_no_header_entry_above(
555                &self,
556                block: BlockNumber,
557            ) -> Result<(), TestRunnerError> {
558                self.db
559                    .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
560                self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
561                self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
562                Ok(())
563            }
564
565            pub(crate) fn send_tip(&self, tip: B256) {
566                self.channel.0.send(tip).expect("failed to send tip");
567            }
568        }
569    }
570
571    stage_test_suite!(HeadersTestRunner, headers);
572
573    /// Execute the stage with linear downloader, unwinds, and ensures that the database tables
574    /// along with the static files are cleaned up.
575    #[tokio::test]
576    async fn execute_with_linear_downloader_unwind() {
577        let mut runner = HeadersTestRunner::with_linear_downloader();
578        let (checkpoint, previous_stage) = (1000, 1200);
579        let input = ExecInput {
580            target: Some(previous_stage),
581            checkpoint: Some(StageCheckpoint::new(checkpoint)),
582        };
583        let headers = runner.seed_execution(input).expect("failed to seed execution");
584        let rx = runner.execute(input);
585
586        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
587
588        // skip `after_execution` hook for linear downloader
589        let tip = headers.last().unwrap();
590        runner.send_tip(tip.hash());
591
592        let result = rx.await.unwrap();
593        runner.db().factory.static_file_provider().commit().unwrap();
594        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
595            block_number,
596            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
597                block_range: CheckpointBlockRange {
598                    from,
599                    to
600                },
601                progress: EntitiesCheckpoint {
602                    processed,
603                    total,
604                }
605            }))
606        }, done: true }) if block_number == tip.number &&
607            from == checkpoint && to == previous_stage &&
608            // -1 because we don't need to download the local head
609            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
610        );
611        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
612        assert!(runner.stage().hash_collector.is_empty());
613        assert!(runner.stage().header_collector.is_empty());
614
615        // let's insert some blocks using append_blocks_with_state
616        let sealed_headers = random_header_range(
617            &mut generators::rng(),
618            tip.number + 1..tip.number + 10,
619            tip.hash(),
620        );
621
622        let provider = runner.db().factory.database_provider_rw().unwrap();
623        let static_file_provider = provider.static_file_provider();
624        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers).unwrap();
625        for header in sealed_headers {
626            writer.append_header(header.header(), &header.hash()).unwrap();
627        }
628        drop(writer);
629
630        provider.commit().unwrap();
631
632        // now we can unwind 10 blocks
633        let unwind_input = UnwindInput {
634            checkpoint: StageCheckpoint::new(tip.number + 10),
635            unwind_to: tip.number,
636            bad_block: None,
637        };
638
639        let unwind_output = runner.unwind(unwind_input).await.unwrap();
640        assert_eq!(unwind_output.checkpoint.block_number, tip.number);
641
642        // validate the unwind, ensure that the tables are cleaned up
643        assert!(runner.validate_unwind(unwind_input).is_ok());
644    }
645
646    /// Execute the stage with linear downloader
647    #[tokio::test]
648    async fn execute_with_linear_downloader() {
649        let mut runner = HeadersTestRunner::with_linear_downloader();
650        let (checkpoint, previous_stage) = (1000, 1200);
651        let input = ExecInput {
652            target: Some(previous_stage),
653            checkpoint: Some(StageCheckpoint::new(checkpoint)),
654        };
655        let headers = runner.seed_execution(input).expect("failed to seed execution");
656        let rx = runner.execute(input);
657
658        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
659
660        // skip `after_execution` hook for linear downloader
661        let tip = headers.last().unwrap();
662        runner.send_tip(tip.hash());
663
664        let result = rx.await.unwrap();
665        runner.db().factory.static_file_provider().commit().unwrap();
666        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
667            block_number,
668            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
669                block_range: CheckpointBlockRange {
670                    from,
671                    to
672                },
673                progress: EntitiesCheckpoint {
674                    processed,
675                    total,
676                }
677            }))
678        }, done: true }) if block_number == tip.number &&
679            from == checkpoint && to == previous_stage &&
680            // -1 because we don't need to download the local head
681            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
682        );
683        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
684        assert!(runner.stage().hash_collector.is_empty());
685        assert!(runner.stage().header_collector.is_empty());
686    }
687}