Skip to main content

reth_stages/stages/
sender_recovery.rs

1use alloy_primitives::{Address, BlockNumber, TxNumber};
2use reth_config::config::SenderRecoveryConfig;
3use reth_consensus::ConsensusError;
4use reth_db::static_file::TransactionMask;
5use reth_db_api::{
6    cursor::DbCursorRW,
7    table::Value,
8    tables,
9    transaction::{DbTx, DbTxMut},
10    RawValue,
11};
12use reth_primitives_traits::{GotExpected, NodePrimitives, SignedTransaction};
13use reth_provider::{
14    BlockReader, DBProvider, EitherWriter, HeaderProvider, ProviderError, PruneCheckpointReader,
15    PruneCheckpointWriter, StaticFileProviderFactory, StatsReader, StorageSettingsCache,
16    TransactionsProvider,
17};
18use reth_prune_types::{PruneCheckpoint, PruneMode, PrunePurpose, PruneSegment};
19use reth_stages_api::{
20    BlockErrorKind, EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError,
21    StageId, UnwindInput, UnwindOutput,
22};
23use reth_static_file_types::StaticFileSegment;
24use std::{fmt::Debug, ops::Range, sync::mpsc, time::Instant};
25use thiserror::Error;
26use tracing::*;
27
28/// Maximum amount of transactions to read from disk at one time before we flush their senders to
29/// disk. Since each rayon worker will hold at most 100 transactions (`WORKER_CHUNK_SIZE`), we
30/// effectively max limit each batch to 1000 channels in memory.
31const BATCH_SIZE: usize = 100_000;
32
33/// Maximum number of senders to recover per rayon worker job.
34const WORKER_CHUNK_SIZE: usize = 100;
35
36/// Type alias for a sender that transmits the result of sender recovery.
37type RecoveryResultSender = mpsc::Sender<Result<(u64, Address), Box<SenderRecoveryStageError>>>;
38
39/// The sender recovery stage iterates over existing transactions,
40/// recovers the transaction signer and stores them
41/// in [`TransactionSenders`][reth_db_api::tables::TransactionSenders] table.
42#[derive(Clone, Debug)]
43pub struct SenderRecoveryStage {
44    /// The size of inserted items after which the control
45    /// flow will be returned to the pipeline for commit
46    pub commit_threshold: u64,
47    /// Prune mode for sender recovery. When set to `PruneMode::Full`, the stage will
48    /// fast-forward its checkpoint to skip all work, since senders will be recovered
49    /// inline by the execution stage instead.
50    pub prune_mode: Option<PruneMode>,
51}
52
53impl SenderRecoveryStage {
54    /// Create new instance of [`SenderRecoveryStage`].
55    pub const fn new(config: SenderRecoveryConfig, prune_mode: Option<PruneMode>) -> Self {
56        Self { commit_threshold: config.commit_threshold, prune_mode }
57    }
58}
59
60impl Default for SenderRecoveryStage {
61    fn default() -> Self {
62        Self { commit_threshold: 5_000_000, prune_mode: None }
63    }
64}
65
66impl<Provider> Stage<Provider> for SenderRecoveryStage
67where
68    Provider: DBProvider<Tx: DbTxMut>
69        + BlockReader
70        + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>
71        + StatsReader
72        + PruneCheckpointReader
73        + PruneCheckpointWriter
74        + StorageSettingsCache,
75{
76    /// Return the id of the stage
77    fn id(&self) -> StageId {
78        StageId::SenderRecovery
79    }
80
81    /// Retrieve the range of transactions to iterate over by querying
82    /// [`BlockBodyIndices`][reth_db_api::tables::BlockBodyIndices],
83    /// collect transactions within that range, recover signer for each transaction and store
84    /// entries in the [`TransactionSenders`][reth_db_api::tables::TransactionSenders] table or
85    /// static files depending on configuration.
86    fn execute(
87        &mut self,
88        provider: &Provider,
89        mut input: ExecInput,
90    ) -> Result<ExecOutput, StageError> {
91        // TODO: when senders are fully pruned, batch recover in execution stage instead of per-tx
92        // fallback
93        if let Some((target_prunable_block, prune_mode)) = self
94            .prune_mode
95            .map(|mode| {
96                mode.prune_target_block(
97                    input.target(),
98                    PruneSegment::SenderRecovery,
99                    PrunePurpose::User,
100                )
101            })
102            .transpose()?
103            .flatten() &&
104            target_prunable_block > input.checkpoint().block_number
105        {
106            input.checkpoint = Some(StageCheckpoint::new(target_prunable_block));
107
108            if provider.get_prune_checkpoint(PruneSegment::SenderRecovery)?.is_none() {
109                let target_prunable_tx_number = provider
110                    .block_body_indices(target_prunable_block)?
111                    .ok_or(ProviderError::BlockBodyIndicesNotFound(target_prunable_block))?
112                    .last_tx_num();
113
114                provider.save_prune_checkpoint(
115                    PruneSegment::SenderRecovery,
116                    PruneCheckpoint {
117                        block_number: Some(target_prunable_block),
118                        tx_number: Some(target_prunable_tx_number),
119                        prune_mode,
120                    },
121                )?;
122            }
123        }
124
125        if input.target_reached() {
126            return Ok(ExecOutput::done(input.checkpoint()))
127        }
128
129        let Some(range_output) =
130            input.next_block_range_with_transaction_threshold(provider, self.commit_threshold)?
131        else {
132            info!(target: "sync::stages::sender_recovery", "No transaction senders to recover");
133            EitherWriter::new_senders(
134                provider,
135                provider
136                    .static_file_provider()
137                    .get_highest_static_file_block(StaticFileSegment::TransactionSenders)
138                    .unwrap_or_default(),
139            )?
140            .ensure_at_block(input.target())?;
141            return Ok(ExecOutput {
142                checkpoint: StageCheckpoint::new(input.target())
143                    .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
144                done: true,
145            })
146        };
147        let end_block = *range_output.block_range.end();
148
149        let mut writer = EitherWriter::new_senders(provider, *range_output.block_range.start())?;
150
151        info!(target: "sync::stages::sender_recovery", tx_range = ?range_output.tx_range, "Recovering senders");
152
153        // Iterate over transactions in batches, recover the senders and append them
154        let batch = range_output
155            .tx_range
156            .clone()
157            .step_by(BATCH_SIZE)
158            .map(|start| start..std::cmp::min(start + BATCH_SIZE as u64, range_output.tx_range.end))
159            .collect::<Vec<Range<u64>>>();
160
161        let tx_batch_sender = setup_range_recovery(provider);
162
163        let start = Instant::now();
164        let block_body_indices =
165            provider.block_body_indices_range(range_output.block_range.clone())?;
166        let block_body_indices_elapsed = start.elapsed();
167        let mut blocks_with_indices = range_output.block_range.zip(block_body_indices).peekable();
168
169        for range in batch {
170            // Pair each transaction number with its block number
171            let start = Instant::now();
172            let block_numbers = range.clone().fold(Vec::new(), |mut block_numbers, tx| {
173                while let Some((block, index)) = blocks_with_indices.peek() {
174                    if index.contains_tx(tx) {
175                        block_numbers.push(*block);
176                        return block_numbers
177                    }
178                    blocks_with_indices.next();
179                }
180                block_numbers
181            });
182            let fold_elapsed = start.elapsed();
183            debug!(target: "sync::stages::sender_recovery", ?block_body_indices_elapsed, ?fold_elapsed, len = block_numbers.len(), "Calculated block numbers");
184            recover_range(range, block_numbers, provider, tx_batch_sender.clone(), &mut writer)?;
185        }
186
187        // Advance the static file header to the end of this range to account for empty blocks.
188        writer.ensure_at_block(end_block)?;
189
190        Ok(ExecOutput {
191            checkpoint: StageCheckpoint::new(end_block)
192                .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
193            done: range_output.is_final_range,
194        })
195    }
196
197    /// Unwind the stage.
198    fn unwind(
199        &mut self,
200        provider: &Provider,
201        input: UnwindInput,
202    ) -> Result<UnwindOutput, StageError> {
203        let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
204
205        if self.prune_mode.is_none_or(|mode| !mode.is_full()) {
206            // Lookup the next tx id after unwind_to block (first tx to remove)
207            let unwind_tx_from = provider
208                .block_body_indices(unwind_to)?
209                .ok_or(ProviderError::BlockBodyIndicesNotFound(unwind_to))?
210                .next_tx_num();
211
212            EitherWriter::new_senders(provider, unwind_to)?
213                .prune_senders(unwind_tx_from, unwind_to)?;
214        }
215
216        Ok(UnwindOutput {
217            checkpoint: StageCheckpoint::new(unwind_to)
218                .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
219        })
220    }
221}
222
223fn recover_range<Provider, CURSOR>(
224    tx_range: Range<TxNumber>,
225    block_numbers: Vec<BlockNumber>,
226    provider: &Provider,
227    tx_batch_sender: mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>,
228    writer: &mut EitherWriter<'_, CURSOR, Provider::Primitives>,
229) -> Result<(), StageError>
230where
231    Provider: DBProvider + HeaderProvider + TransactionsProvider + StaticFileProviderFactory,
232    CURSOR: DbCursorRW<tables::TransactionSenders>,
233{
234    debug_assert_eq!(
235        tx_range.clone().count(),
236        block_numbers.len(),
237        "Transaction range and block numbers count mismatch"
238    );
239
240    debug!(target: "sync::stages::sender_recovery", ?tx_range, "Sending batch for processing");
241
242    // Preallocate channels for each chunks in the batch
243    let (chunks, receivers): (Vec<_>, Vec<_>) = tx_range
244        .clone()
245        .step_by(WORKER_CHUNK_SIZE)
246        .map(|start| {
247            let range = start..std::cmp::min(start + WORKER_CHUNK_SIZE as u64, tx_range.end);
248            let (tx, rx) = mpsc::channel();
249            // Range and channel sender will be sent to rayon worker
250            ((range, tx), rx)
251        })
252        .unzip();
253
254    if let Some(err) = tx_batch_sender.send(chunks).err() {
255        return Err(StageError::Fatal(err.into()));
256    }
257
258    debug!(target: "sync::stages::sender_recovery", ?tx_range, "Appending recovered senders to the database");
259
260    let mut processed_transactions = 0;
261    let mut block_numbers = block_numbers.into_iter();
262    for channel in receivers {
263        while let Ok(recovered) = channel.recv() {
264            let (tx_id, sender) = match recovered {
265                Ok(result) => result,
266                Err(error) => {
267                    return match *error {
268                        SenderRecoveryStageError::FailedRecovery(err) => {
269                            // get the block number for the bad transaction
270                            let block_number = provider
271                                .tx_ref()
272                                .get::<tables::TransactionBlocks>(err.tx)?
273                                .ok_or(ProviderError::BlockNumberForTransactionIndexNotFound)?;
274
275                            // fetch the sealed header so we can use it in the sender recovery
276                            // unwind
277                            let sealed_header =
278                                provider.sealed_header(block_number)?.ok_or_else(|| {
279                                    ProviderError::HeaderNotFound(block_number.into())
280                                })?;
281
282                            Err(StageError::Block {
283                                block: Box::new(sealed_header.block_with_parent()),
284                                error: BlockErrorKind::Validation(
285                                    ConsensusError::TransactionSignerRecoveryError,
286                                ),
287                            })
288                        }
289                        SenderRecoveryStageError::StageError(err) => Err(err),
290                        SenderRecoveryStageError::RecoveredSendersMismatch(expectation) => {
291                            Err(StageError::Fatal(
292                                SenderRecoveryStageError::RecoveredSendersMismatch(expectation)
293                                    .into(),
294                            ))
295                        }
296                    }
297                }
298            };
299
300            let new_block_number = block_numbers
301                .next()
302                .expect("block numbers iterator has the same length as the number of transactions");
303            writer.ensure_at_block(new_block_number)?;
304            writer.append_sender(tx_id, &sender)?;
305            processed_transactions += 1;
306        }
307    }
308    debug!(target: "sync::stages::sender_recovery", ?tx_range, "Finished recovering senders batch");
309
310    // Fail safe to ensure that we do not proceed without having recovered all senders.
311    let expected = tx_range.end - tx_range.start;
312    if processed_transactions != expected {
313        return Err(StageError::Fatal(
314            SenderRecoveryStageError::RecoveredSendersMismatch(GotExpected {
315                got: processed_transactions,
316                expected,
317            })
318            .into(),
319        ));
320    }
321    Ok(())
322}
323
324/// Spawns a thread to handle the recovery of transaction senders for
325/// specified chunks of a given batch. It processes incoming ranges, fetching and recovering
326/// transactions in parallel using global rayon pool
327fn setup_range_recovery<Provider>(
328    provider: &Provider,
329) -> mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>
330where
331    Provider: DBProvider
332        + HeaderProvider
333        + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>,
334{
335    let (tx_sender, tx_receiver) = mpsc::channel::<Vec<(Range<u64>, RecoveryResultSender)>>();
336    let static_file_provider = provider.static_file_provider();
337
338    // We do not use `tokio::task::spawn_blocking` because, during a shutdown,
339    // there will be a timeout grace period in which Tokio does not allow spawning
340    // additional blocking tasks. This would cause this function to return
341    // `SenderRecoveryStageError::RecoveredSendersMismatch` at the end.
342    //
343    // However, using `std::thread::spawn` allows us to utilize the timeout grace
344    // period to complete some work without throwing errors during the shutdown.
345    reth_tasks::spawn_os_thread("sender-recovery", move || {
346        while let Ok(chunks) = tx_receiver.recv() {
347            for (chunk_range, recovered_senders_tx) in chunks {
348                // Read the raw value, and let the rayon worker to decompress & decode.
349                let chunk = match static_file_provider.fetch_range_with_predicate(
350                    StaticFileSegment::Transactions,
351                    chunk_range.clone(),
352                    |cursor, number| {
353                        Ok(cursor
354                            .get_one::<TransactionMask<
355                                RawValue<<Provider::Primitives as NodePrimitives>::SignedTx>,
356                            >>(number.into())?
357                            .map(|tx| (number, tx)))
358                    },
359                    |_| true,
360                ) {
361                    Ok(chunk) => chunk,
362                    Err(err) => {
363                        // We exit early since we could not process this chunk.
364                        let _ = recovered_senders_tx
365                            .send(Err(Box::new(SenderRecoveryStageError::StageError(err.into()))));
366                        break
367                    }
368                };
369
370                // Spawn the task onto the global rayon pool
371                // This task will send the results through the channel after it has read the
372                // transaction and calculated the sender.
373                rayon::spawn(move || {
374                    let mut rlp_buf = Vec::with_capacity(128);
375                    for (number, tx) in chunk {
376                        let res = tx
377                            .value()
378                            .map_err(|err| {
379                                Box::new(SenderRecoveryStageError::StageError(err.into()))
380                            })
381                            .and_then(|tx| recover_sender((number, tx), &mut rlp_buf));
382
383                        let is_err = res.is_err();
384
385                        let _ = recovered_senders_tx.send(res);
386
387                        // Finish early
388                        if is_err {
389                            break
390                        }
391                    }
392                });
393            }
394        }
395    });
396    tx_sender
397}
398
399#[inline]
400fn recover_sender<T: SignedTransaction>(
401    (tx_id, tx): (TxNumber, T),
402    rlp_buf: &mut Vec<u8>,
403) -> Result<(u64, Address), Box<SenderRecoveryStageError>> {
404    rlp_buf.clear();
405    // We call [Signature::encode_and_recover_unchecked] because transactions run in the pipeline
406    // are known to be valid - this means that we do not need to check whether or not the `s`
407    // value is greater than `secp256k1n / 2` if past EIP-2. There are transactions
408    // pre-homestead which have large `s` values, so using [Signature::recover_signer] here
409    // would not be backwards-compatible.
410    let sender = tx.recover_unchecked_with_buf(rlp_buf).map_err(|_| {
411        SenderRecoveryStageError::FailedRecovery(FailedSenderRecoveryError { tx: tx_id })
412    })?;
413
414    Ok((tx_id, sender))
415}
416
417fn stage_checkpoint<Provider>(provider: &Provider) -> Result<EntitiesCheckpoint, StageError>
418where
419    Provider: StatsReader + StaticFileProviderFactory + PruneCheckpointReader,
420{
421    let pruned_entries = provider
422        .get_prune_checkpoint(PruneSegment::SenderRecovery)?
423        .and_then(|checkpoint| checkpoint.tx_number)
424        .unwrap_or_default();
425    Ok(EntitiesCheckpoint {
426        // If `TransactionSenders` table was pruned, we will have a number of entries in it not
427        // matching the actual number of processed transactions. To fix that, we add the
428        // number of pruned `TransactionSenders` entries.
429        processed: provider.count_entries::<tables::TransactionSenders>()? as u64 + pruned_entries,
430        // Count only static files entries. If we count the database entries too, we may have
431        // duplicates. We're sure that the static files have all entries that database has,
432        // because we run the `StaticFileProducer` before starting the pipeline.
433        total: provider.static_file_provider().count_entries::<tables::Transactions>()? as u64,
434    })
435}
436
437#[derive(Error, Debug)]
438#[error(transparent)]
439enum SenderRecoveryStageError {
440    /// A transaction failed sender recovery
441    #[error(transparent)]
442    FailedRecovery(#[from] FailedSenderRecoveryError),
443
444    /// Number of recovered senders does not match
445    #[error("mismatched sender count during recovery: {_0}")]
446    RecoveredSendersMismatch(GotExpected<u64>),
447
448    /// A different type of stage error occurred
449    #[error(transparent)]
450    StageError(#[from] StageError),
451}
452
453#[derive(Error, Debug)]
454#[error("sender recovery failed for transaction {tx}")]
455struct FailedSenderRecoveryError {
456    /// The transaction that failed sender recovery
457    tx: TxNumber,
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use crate::test_utils::{
464        stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
465        TestRunnerError, TestStageDB, UnwindStageTestRunner,
466    };
467    use alloy_primitives::{BlockNumber, B256};
468    use assert_matches::assert_matches;
469    use reth_db_api::{cursor::DbCursorRO, models::StorageSettings};
470    use reth_ethereum_primitives::{Block, TransactionSigned};
471    use reth_primitives_traits::{SealedBlock, SignerRecoverable};
472    use reth_provider::{
473        providers::StaticFileWriter, BlockBodyIndicesProvider, DatabaseProviderFactory,
474        PruneCheckpointWriter, StaticFileProviderFactory, TransactionsProvider,
475    };
476    use reth_prune_types::{PruneCheckpoint, PruneMode};
477    use reth_stages_api::StageUnitCheckpoint;
478    use reth_static_file_types::StaticFileSegment;
479    use reth_testing_utils::generators::{
480        self, random_block, random_block_range, BlockParams, BlockRangeParams,
481    };
482
483    stage_test_suite_ext!(SenderRecoveryTestRunner, sender_recovery);
484
485    /// Execute a block range with a single transaction
486    #[tokio::test]
487    async fn execute_single_transaction() {
488        let (previous_stage, stage_progress) = (500, 100);
489        let mut rng = generators::rng();
490
491        // Set up the runner
492        let runner = SenderRecoveryTestRunner::default();
493        let input = ExecInput {
494            target: Some(previous_stage),
495            checkpoint: Some(StageCheckpoint::new(stage_progress)),
496        };
497
498        // Insert blocks with a single transaction at block `stage_progress + 10`
499        let non_empty_block_number = stage_progress + 10;
500        let blocks = (stage_progress..=input.target())
501            .map(|number| {
502                random_block(
503                    &mut rng,
504                    number,
505                    BlockParams {
506                        tx_count: Some((number == non_empty_block_number) as u8),
507                        ..Default::default()
508                    },
509                )
510            })
511            .collect::<Vec<_>>();
512        runner
513            .db
514            .insert_blocks(blocks.iter(), StorageKind::Static)
515            .expect("failed to insert blocks");
516
517        let rx = runner.execute(input);
518
519        // Assert the successful result
520        let result = rx.await.unwrap();
521        assert_matches!(
522            result,
523            Ok(ExecOutput { checkpoint: StageCheckpoint {
524                block_number,
525                stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
526                    processed: 1,
527                    total: 1
528                }))
529            }, done: true }) if block_number == previous_stage
530        );
531
532        // Validate the stage execution
533        assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
534    }
535
536    /// Ensure the static file header advances to trailing empty blocks.
537    #[tokio::test]
538    async fn execute_advances_static_file_for_trailing_empty_blocks() {
539        let (stage_progress, target) = (0, 3);
540        let mut rng = generators::rng();
541
542        let runner = SenderRecoveryTestRunner::default();
543        runner.db.factory.set_storage_settings_cache(StorageSettings::v2());
544        let input = ExecInput {
545            target: Some(target),
546            checkpoint: Some(StageCheckpoint::new(stage_progress)),
547        };
548
549        let non_empty_block_number = stage_progress + 1;
550        let blocks = (stage_progress..=input.target())
551            .map(|number| {
552                random_block(
553                    &mut rng,
554                    number,
555                    BlockParams {
556                        tx_count: Some((number == non_empty_block_number) as u8),
557                        ..Default::default()
558                    },
559                )
560            })
561            .collect::<Vec<_>>();
562        runner
563            .db
564            .insert_blocks(blocks.iter(), StorageKind::Static)
565            .expect("failed to insert blocks");
566
567        let result = runner.execute(input).await.unwrap();
568        assert_matches!(result, Ok(ExecOutput { checkpoint, done: true }) if checkpoint.block_number == target);
569
570        let highest_block = runner
571            .db
572            .factory
573            .static_file_provider()
574            .get_highest_static_file_block(StaticFileSegment::TransactionSenders);
575        assert_eq!(Some(target), highest_block);
576    }
577
578    /// Execute the stage twice with input range that exceeds the commit threshold
579    #[tokio::test]
580    async fn execute_intermediate_commit() {
581        let mut rng = generators::rng();
582
583        let threshold = 10;
584        let mut runner = SenderRecoveryTestRunner::default();
585        runner.set_threshold(threshold);
586        let (stage_progress, previous_stage) = (1000, 1100); // input exceeds threshold
587
588        // Manually seed once with full input range
589        let seed = random_block_range(
590            &mut rng,
591            stage_progress + 1..=previous_stage,
592            BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..4, ..Default::default() },
593        ); // set tx count range high enough to hit the threshold
594        runner
595            .db
596            .insert_blocks(seed.iter(), StorageKind::Static)
597            .expect("failed to seed execution");
598
599        let total_transactions = runner
600            .db
601            .factory
602            .static_file_provider()
603            .count_entries::<tables::Transactions>()
604            .unwrap() as u64;
605
606        let first_input = ExecInput {
607            target: Some(previous_stage),
608            checkpoint: Some(StageCheckpoint::new(stage_progress)),
609        };
610
611        // Execute first time
612        let result = runner.execute(first_input).await.unwrap();
613        let mut tx_count = 0;
614        let expected_progress = seed
615            .iter()
616            .find(|x| {
617                tx_count += x.transaction_count();
618                tx_count as u64 > threshold
619            })
620            .map(|x| x.number)
621            .unwrap_or(previous_stage);
622        assert_matches!(result, Ok(_));
623        assert_eq!(
624            result.unwrap(),
625            ExecOutput {
626                checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint(
627                    EntitiesCheckpoint {
628                        processed: runner.db.count_entries::<tables::TransactionSenders>().unwrap()
629                            as u64,
630                        total: total_transactions
631                    }
632                ),
633                done: false
634            }
635        );
636
637        // Execute second time to completion
638        runner.set_threshold(u64::MAX);
639        let second_input = ExecInput {
640            target: Some(previous_stage),
641            checkpoint: Some(StageCheckpoint::new(expected_progress)),
642        };
643        let result = runner.execute(second_input).await.unwrap();
644        assert_matches!(result, Ok(_));
645        assert_eq!(
646            result.as_ref().unwrap(),
647            &ExecOutput {
648                checkpoint: StageCheckpoint::new(previous_stage).with_entities_stage_checkpoint(
649                    EntitiesCheckpoint { processed: total_transactions, total: total_transactions }
650                ),
651                done: true
652            }
653        );
654
655        assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed");
656    }
657
658    #[test]
659    fn stage_checkpoint_pruned() {
660        let db = TestStageDB::default();
661        let mut rng = generators::rng();
662
663        let blocks = random_block_range(
664            &mut rng,
665            0..=100,
666            BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..10, ..Default::default() },
667        );
668        db.insert_blocks(blocks.iter(), StorageKind::Static).expect("insert blocks");
669
670        let max_pruned_block = 30;
671        let max_processed_block = 70;
672
673        let mut tx_senders = Vec::new();
674        let mut tx_number = 0;
675        for block in &blocks[..=max_processed_block] {
676            for transaction in &block.body().transactions {
677                if block.number > max_pruned_block {
678                    tx_senders
679                        .push((tx_number, transaction.recover_signer().expect("recover signer")));
680                }
681                tx_number += 1;
682            }
683        }
684        db.insert_transaction_senders(tx_senders).expect("insert tx hash numbers");
685
686        let provider = db.factory.provider_rw().unwrap();
687        provider
688            .save_prune_checkpoint(
689                PruneSegment::SenderRecovery,
690                PruneCheckpoint {
691                    block_number: Some(max_pruned_block),
692                    tx_number: Some(
693                        blocks[..=max_pruned_block as usize]
694                            .iter()
695                            .map(|block| block.transaction_count() as u64)
696                            .sum(),
697                    ),
698                    prune_mode: PruneMode::Full,
699                },
700            )
701            .expect("save stage checkpoint");
702        provider.commit().expect("commit");
703
704        let provider = db.factory.database_provider_rw().unwrap();
705        assert_eq!(
706            stage_checkpoint(&provider).expect("stage checkpoint"),
707            EntitiesCheckpoint {
708                processed: blocks[..=max_processed_block]
709                    .iter()
710                    .map(|block| block.transaction_count() as u64)
711                    .sum(),
712                total: blocks.iter().map(|block| block.transaction_count() as u64).sum()
713            }
714        );
715    }
716
717    struct SenderRecoveryTestRunner {
718        db: TestStageDB,
719        threshold: u64,
720    }
721
722    impl Default for SenderRecoveryTestRunner {
723        fn default() -> Self {
724            Self { threshold: 1000, db: TestStageDB::default() }
725        }
726    }
727
728    impl SenderRecoveryTestRunner {
729        fn set_threshold(&mut self, threshold: u64) {
730            self.threshold = threshold;
731        }
732
733        /// # Panics
734        ///
735        /// 1. If there are any entries in the [`tables::TransactionSenders`] table above a given
736        ///    block number.
737        /// 2. If there is no requested block entry in the bodies table, but
738        ///    [`tables::TransactionSenders`] is not empty.
739        fn ensure_no_senders_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
740            let body_result = self
741                .db
742                .factory
743                .provider_rw()?
744                .block_body_indices(block)?
745                .ok_or(ProviderError::BlockBodyIndicesNotFound(block));
746            match body_result {
747                Ok(body) => self.db.ensure_no_entry_above::<tables::TransactionSenders, _>(
748                    body.last_tx_num(),
749                    |key| key,
750                )?,
751                Err(_) => {
752                    assert!(self.db.table_is_empty::<tables::TransactionSenders>()?);
753                }
754            };
755
756            Ok(())
757        }
758    }
759
760    impl StageTestRunner for SenderRecoveryTestRunner {
761        type S = SenderRecoveryStage;
762
763        fn db(&self) -> &TestStageDB {
764            &self.db
765        }
766
767        fn stage(&self) -> Self::S {
768            SenderRecoveryStage { commit_threshold: self.threshold, prune_mode: None }
769        }
770    }
771
772    impl ExecuteStageTestRunner for SenderRecoveryTestRunner {
773        type Seed = Vec<SealedBlock<Block>>;
774
775        fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
776            let mut rng = generators::rng();
777            let stage_progress = input.checkpoint().block_number;
778            let end = input.target();
779
780            let blocks = random_block_range(
781                &mut rng,
782                stage_progress..=end,
783                BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..2, ..Default::default() },
784            );
785            self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
786            Ok(blocks)
787        }
788
789        fn validate_execution(
790            &self,
791            input: ExecInput,
792            output: Option<ExecOutput>,
793        ) -> Result<(), TestRunnerError> {
794            match output {
795                Some(output) => {
796                    let provider = self.db.factory.provider()?;
797                    let start_block = input.next_block();
798                    let end_block = output.checkpoint.block_number;
799
800                    if start_block > end_block {
801                        return Ok(())
802                    }
803
804                    let mut body_cursor =
805                        provider.tx_ref().cursor_read::<tables::BlockBodyIndices>()?;
806                    body_cursor.seek_exact(start_block)?;
807
808                    while let Some((_, body)) = body_cursor.next()? {
809                        for tx_id in body.tx_num_range() {
810                            let transaction: TransactionSigned = provider
811                                .transaction_by_id_unhashed(tx_id)?
812                                .expect("no transaction entry");
813                            let signer =
814                                transaction.recover_signer().expect("failed to recover signer");
815                            assert_eq!(Some(signer), provider.transaction_sender(tx_id)?)
816                        }
817                    }
818                }
819                None => self.ensure_no_senders_by_block(input.checkpoint().block_number)?,
820            };
821
822            Ok(())
823        }
824    }
825
826    impl UnwindStageTestRunner for SenderRecoveryTestRunner {
827        fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
828            self.ensure_no_senders_by_block(input.unwind_to)
829        }
830    }
831}