reth_stages/stages/
headers.rs

1use alloy_consensus::BlockHeader;
2use alloy_eips::{eip1898::BlockWithParent, NumHash};
3use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
4use futures_util::StreamExt;
5use reth_config::config::EtlConfig;
6use reth_consensus::HeaderValidator;
7use reth_db_api::{
8    cursor::{DbCursorRO, DbCursorRW},
9    table::Value,
10    tables,
11    transaction::{DbTx, DbTxMut},
12    DbTxUnwindExt, RawKey, RawTable, RawValue,
13};
14use reth_etl::Collector;
15use reth_network_p2p::headers::{
16    downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
17    error::HeadersDownloaderError,
18};
19use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader, NodePrimitives, SealedHeader};
20use reth_provider::{
21    providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderProvider,
22    HeaderSyncGapProvider, StaticFileProviderFactory,
23};
24use reth_stages_api::{
25    BlockErrorKind, CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput,
26    HeadersCheckpoint, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
27};
28use reth_static_file_types::StaticFileSegment;
29use reth_storage_errors::provider::ProviderError;
30use std::{
31    sync::Arc,
32    task::{ready, Context, Poll},
33};
34use tokio::sync::watch;
35use tracing::*;
36
37/// The headers stage.
38///
39/// The headers stage downloads all block headers from the highest block in storage to
40/// the perceived highest block on the network.
41///
42/// The headers are processed and data is inserted into static files, as well as into the
43/// [`HeaderNumbers`][reth_db_api::tables::HeaderNumbers] table.
44///
45/// NOTE: This stage downloads headers in reverse and pushes them to the ETL [`Collector`]. It then
46/// proceeds to push them sequentially to static files. The stage checkpoint is not updated until
47/// this stage is done.
48#[derive(Debug)]
49pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
50    /// Database handle.
51    provider: Provider,
52    /// Strategy for downloading the headers
53    downloader: Downloader,
54    /// The tip for the stage.
55    ///
56    /// This determines the sync target of the stage (set by the pipeline).
57    tip: watch::Receiver<B256>,
58    /// Consensus client implementation
59    consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
60    /// Current sync gap.
61    sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
62    /// ETL collector with `HeaderHash` -> `BlockNumber`
63    hash_collector: Collector<BlockHash, BlockNumber>,
64    /// ETL collector with `BlockNumber` -> `BincodeSealedHeader`
65    header_collector: Collector<BlockNumber, Bytes>,
66    /// Returns true if the ETL collector has all necessary headers to fill the gap.
67    is_etl_ready: bool,
68}
69
70// === impl HeaderStage ===
71
72impl<Provider, Downloader> HeaderStage<Provider, Downloader>
73where
74    Downloader: HeaderDownloader,
75{
76    /// Create a new header stage
77    pub fn new(
78        database: Provider,
79        downloader: Downloader,
80        tip: watch::Receiver<B256>,
81        consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
82        etl_config: EtlConfig,
83    ) -> Self {
84        Self {
85            provider: database,
86            downloader,
87            tip,
88            consensus,
89            sync_gap: None,
90            hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
91            header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
92            is_etl_ready: false,
93        }
94    }
95
96    /// Write downloaded headers to storage from ETL.
97    ///
98    /// Writes to static files ( `Header | HeaderTD | HeaderHash` ) and [`tables::HeaderNumbers`]
99    /// database table.
100    fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
101    where
102        P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
103        Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
104        <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
105    {
106        let total_headers = self.header_collector.len();
107
108        info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
109
110        let static_file_provider = provider.static_file_provider();
111
112        // Consistency check of expected headers in static files vs DB is done on provider::sync_gap
113        // when poll_execute_ready is polled.
114        let mut last_header_number = static_file_provider
115            .get_highest_static_file_block(StaticFileSegment::Headers)
116            .unwrap_or_default();
117
118        // Find the latest total difficulty
119        let mut td = static_file_provider
120            .header_td_by_number(last_header_number)?
121            .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;
122
123        // Although headers were downloaded in reverse order, the collector iterates it in ascending
124        // order
125        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
126        let interval = (total_headers / 10).max(1);
127        for (index, header) in self.header_collector.iter()?.enumerate() {
128            let (_, header_buf) = header?;
129
130            if index > 0 && index % interval == 0 && total_headers > 100 {
131                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
132            }
133
134            let sealed_header: SealedHeader<Downloader::Header> =
135                bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
136                    .map_err(|err| StageError::Fatal(Box::new(err)))?
137                    .into();
138
139            let (header, header_hash) = sealed_header.split_ref();
140            if header.number() == 0 {
141                continue
142            }
143            last_header_number = header.number();
144
145            // Increase total difficulty
146            td += header.difficulty();
147
148            self.consensus.validate_header(&sealed_header).map_err(|error| StageError::Block {
149                block: Box::new(BlockWithParent::new(
150                    sealed_header.parent_hash(),
151                    NumHash::new(sealed_header.number(), sealed_header.hash()),
152                )),
153                error: BlockErrorKind::Validation(error),
154            })?;
155
156            // Append to Headers segment
157            writer.append_header(header, td, header_hash)?;
158        }
159
160        info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
161
162        let mut cursor_header_numbers =
163            provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
164        let mut first_sync = false;
165
166        // If we only have the genesis block hash, then we are at first sync, and we can remove it,
167        // add it to the collector and use tx.append on all hashes.
168        if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 {
169            if let Some((hash, block_number)) = cursor_header_numbers.last()? {
170                if block_number.value()? == 0 {
171                    self.hash_collector.insert(hash.key()?, 0)?;
172                    cursor_header_numbers.delete_current()?;
173                    first_sync = true;
174                }
175            }
176        }
177
178        // Since ETL sorts all entries by hashes, we are either appending (first sync) or inserting
179        // in order (further syncs).
180        for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
181            let (hash, number) = hash_to_number?;
182
183            if index > 0 && index % interval == 0 && total_headers > 100 {
184                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
185            }
186
187            if first_sync {
188                cursor_header_numbers.append(
189                    RawKey::<BlockHash>::from_vec(hash),
190                    &RawValue::<BlockNumber>::from_vec(number),
191                )?;
192            } else {
193                cursor_header_numbers.upsert(
194                    RawKey::<BlockHash>::from_vec(hash),
195                    &RawValue::<BlockNumber>::from_vec(number),
196                )?;
197            }
198        }
199
200        Ok(last_header_number)
201    }
202}
203
204impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
205where
206    Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
207    P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
208    D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
209    <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
210{
211    /// Return the id of the stage
212    fn id(&self) -> StageId {
213        StageId::Headers
214    }
215
216    fn poll_execute_ready(
217        &mut self,
218        cx: &mut Context<'_>,
219        input: ExecInput,
220    ) -> Poll<Result<(), StageError>> {
221        let current_checkpoint = input.checkpoint();
222
223        // Return if stage has already completed the gap on the ETL files
224        if self.is_etl_ready {
225            return Poll::Ready(Ok(()))
226        }
227
228        // Lookup the head and tip of the sync range
229        let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
230        let target = SyncTarget::Tip(*self.tip.borrow());
231        let gap = HeaderSyncGap { local_head, target };
232        let tip = gap.target.tip();
233        self.sync_gap = Some(gap.clone());
234
235        // Nothing to sync
236        if gap.is_closed() {
237            info!(
238                target: "sync::stages::headers",
239                checkpoint = %current_checkpoint.block_number,
240                target = ?tip,
241                "Target block already reached"
242            );
243            self.is_etl_ready = true;
244            return Poll::Ready(Ok(()))
245        }
246
247        debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
248        let local_head_number = gap.local_head.number();
249
250        // let the downloader know what to sync
251        self.downloader.update_sync_gap(gap.local_head, gap.target);
252
253        // We only want to stop once we have all the headers on ETL filespace (disk).
254        loop {
255            match ready!(self.downloader.poll_next_unpin(cx)) {
256                Some(Ok(headers)) => {
257                    info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
258                    for header in headers {
259                        let header_number = header.number();
260
261                        self.hash_collector.insert(header.hash(), header_number)?;
262                        self.header_collector.insert(
263                            header_number,
264                            Bytes::from(
265                                bincode::serialize(&serde_bincode_compat::SealedHeader::from(
266                                    &header,
267                                ))
268                                .map_err(|err| StageError::Fatal(Box::new(err)))?,
269                            ),
270                        )?;
271
272                        // Headers are downloaded in reverse, so if we reach here, we know we have
273                        // filled the gap.
274                        if header_number == local_head_number + 1 {
275                            self.is_etl_ready = true;
276                            return Poll::Ready(Ok(()))
277                        }
278                    }
279                }
280                Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
281                    error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
282                    return Poll::Ready(Err(StageError::DetachedHead {
283                        local_head: Box::new(local_head.block_with_parent()),
284                        header: Box::new(header.block_with_parent()),
285                        error,
286                    }))
287                }
288                None => return Poll::Ready(Err(StageError::ChannelClosed)),
289            }
290        }
291    }
292
293    /// Download the headers in reverse order (falling block numbers)
294    /// starting from the tip of the chain
295    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
296        let current_checkpoint = input.checkpoint();
297
298        if self.sync_gap.as_ref().ok_or(StageError::MissingSyncGap)?.is_closed() {
299            self.is_etl_ready = false;
300            return Ok(ExecOutput::done(current_checkpoint))
301        }
302
303        // We should be here only after we have downloaded all headers into the disk buffer (ETL).
304        if !self.is_etl_ready {
305            return Err(StageError::MissingDownloadBuffer)
306        }
307
308        // Reset flag
309        self.is_etl_ready = false;
310
311        // Write the headers and related tables to DB from ETL space
312        let to_be_processed = self.hash_collector.len() as u64;
313        let last_header_number = self.write_headers(provider)?;
314
315        // Clear ETL collectors
316        self.hash_collector.clear();
317        self.header_collector.clear();
318
319        Ok(ExecOutput {
320            checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
321                HeadersCheckpoint {
322                    block_range: CheckpointBlockRange {
323                        from: input.checkpoint().block_number,
324                        to: last_header_number,
325                    },
326                    progress: EntitiesCheckpoint {
327                        processed: input.checkpoint().block_number + to_be_processed,
328                        total: last_header_number,
329                    },
330                },
331            ),
332            // We only reach here if all headers have been downloaded by ETL, and pushed to DB all
333            // in one stage run.
334            done: true,
335        })
336    }
337
338    /// Unwind the stage.
339    fn unwind(
340        &mut self,
341        provider: &Provider,
342        input: UnwindInput,
343    ) -> Result<UnwindOutput, StageError> {
344        self.sync_gap.take();
345
346        // First unwind the db tables, until the unwind_to block number. use the walker to unwind
347        // HeaderNumbers based on the index in CanonicalHeaders
348        // unwind from the next block number since the unwind_to block is exclusive
349        provider
350            .tx_ref()
351            .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
352                (input.unwind_to + 1)..,
353            )?;
354        provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
355        provider
356            .tx_ref()
357            .unwind_table_by_num::<tables::HeaderTerminalDifficulties>(input.unwind_to)?;
358        let unfinalized_headers_unwound =
359            provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
360
361        // determine how many headers to unwind from the static files based on the highest block and
362        // the unwind_to block
363        let static_file_provider = provider.static_file_provider();
364        let highest_block = static_file_provider
365            .get_highest_static_file_block(StaticFileSegment::Headers)
366            .unwrap_or_default();
367        let static_file_headers_to_unwind = highest_block - input.unwind_to;
368        for block_number in (input.unwind_to + 1)..=highest_block {
369            let hash = static_file_provider.block_hash(block_number)?;
370            // we have to delete from HeaderNumbers here as well as in the above unwind, since that
371            // mapping contains entries for both headers in the db and headers in static files
372            //
373            // so if we are unwinding past the lowest block in the db, we have to iterate through
374            // the HeaderNumbers entries that we'll delete in static files below
375            if let Some(header_hash) = hash {
376                provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
377            }
378        }
379
380        // Now unwind the static files until the unwind_to block number
381        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
382        writer.prune_headers(static_file_headers_to_unwind)?;
383
384        // Set the stage checkpoint entities processed based on how much we unwound - we add the
385        // headers unwound from static files and db
386        let stage_checkpoint =
387            input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
388                block_range: stage_checkpoint.block_range,
389                progress: EntitiesCheckpoint {
390                    processed: stage_checkpoint.progress.processed.saturating_sub(
391                        static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
392                    ),
393                    total: stage_checkpoint.progress.total,
394                },
395            });
396
397        let mut checkpoint = StageCheckpoint::new(input.unwind_to);
398        if let Some(stage_checkpoint) = stage_checkpoint {
399            checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
400        }
401
402        Ok(UnwindOutput { checkpoint })
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use crate::test_utils::{
410        stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
411    };
412    use alloy_primitives::B256;
413    use assert_matches::assert_matches;
414    use reth_ethereum_primitives::BlockBody;
415    use reth_execution_types::ExecutionOutcome;
416    use reth_primitives_traits::{RecoveredBlock, SealedBlock};
417    use reth_provider::{BlockWriter, ProviderFactory, StaticFileProviderFactory};
418    use reth_stages_api::StageUnitCheckpoint;
419    use reth_testing_utils::generators::{self, random_header, random_header_range};
420    use reth_trie::{updates::TrieUpdates, HashedPostStateSorted};
421    use test_runner::HeadersTestRunner;
422
423    mod test_runner {
424        use super::*;
425        use crate::test_utils::{TestRunnerError, TestStageDB};
426        use reth_consensus::test_utils::TestConsensus;
427        use reth_downloaders::headers::reverse_headers::{
428            ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
429        };
430        use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
431        use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader};
432        use tokio::sync::watch;
433
434        pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
435            pub(crate) client: TestHeadersClient,
436            channel: (watch::Sender<B256>, watch::Receiver<B256>),
437            downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
438            db: TestStageDB,
439            consensus: Arc<TestConsensus>,
440        }
441
442        impl Default for HeadersTestRunner<TestHeaderDownloader> {
443            fn default() -> Self {
444                let client = TestHeadersClient::default();
445                Self {
446                    client: client.clone(),
447                    channel: watch::channel(B256::ZERO),
448                    consensus: Arc::new(TestConsensus::default()),
449                    downloader_factory: Box::new(move || {
450                        TestHeaderDownloader::new(
451                            client.clone(),
452                            Arc::new(TestConsensus::default()),
453                            1000,
454                            1000,
455                        )
456                    }),
457                    db: TestStageDB::default(),
458                }
459            }
460        }
461
462        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
463            for HeadersTestRunner<D>
464        {
465            type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
466
467            fn db(&self) -> &TestStageDB {
468                &self.db
469            }
470
471            fn stage(&self) -> Self::S {
472                HeaderStage::new(
473                    self.db.factory.clone(),
474                    (*self.downloader_factory)(),
475                    self.channel.1.clone(),
476                    self.consensus.clone(),
477                    EtlConfig::default(),
478                )
479            }
480        }
481
482        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
483            for HeadersTestRunner<D>
484        {
485            type Seed = Vec<SealedHeader>;
486
487            fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
488                let mut rng = generators::rng();
489                let start = input.checkpoint().block_number;
490                let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
491                let head = headers.last().cloned().unwrap();
492                self.db.insert_headers_with_td(headers.iter())?;
493
494                // use previous checkpoint as seed size
495                let end = input.target.unwrap_or_default() + 1;
496
497                if start + 1 >= end {
498                    return Ok(Vec::default())
499                }
500
501                let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
502                headers.insert(0, head);
503                Ok(headers)
504            }
505
506            /// Validate stored headers
507            fn validate_execution(
508                &self,
509                input: ExecInput,
510                output: Option<ExecOutput>,
511            ) -> Result<(), TestRunnerError> {
512                let initial_checkpoint = input.checkpoint().block_number;
513                match output {
514                    Some(output) if output.checkpoint.block_number > initial_checkpoint => {
515                        let provider = self.db.factory.provider()?;
516                        let mut td = provider
517                            .header_td_by_number(initial_checkpoint.saturating_sub(1))?
518                            .unwrap_or_default();
519
520                        for block_num in initial_checkpoint..output.checkpoint.block_number {
521                            // look up the header hash
522                            let hash = provider.block_hash(block_num)?.expect("no header hash");
523
524                            // validate the header number
525                            assert_eq!(provider.block_number(hash)?, Some(block_num));
526
527                            // validate the header
528                            let header = provider.header_by_number(block_num)?;
529                            assert!(header.is_some());
530                            let header = SealedHeader::seal_slow(header.unwrap());
531                            assert_eq!(header.hash(), hash);
532
533                            // validate the header total difficulty
534                            td += header.difficulty;
535                            assert_eq!(provider.header_td_by_number(block_num)?, Some(td));
536                        }
537                    }
538                    _ => self.check_no_header_entry_above(initial_checkpoint)?,
539                };
540                Ok(())
541            }
542
543            async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
544                self.client.extend(headers.iter().map(|h| h.clone_header())).await;
545                let tip = if headers.is_empty() {
546                    let tip = random_header(&mut generators::rng(), 0, None);
547                    self.db.insert_headers(std::iter::once(&tip))?;
548                    tip.hash()
549                } else {
550                    headers.last().unwrap().hash()
551                };
552                self.send_tip(tip);
553                Ok(())
554            }
555        }
556
557        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
558            for HeadersTestRunner<D>
559        {
560            fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
561                self.check_no_header_entry_above(input.unwind_to)
562            }
563        }
564
565        impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
566            pub(crate) fn with_linear_downloader() -> Self {
567                let client = TestHeadersClient::default();
568                Self {
569                    client: client.clone(),
570                    channel: watch::channel(B256::ZERO),
571                    downloader_factory: Box::new(move || {
572                        ReverseHeadersDownloaderBuilder::default()
573                            .stream_batch_size(500)
574                            .build(client.clone(), Arc::new(TestConsensus::default()))
575                    }),
576                    db: TestStageDB::default(),
577                    consensus: Arc::new(TestConsensus::default()),
578                }
579            }
580        }
581
582        impl<D: HeaderDownloader> HeadersTestRunner<D> {
583            pub(crate) fn check_no_header_entry_above(
584                &self,
585                block: BlockNumber,
586            ) -> Result<(), TestRunnerError> {
587                self.db
588                    .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
589                self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
590                self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
591                self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
592                    block,
593                    |num| num,
594                )?;
595                Ok(())
596            }
597
598            pub(crate) fn send_tip(&self, tip: B256) {
599                self.channel.0.send(tip).expect("failed to send tip");
600            }
601        }
602    }
603
604    stage_test_suite!(HeadersTestRunner, headers);
605
606    /// Execute the stage with linear downloader, unwinds, and ensures that the database tables
607    /// along with the static files are cleaned up.
608    #[tokio::test]
609    async fn execute_with_linear_downloader_unwind() {
610        let mut runner = HeadersTestRunner::with_linear_downloader();
611        let (checkpoint, previous_stage) = (1000, 1200);
612        let input = ExecInput {
613            target: Some(previous_stage),
614            checkpoint: Some(StageCheckpoint::new(checkpoint)),
615        };
616        let headers = runner.seed_execution(input).expect("failed to seed execution");
617        let rx = runner.execute(input);
618
619        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
620
621        // skip `after_execution` hook for linear downloader
622        let tip = headers.last().unwrap();
623        runner.send_tip(tip.hash());
624
625        let result = rx.await.unwrap();
626        runner.db().factory.static_file_provider().commit().unwrap();
627        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
628            block_number,
629            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
630                block_range: CheckpointBlockRange {
631                    from,
632                    to
633                },
634                progress: EntitiesCheckpoint {
635                    processed,
636                    total,
637                }
638            }))
639        }, done: true }) if block_number == tip.number &&
640            from == checkpoint && to == previous_stage &&
641            // -1 because we don't need to download the local head
642            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
643        );
644        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
645        assert!(runner.stage().hash_collector.is_empty());
646        assert!(runner.stage().header_collector.is_empty());
647
648        // let's insert some blocks using append_blocks_with_state
649        let sealed_headers =
650            random_header_range(&mut generators::rng(), tip.number..tip.number + 10, tip.hash());
651
652        // make them sealed blocks with senders by converting them to empty blocks
653        let sealed_blocks = sealed_headers
654            .iter()
655            .map(|header| {
656                RecoveredBlock::new_sealed(
657                    SealedBlock::from_sealed_parts(header.clone(), BlockBody::default()),
658                    vec![],
659                )
660            })
661            .collect();
662
663        // append the blocks
664        let provider = runner.db().factory.provider_rw().unwrap();
665        provider
666            .append_blocks_with_state(
667                sealed_blocks,
668                &ExecutionOutcome::default(),
669                HashedPostStateSorted::default(),
670                TrieUpdates::default(),
671            )
672            .unwrap();
673        provider.commit().unwrap();
674
675        // now we can unwind 10 blocks
676        let unwind_input = UnwindInput {
677            checkpoint: StageCheckpoint::new(tip.number + 10),
678            unwind_to: tip.number,
679            bad_block: None,
680        };
681
682        let unwind_output = runner.unwind(unwind_input).await.unwrap();
683        assert_eq!(unwind_output.checkpoint.block_number, tip.number);
684
685        // validate the unwind, ensure that the tables are cleaned up
686        assert!(runner.validate_unwind(unwind_input).is_ok());
687    }
688
689    /// Execute the stage with linear downloader
690    #[tokio::test]
691    async fn execute_with_linear_downloader() {
692        let mut runner = HeadersTestRunner::with_linear_downloader();
693        let (checkpoint, previous_stage) = (1000, 1200);
694        let input = ExecInput {
695            target: Some(previous_stage),
696            checkpoint: Some(StageCheckpoint::new(checkpoint)),
697        };
698        let headers = runner.seed_execution(input).expect("failed to seed execution");
699        let rx = runner.execute(input);
700
701        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
702
703        // skip `after_execution` hook for linear downloader
704        let tip = headers.last().unwrap();
705        runner.send_tip(tip.hash());
706
707        let result = rx.await.unwrap();
708        runner.db().factory.static_file_provider().commit().unwrap();
709        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
710            block_number,
711            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
712                block_range: CheckpointBlockRange {
713                    from,
714                    to
715                },
716                progress: EntitiesCheckpoint {
717                    processed,
718                    total,
719                }
720            }))
721        }, done: true }) if block_number == tip.number &&
722            from == checkpoint && to == previous_stage &&
723            // -1 because we don't need to download the local head
724            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
725        );
726        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
727        assert!(runner.stage().hash_collector.is_empty());
728        assert!(runner.stage().header_collector.is_empty());
729    }
730}