Skip to main content

reth_stages/stages/
sender_recovery.rs

1use alloy_primitives::{Address, BlockNumber, TxNumber};
2use reth_config::config::SenderRecoveryConfig;
3use reth_consensus::ConsensusError;
4use reth_db::static_file::TransactionMask;
5use reth_db_api::{
6    cursor::DbCursorRW,
7    table::Value,
8    tables,
9    transaction::{DbTx, DbTxMut},
10    RawValue,
11};
12use reth_primitives_traits::{
13    FastInstant as Instant, GotExpected, NodePrimitives, SignedTransaction,
14};
15use reth_provider::{
16    BlockReader, DBProvider, EitherWriter, HeaderProvider, ProviderError, PruneCheckpointReader,
17    PruneCheckpointWriter, StaticFileProviderFactory, StatsReader, StorageSettingsCache,
18    TransactionsProvider,
19};
20use reth_prune_types::{PruneCheckpoint, PruneMode, PrunePurpose, PruneSegment};
21use reth_stages_api::{
22    BlockErrorKind, EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError,
23    StageId, UnwindInput, UnwindOutput,
24};
25use reth_static_file_types::StaticFileSegment;
26use std::{fmt::Debug, ops::Range, sync::mpsc};
27use thiserror::Error;
28use tracing::*;
29
30/// Maximum amount of transactions to read from disk at one time before we flush their senders to
31/// disk. Since each rayon worker will hold at most 100 transactions (`WORKER_CHUNK_SIZE`), we
32/// effectively max limit each batch to 1000 channels in memory.
33const BATCH_SIZE: usize = 100_000;
34
35/// Maximum number of senders to recover per rayon worker job.
36const WORKER_CHUNK_SIZE: usize = 100;
37
38/// Type alias for a sender that transmits the result of sender recovery.
39type RecoveryResultSender = mpsc::SyncSender<Result<(u64, Address), Box<SenderRecoveryStageError>>>;
40
41/// The sender recovery stage iterates over existing transactions,
42/// recovers the transaction signer and stores them
43/// in [`TransactionSenders`][reth_db_api::tables::TransactionSenders] table.
44#[derive(Clone, Debug)]
45pub struct SenderRecoveryStage {
46    /// The size of inserted items after which the control
47    /// flow will be returned to the pipeline for commit
48    pub commit_threshold: u64,
49    /// Prune mode for sender recovery. When set to `PruneMode::Full`, the stage will
50    /// fast-forward its checkpoint to skip all work, since senders will be recovered
51    /// inline by the execution stage instead.
52    pub prune_mode: Option<PruneMode>,
53}
54
55impl SenderRecoveryStage {
56    /// Create new instance of [`SenderRecoveryStage`].
57    pub const fn new(config: SenderRecoveryConfig, prune_mode: Option<PruneMode>) -> Self {
58        Self { commit_threshold: config.commit_threshold, prune_mode }
59    }
60}
61
62impl Default for SenderRecoveryStage {
63    fn default() -> Self {
64        Self { commit_threshold: 5_000_000, prune_mode: None }
65    }
66}
67
68impl<Provider> Stage<Provider> for SenderRecoveryStage
69where
70    Provider: DBProvider<Tx: DbTxMut>
71        + BlockReader
72        + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>
73        + StatsReader
74        + PruneCheckpointReader
75        + PruneCheckpointWriter
76        + StorageSettingsCache,
77{
78    /// Return the id of the stage
79    fn id(&self) -> StageId {
80        StageId::SenderRecovery
81    }
82
83    /// Retrieve the range of transactions to iterate over by querying
84    /// [`BlockBodyIndices`][reth_db_api::tables::BlockBodyIndices],
85    /// collect transactions within that range, recover signer for each transaction and store
86    /// entries in the [`TransactionSenders`][reth_db_api::tables::TransactionSenders] table or
87    /// static files depending on configuration.
88    fn execute(
89        &mut self,
90        provider: &Provider,
91        mut input: ExecInput,
92    ) -> Result<ExecOutput, StageError> {
93        // TODO: when senders are fully pruned, batch recover in execution stage instead of per-tx
94        // fallback
95        if let Some((target_prunable_block, prune_mode)) = self
96            .prune_mode
97            .map(|mode| {
98                mode.prune_target_block(
99                    input.target(),
100                    PruneSegment::SenderRecovery,
101                    PrunePurpose::User,
102                )
103            })
104            .transpose()?
105            .flatten() &&
106            target_prunable_block > input.checkpoint().block_number
107        {
108            input.checkpoint = Some(StageCheckpoint::new(target_prunable_block));
109
110            if provider.get_prune_checkpoint(PruneSegment::SenderRecovery)?.is_none() {
111                let target_prunable_tx_number = provider
112                    .block_body_indices(target_prunable_block)?
113                    .ok_or(ProviderError::BlockBodyIndicesNotFound(target_prunable_block))?
114                    .last_tx_num();
115
116                provider.save_prune_checkpoint(
117                    PruneSegment::SenderRecovery,
118                    PruneCheckpoint {
119                        block_number: Some(target_prunable_block),
120                        tx_number: Some(target_prunable_tx_number),
121                        prune_mode,
122                    },
123                )?;
124            }
125        }
126
127        if input.target_reached() {
128            return Ok(ExecOutput::done(input.checkpoint()))
129        }
130
131        let Some(range_output) =
132            input.next_block_range_with_transaction_threshold(provider, self.commit_threshold)?
133        else {
134            info!(target: "sync::stages::sender_recovery", "No transaction senders to recover");
135            EitherWriter::new_senders(
136                provider,
137                provider
138                    .static_file_provider()
139                    .get_highest_static_file_block(StaticFileSegment::TransactionSenders)
140                    .unwrap_or_default(),
141            )?
142            .ensure_at_block(input.target())?;
143            return Ok(ExecOutput {
144                checkpoint: StageCheckpoint::new(input.target())
145                    .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
146                done: true,
147            })
148        };
149        let end_block = *range_output.block_range.end();
150
151        let mut writer = EitherWriter::new_senders(provider, *range_output.block_range.start())?;
152
153        info!(target: "sync::stages::sender_recovery", tx_range = ?range_output.tx_range, "Recovering senders");
154
155        // Iterate over transactions in batches, recover the senders and append them
156        let batch = range_output
157            .tx_range
158            .clone()
159            .step_by(BATCH_SIZE)
160            .map(|start| start..std::cmp::min(start + BATCH_SIZE as u64, range_output.tx_range.end))
161            .collect::<Vec<Range<u64>>>();
162
163        let tx_batch_sender = setup_range_recovery(provider);
164
165        let start = Instant::now();
166        let block_body_indices =
167            provider.block_body_indices_range(range_output.block_range.clone())?;
168        let block_body_indices_elapsed = start.elapsed();
169        let mut blocks_with_indices = range_output.block_range.zip(block_body_indices).peekable();
170
171        for range in batch {
172            // Pair each transaction number with its block number
173            let start = Instant::now();
174            let block_numbers = range.clone().fold(Vec::new(), |mut block_numbers, tx| {
175                while let Some((block, index)) = blocks_with_indices.peek() {
176                    if index.contains_tx(tx) {
177                        block_numbers.push(*block);
178                        return block_numbers
179                    }
180                    blocks_with_indices.next();
181                }
182                block_numbers
183            });
184            let fold_elapsed = start.elapsed();
185            debug!(target: "sync::stages::sender_recovery", ?block_body_indices_elapsed, ?fold_elapsed, len = block_numbers.len(), "Calculated block numbers");
186            recover_range(range, block_numbers, provider, tx_batch_sender.clone(), &mut writer)?;
187        }
188
189        // Advance the static file header to the end of this range to account for empty blocks.
190        writer.ensure_at_block(end_block)?;
191
192        Ok(ExecOutput {
193            checkpoint: StageCheckpoint::new(end_block)
194                .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
195            done: range_output.is_final_range,
196        })
197    }
198
199    /// Unwind the stage.
200    fn unwind(
201        &mut self,
202        provider: &Provider,
203        input: UnwindInput,
204    ) -> Result<UnwindOutput, StageError> {
205        let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
206
207        if self.prune_mode.is_none_or(|mode| !mode.is_full()) {
208            // Lookup the next tx id after unwind_to block (first tx to remove)
209            let unwind_tx_from = provider
210                .block_body_indices(unwind_to)?
211                .ok_or(ProviderError::BlockBodyIndicesNotFound(unwind_to))?
212                .next_tx_num();
213
214            EitherWriter::new_senders(provider, unwind_to)?
215                .prune_senders(unwind_tx_from, unwind_to)?;
216        }
217
218        Ok(UnwindOutput {
219            checkpoint: StageCheckpoint::new(unwind_to)
220                .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
221        })
222    }
223}
224
225fn recover_range<Provider, CURSOR>(
226    tx_range: Range<TxNumber>,
227    block_numbers: Vec<BlockNumber>,
228    provider: &Provider,
229    tx_batch_sender: mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>,
230    writer: &mut EitherWriter<'_, CURSOR, Provider::Primitives>,
231) -> Result<(), StageError>
232where
233    Provider: DBProvider + HeaderProvider + TransactionsProvider + StaticFileProviderFactory,
234    CURSOR: DbCursorRW<tables::TransactionSenders>,
235{
236    debug_assert_eq!(
237        tx_range.clone().count(),
238        block_numbers.len(),
239        "Transaction range and block numbers count mismatch"
240    );
241
242    debug!(target: "sync::stages::sender_recovery", ?tx_range, "Sending batch for processing");
243
244    // Preallocate channels for each chunks in the batch
245    let (chunks, receivers): (Vec<_>, Vec<_>) = tx_range
246        .clone()
247        .step_by(WORKER_CHUNK_SIZE)
248        .map(|start| {
249            let range = start..std::cmp::min(start + WORKER_CHUNK_SIZE as u64, tx_range.end);
250            let (tx, rx) = mpsc::sync_channel((range.end - range.start) as usize);
251            // Range and channel sender will be sent to rayon worker
252            ((range, tx), rx)
253        })
254        .unzip();
255
256    if let Some(err) = tx_batch_sender.send(chunks).err() {
257        return Err(StageError::Fatal(err.into()));
258    }
259
260    debug!(target: "sync::stages::sender_recovery", ?tx_range, "Appending recovered senders to the database");
261
262    let mut processed_transactions = 0;
263    let mut block_numbers = block_numbers.into_iter();
264    for channel in receivers {
265        while let Ok(recovered) = channel.recv() {
266            let (tx_id, sender) = match recovered {
267                Ok(result) => result,
268                Err(error) => {
269                    return match *error {
270                        SenderRecoveryStageError::FailedRecovery(err) => {
271                            // get the block number for the bad transaction
272                            let block_number = provider
273                                .tx_ref()
274                                .get::<tables::TransactionBlocks>(err.tx)?
275                                .ok_or(ProviderError::BlockNumberForTransactionIndexNotFound)?;
276
277                            // fetch the sealed header so we can use it in the sender recovery
278                            // unwind
279                            let sealed_header =
280                                provider.sealed_header(block_number)?.ok_or_else(|| {
281                                    ProviderError::HeaderNotFound(block_number.into())
282                                })?;
283
284                            Err(StageError::Block {
285                                block: Box::new(sealed_header.block_with_parent()),
286                                error: BlockErrorKind::Validation(
287                                    ConsensusError::TransactionSignerRecoveryError,
288                                ),
289                            })
290                        }
291                        SenderRecoveryStageError::StageError(err) => Err(err),
292                        SenderRecoveryStageError::RecoveredSendersMismatch(expectation) => {
293                            Err(StageError::Fatal(
294                                SenderRecoveryStageError::RecoveredSendersMismatch(expectation)
295                                    .into(),
296                            ))
297                        }
298                    }
299                }
300            };
301
302            let new_block_number = block_numbers
303                .next()
304                .expect("block numbers iterator has the same length as the number of transactions");
305            writer.ensure_at_block(new_block_number)?;
306            writer.append_sender(tx_id, &sender)?;
307            processed_transactions += 1;
308        }
309    }
310    debug!(target: "sync::stages::sender_recovery", ?tx_range, "Finished recovering senders batch");
311
312    // Fail safe to ensure that we do not proceed without having recovered all senders.
313    let expected = tx_range.end - tx_range.start;
314    if processed_transactions != expected {
315        return Err(StageError::Fatal(
316            SenderRecoveryStageError::RecoveredSendersMismatch(GotExpected {
317                got: processed_transactions,
318                expected,
319            })
320            .into(),
321        ));
322    }
323    Ok(())
324}
325
326/// Spawns a thread to handle the recovery of transaction senders for
327/// specified chunks of a given batch. It processes incoming ranges, fetching and recovering
328/// transactions in parallel using global rayon pool
329fn setup_range_recovery<Provider>(
330    provider: &Provider,
331) -> mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>
332where
333    Provider: DBProvider
334        + HeaderProvider
335        + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>,
336{
337    let (tx_sender, tx_receiver) = mpsc::channel::<Vec<(Range<u64>, RecoveryResultSender)>>();
338    let static_file_provider = provider.static_file_provider();
339
340    // We do not use `tokio::task::spawn_blocking` because, during a shutdown,
341    // there will be a timeout grace period in which Tokio does not allow spawning
342    // additional blocking tasks. This would cause this function to return
343    // `SenderRecoveryStageError::RecoveredSendersMismatch` at the end.
344    //
345    // However, using `std::thread::spawn` allows us to utilize the timeout grace
346    // period to complete some work without throwing errors during the shutdown.
347    reth_tasks::spawn_os_thread("sender-recovery", move || {
348        while let Ok(chunks) = tx_receiver.recv() {
349            for (chunk_range, recovered_senders_tx) in chunks {
350                // Read the raw value, and let the rayon worker to decompress & decode.
351                let chunk = match static_file_provider.fetch_range_with_predicate(
352                    StaticFileSegment::Transactions,
353                    chunk_range.clone(),
354                    |cursor, number| {
355                        Ok(cursor
356                            .get_one::<TransactionMask<
357                                RawValue<<Provider::Primitives as NodePrimitives>::SignedTx>,
358                            >>(number.into())?
359                            .map(|tx| (number, tx)))
360                    },
361                    |_| true,
362                ) {
363                    Ok(chunk) => chunk,
364                    Err(err) => {
365                        // We exit early since we could not process this chunk.
366                        let _ = recovered_senders_tx
367                            .send(Err(Box::new(SenderRecoveryStageError::StageError(err.into()))));
368                        break
369                    }
370                };
371
372                // Spawn the task onto the global rayon pool
373                // This task will send the results through the channel after it has read the
374                // transaction and calculated the sender.
375                rayon::spawn(move || {
376                    let mut rlp_buf = Vec::with_capacity(128);
377                    for (number, tx) in chunk {
378                        let res = tx
379                            .value()
380                            .map_err(|err| {
381                                Box::new(SenderRecoveryStageError::StageError(err.into()))
382                            })
383                            .and_then(|tx| recover_sender((number, tx), &mut rlp_buf));
384
385                        let is_err = res.is_err();
386
387                        let _ = recovered_senders_tx.send(res);
388
389                        // Finish early
390                        if is_err {
391                            break
392                        }
393                    }
394                });
395            }
396        }
397    });
398    tx_sender
399}
400
401#[inline]
402fn recover_sender<T: SignedTransaction>(
403    (tx_id, tx): (TxNumber, T),
404    rlp_buf: &mut Vec<u8>,
405) -> Result<(u64, Address), Box<SenderRecoveryStageError>> {
406    rlp_buf.clear();
407    // We call [Signature::encode_and_recover_unchecked] because transactions run in the pipeline
408    // are known to be valid - this means that we do not need to check whether or not the `s`
409    // value is greater than `secp256k1n / 2` if past EIP-2. There are transactions
410    // pre-homestead which have large `s` values, so using [Signature::recover_signer] here
411    // would not be backwards-compatible.
412    let sender = tx.recover_unchecked_with_buf(rlp_buf).map_err(|_| {
413        SenderRecoveryStageError::FailedRecovery(FailedSenderRecoveryError { tx: tx_id })
414    })?;
415
416    Ok((tx_id, sender))
417}
418
419fn stage_checkpoint<Provider>(provider: &Provider) -> Result<EntitiesCheckpoint, StageError>
420where
421    Provider: StatsReader + StaticFileProviderFactory + PruneCheckpointReader,
422{
423    let pruned_entries = provider
424        .get_prune_checkpoint(PruneSegment::SenderRecovery)?
425        .and_then(|checkpoint| checkpoint.tx_number)
426        .unwrap_or_default();
427    Ok(EntitiesCheckpoint {
428        // If `TransactionSenders` table was pruned, we will have a number of entries in it not
429        // matching the actual number of processed transactions. To fix that, we add the
430        // number of pruned `TransactionSenders` entries.
431        processed: provider.count_entries::<tables::TransactionSenders>()? as u64 + pruned_entries,
432        // Count only static files entries. If we count the database entries too, we may have
433        // duplicates. We're sure that the static files have all entries that database has,
434        // because we run the `StaticFileProducer` before starting the pipeline.
435        total: provider.static_file_provider().count_entries::<tables::Transactions>()? as u64,
436    })
437}
438
439#[derive(Error, Debug)]
440#[error(transparent)]
441enum SenderRecoveryStageError {
442    /// A transaction failed sender recovery
443    #[error(transparent)]
444    FailedRecovery(#[from] FailedSenderRecoveryError),
445
446    /// Number of recovered senders does not match
447    #[error("mismatched sender count during recovery: {_0}")]
448    RecoveredSendersMismatch(GotExpected<u64>),
449
450    /// A different type of stage error occurred
451    #[error(transparent)]
452    StageError(#[from] StageError),
453}
454
455#[derive(Error, Debug)]
456#[error("sender recovery failed for transaction {tx}")]
457struct FailedSenderRecoveryError {
458    /// The transaction that failed sender recovery
459    tx: TxNumber,
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::test_utils::{
466        stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
467        TestRunnerError, TestStageDB, UnwindStageTestRunner,
468    };
469    use alloy_primitives::{BlockNumber, B256};
470    use assert_matches::assert_matches;
471    use reth_db_api::{cursor::DbCursorRO, models::StorageSettings};
472    use reth_ethereum_primitives::{Block, TransactionSigned};
473    use reth_primitives_traits::{SealedBlock, SignerRecoverable};
474    use reth_provider::{
475        providers::StaticFileWriter, BlockBodyIndicesProvider, DatabaseProviderFactory,
476        PruneCheckpointWriter, StaticFileProviderFactory, TransactionsProvider,
477    };
478    use reth_prune_types::{PruneCheckpoint, PruneMode};
479    use reth_stages_api::StageUnitCheckpoint;
480    use reth_static_file_types::StaticFileSegment;
481    use reth_testing_utils::generators::{
482        self, random_block, random_block_range, BlockParams, BlockRangeParams,
483    };
484
485    stage_test_suite_ext!(SenderRecoveryTestRunner, sender_recovery);
486
487    /// Execute a block range with a single transaction
488    #[tokio::test]
489    async fn execute_single_transaction() {
490        let (previous_stage, stage_progress) = (500, 100);
491        let mut rng = generators::rng();
492
493        // Set up the runner
494        let runner = SenderRecoveryTestRunner::default();
495        let input = ExecInput {
496            target: Some(previous_stage),
497            checkpoint: Some(StageCheckpoint::new(stage_progress)),
498        };
499
500        // Insert blocks with a single transaction at block `stage_progress + 10`
501        let non_empty_block_number = stage_progress + 10;
502        let blocks = (stage_progress..=input.target())
503            .map(|number| {
504                random_block(
505                    &mut rng,
506                    number,
507                    BlockParams {
508                        tx_count: Some((number == non_empty_block_number) as u8),
509                        ..Default::default()
510                    },
511                )
512            })
513            .collect::<Vec<_>>();
514        runner
515            .db
516            .insert_blocks(blocks.iter(), StorageKind::Static)
517            .expect("failed to insert blocks");
518
519        let rx = runner.execute(input);
520
521        // Assert the successful result
522        let result = rx.await.unwrap();
523        assert_matches!(
524            result,
525            Ok(ExecOutput { checkpoint: StageCheckpoint {
526                block_number,
527                stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
528                    processed: 1,
529                    total: 1
530                }))
531            }, done: true }) if block_number == previous_stage
532        );
533
534        // Validate the stage execution
535        assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
536    }
537
538    /// Ensure the static file header advances to trailing empty blocks.
539    #[tokio::test]
540    async fn execute_advances_static_file_for_trailing_empty_blocks() {
541        let (stage_progress, target) = (0, 3);
542        let mut rng = generators::rng();
543
544        let runner = SenderRecoveryTestRunner::default();
545        runner.db.factory.set_storage_settings_cache(StorageSettings::v2());
546        let input = ExecInput {
547            target: Some(target),
548            checkpoint: Some(StageCheckpoint::new(stage_progress)),
549        };
550
551        let non_empty_block_number = stage_progress + 1;
552        let blocks = (stage_progress..=input.target())
553            .map(|number| {
554                random_block(
555                    &mut rng,
556                    number,
557                    BlockParams {
558                        tx_count: Some((number == non_empty_block_number) as u8),
559                        ..Default::default()
560                    },
561                )
562            })
563            .collect::<Vec<_>>();
564        runner
565            .db
566            .insert_blocks(blocks.iter(), StorageKind::Static)
567            .expect("failed to insert blocks");
568
569        let result = runner.execute(input).await.unwrap();
570        assert_matches!(result, Ok(ExecOutput { checkpoint, done: true }) if checkpoint.block_number == target);
571
572        let highest_block = runner
573            .db
574            .factory
575            .static_file_provider()
576            .get_highest_static_file_block(StaticFileSegment::TransactionSenders);
577        assert_eq!(Some(target), highest_block);
578    }
579
580    /// Execute the stage twice with input range that exceeds the commit threshold
581    #[tokio::test]
582    async fn execute_intermediate_commit() {
583        let mut rng = generators::rng();
584
585        let threshold = 10;
586        let mut runner = SenderRecoveryTestRunner::default();
587        runner.set_threshold(threshold);
588        let (stage_progress, previous_stage) = (1000, 1100); // input exceeds threshold
589
590        // Manually seed once with full input range
591        let seed = random_block_range(
592            &mut rng,
593            stage_progress + 1..=previous_stage,
594            BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..4, ..Default::default() },
595        ); // set tx count range high enough to hit the threshold
596        runner
597            .db
598            .insert_blocks(seed.iter(), StorageKind::Static)
599            .expect("failed to seed execution");
600
601        let total_transactions = runner
602            .db
603            .factory
604            .static_file_provider()
605            .count_entries::<tables::Transactions>()
606            .unwrap() as u64;
607
608        let first_input = ExecInput {
609            target: Some(previous_stage),
610            checkpoint: Some(StageCheckpoint::new(stage_progress)),
611        };
612
613        // Execute first time
614        let result = runner.execute(first_input).await.unwrap();
615        let mut tx_count = 0;
616        let expected_progress = seed
617            .iter()
618            .find(|x| {
619                tx_count += x.transaction_count();
620                tx_count as u64 > threshold
621            })
622            .map(|x| x.number)
623            .unwrap_or(previous_stage);
624        assert_matches!(result, Ok(_));
625        assert_eq!(
626            result.unwrap(),
627            ExecOutput {
628                checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint(
629                    EntitiesCheckpoint {
630                        processed: runner.db.count_entries::<tables::TransactionSenders>().unwrap()
631                            as u64,
632                        total: total_transactions
633                    }
634                ),
635                done: false
636            }
637        );
638
639        // Execute second time to completion
640        runner.set_threshold(u64::MAX);
641        let second_input = ExecInput {
642            target: Some(previous_stage),
643            checkpoint: Some(StageCheckpoint::new(expected_progress)),
644        };
645        let result = runner.execute(second_input).await.unwrap();
646        assert_matches!(result, Ok(_));
647        assert_eq!(
648            result.as_ref().unwrap(),
649            &ExecOutput {
650                checkpoint: StageCheckpoint::new(previous_stage).with_entities_stage_checkpoint(
651                    EntitiesCheckpoint { processed: total_transactions, total: total_transactions }
652                ),
653                done: true
654            }
655        );
656
657        assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed");
658    }
659
660    #[test]
661    fn stage_checkpoint_pruned() {
662        let db = TestStageDB::default();
663        let mut rng = generators::rng();
664
665        let blocks = random_block_range(
666            &mut rng,
667            0..=100,
668            BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..10, ..Default::default() },
669        );
670        db.insert_blocks(blocks.iter(), StorageKind::Static).expect("insert blocks");
671
672        let max_pruned_block = 30;
673        let max_processed_block = 70;
674
675        let mut tx_senders = Vec::new();
676        let mut tx_number = 0;
677        for block in &blocks[..=max_processed_block] {
678            for transaction in &block.body().transactions {
679                if block.number > max_pruned_block {
680                    tx_senders
681                        .push((tx_number, transaction.recover_signer().expect("recover signer")));
682                }
683                tx_number += 1;
684            }
685        }
686        db.insert_transaction_senders(tx_senders).expect("insert tx hash numbers");
687
688        let provider = db.factory.provider_rw().unwrap();
689        provider
690            .save_prune_checkpoint(
691                PruneSegment::SenderRecovery,
692                PruneCheckpoint {
693                    block_number: Some(max_pruned_block),
694                    tx_number: Some(
695                        blocks[..=max_pruned_block as usize]
696                            .iter()
697                            .map(|block| block.transaction_count() as u64)
698                            .sum(),
699                    ),
700                    prune_mode: PruneMode::Full,
701                },
702            )
703            .expect("save stage checkpoint");
704        provider.commit().expect("commit");
705
706        let provider = db.factory.database_provider_rw().unwrap();
707        assert_eq!(
708            stage_checkpoint(&provider).expect("stage checkpoint"),
709            EntitiesCheckpoint {
710                processed: blocks[..=max_processed_block]
711                    .iter()
712                    .map(|block| block.transaction_count() as u64)
713                    .sum(),
714                total: blocks.iter().map(|block| block.transaction_count() as u64).sum()
715            }
716        );
717    }
718
719    struct SenderRecoveryTestRunner {
720        db: TestStageDB,
721        threshold: u64,
722    }
723
724    impl Default for SenderRecoveryTestRunner {
725        fn default() -> Self {
726            Self { threshold: 1000, db: TestStageDB::default() }
727        }
728    }
729
730    impl SenderRecoveryTestRunner {
731        fn set_threshold(&mut self, threshold: u64) {
732            self.threshold = threshold;
733        }
734
735        /// # Panics
736        ///
737        /// 1. If there are any entries in the [`tables::TransactionSenders`] table above a given
738        ///    block number.
739        /// 2. If there is no requested block entry in the bodies table, but
740        ///    [`tables::TransactionSenders`] is not empty.
741        fn ensure_no_senders_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
742            let body_result = self
743                .db
744                .factory
745                .provider_rw()?
746                .block_body_indices(block)?
747                .ok_or(ProviderError::BlockBodyIndicesNotFound(block));
748            match body_result {
749                Ok(body) => self.db.ensure_no_entry_above::<tables::TransactionSenders, _>(
750                    body.last_tx_num(),
751                    |key| key,
752                )?,
753                Err(_) => {
754                    assert!(self.db.table_is_empty::<tables::TransactionSenders>()?);
755                }
756            };
757
758            Ok(())
759        }
760    }
761
762    impl StageTestRunner for SenderRecoveryTestRunner {
763        type S = SenderRecoveryStage;
764
765        fn db(&self) -> &TestStageDB {
766            &self.db
767        }
768
769        fn stage(&self) -> Self::S {
770            SenderRecoveryStage { commit_threshold: self.threshold, prune_mode: None }
771        }
772    }
773
774    impl ExecuteStageTestRunner for SenderRecoveryTestRunner {
775        type Seed = Vec<SealedBlock<Block>>;
776
777        fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
778            let mut rng = generators::rng();
779            let stage_progress = input.checkpoint().block_number;
780            let end = input.target();
781
782            let blocks = random_block_range(
783                &mut rng,
784                stage_progress..=end,
785                BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..2, ..Default::default() },
786            );
787            self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
788            Ok(blocks)
789        }
790
791        fn validate_execution(
792            &self,
793            input: ExecInput,
794            output: Option<ExecOutput>,
795        ) -> Result<(), TestRunnerError> {
796            match output {
797                Some(output) => {
798                    let provider = self.db.factory.provider()?;
799                    let start_block = input.next_block();
800                    let end_block = output.checkpoint.block_number;
801
802                    if start_block > end_block {
803                        return Ok(())
804                    }
805
806                    let mut body_cursor =
807                        provider.tx_ref().cursor_read::<tables::BlockBodyIndices>()?;
808                    body_cursor.seek_exact(start_block)?;
809
810                    while let Some((_, body)) = body_cursor.next()? {
811                        for tx_id in body.tx_num_range() {
812                            let transaction: TransactionSigned = provider
813                                .transaction_by_id_unhashed(tx_id)?
814                                .expect("no transaction entry");
815                            let signer =
816                                transaction.recover_signer().expect("failed to recover signer");
817                            assert_eq!(Some(signer), provider.transaction_sender(tx_id)?)
818                        }
819                    }
820                }
821                None => self.ensure_no_senders_by_block(input.checkpoint().block_number)?,
822            };
823
824            Ok(())
825        }
826    }
827
828    impl UnwindStageTestRunner for SenderRecoveryTestRunner {
829        fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
830            self.ensure_no_senders_by_block(input.unwind_to)
831        }
832    }
833}