1use alloy_primitives::{Address, BlockNumber, TxNumber};
2use reth_config::config::SenderRecoveryConfig;
3use reth_consensus::ConsensusError;
4use reth_db::static_file::TransactionMask;
5use reth_db_api::{
6 cursor::DbCursorRW,
7 table::Value,
8 tables,
9 transaction::{DbTx, DbTxMut},
10 DbTxUnwindExt, RawValue,
11};
12use reth_primitives_traits::{GotExpected, NodePrimitives, SignedTransaction};
13use reth_provider::{
14 BlockReader, DBProvider, EitherWriter, HeaderProvider, ProviderError, PruneCheckpointReader,
15 StaticFileProviderFactory, StatsReader, StorageSettingsCache, TransactionsProvider,
16};
17use reth_prune_types::PruneSegment;
18use reth_stages_api::{
19 BlockErrorKind, EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError,
20 StageId, UnwindInput, UnwindOutput,
21};
22use reth_static_file_types::StaticFileSegment;
23use std::{fmt::Debug, ops::Range, sync::mpsc, time::Instant};
24use thiserror::Error;
25use tracing::*;
26
27const BATCH_SIZE: usize = 100_000;
31
32const WORKER_CHUNK_SIZE: usize = 100;
34
35type RecoveryResultSender = mpsc::Sender<Result<(u64, Address), Box<SenderRecoveryStageError>>>;
37
38#[derive(Clone, Debug)]
42pub struct SenderRecoveryStage {
43 pub commit_threshold: u64,
46}
47
48impl SenderRecoveryStage {
49 pub const fn new(config: SenderRecoveryConfig) -> Self {
51 Self { commit_threshold: config.commit_threshold }
52 }
53}
54
55impl Default for SenderRecoveryStage {
56 fn default() -> Self {
57 Self { commit_threshold: 5_000_000 }
58 }
59}
60
61impl<Provider> Stage<Provider> for SenderRecoveryStage
62where
63 Provider: DBProvider<Tx: DbTxMut>
64 + BlockReader
65 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>
66 + StatsReader
67 + PruneCheckpointReader
68 + StorageSettingsCache,
69{
70 fn id(&self) -> StageId {
72 StageId::SenderRecovery
73 }
74
75 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
81 if input.target_reached() {
82 return Ok(ExecOutput::done(input.checkpoint()))
83 }
84
85 let Some(range_output) =
86 input.next_block_range_with_transaction_threshold(provider, self.commit_threshold)?
87 else {
88 info!(target: "sync::stages::sender_recovery", "No transaction senders to recover");
89 EitherWriter::new_senders(
90 provider,
91 provider
92 .static_file_provider()
93 .get_highest_static_file_block(StaticFileSegment::TransactionSenders)
94 .unwrap_or_default(),
95 )?
96 .ensure_at_block(input.target())?;
97 return Ok(ExecOutput {
98 checkpoint: StageCheckpoint::new(input.target())
99 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
100 done: true,
101 })
102 };
103 let end_block = *range_output.block_range.end();
104
105 let mut writer = EitherWriter::new_senders(provider, *range_output.block_range.start())?;
106
107 info!(target: "sync::stages::sender_recovery", tx_range = ?range_output.tx_range, "Recovering senders");
108
109 let batch = range_output
111 .tx_range
112 .clone()
113 .step_by(BATCH_SIZE)
114 .map(|start| start..std::cmp::min(start + BATCH_SIZE as u64, range_output.tx_range.end))
115 .collect::<Vec<Range<u64>>>();
116
117 let tx_batch_sender = setup_range_recovery(provider);
118
119 let start = Instant::now();
120 let block_body_indices =
121 provider.block_body_indices_range(range_output.block_range.clone())?;
122 let block_body_indices_elapsed = start.elapsed();
123 let mut blocks_with_indices = range_output.block_range.zip(block_body_indices).peekable();
124
125 for range in batch {
126 let start = Instant::now();
128 let block_numbers = range.clone().fold(Vec::new(), |mut block_numbers, tx| {
129 while let Some((block, index)) = blocks_with_indices.peek() {
130 if index.contains_tx(tx) {
131 block_numbers.push(*block);
132 return block_numbers
133 }
134 blocks_with_indices.next();
135 }
136 block_numbers
137 });
138 let fold_elapsed = start.elapsed();
139 debug!(target: "sync::stages::sender_recovery", ?block_body_indices_elapsed, ?fold_elapsed, len = block_numbers.len(), "Calculated block numbers");
140 recover_range(range, block_numbers, provider, tx_batch_sender.clone(), &mut writer)?;
141 }
142
143 Ok(ExecOutput {
144 checkpoint: StageCheckpoint::new(end_block)
145 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
146 done: range_output.is_final_range,
147 })
148 }
149
150 fn unwind(
152 &mut self,
153 provider: &Provider,
154 input: UnwindInput,
155 ) -> Result<UnwindOutput, StageError> {
156 let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
157
158 let latest_tx_id = provider
160 .block_body_indices(unwind_to)?
161 .ok_or(ProviderError::BlockBodyIndicesNotFound(unwind_to))?
162 .last_tx_num();
163 provider.tx_ref().unwind_table_by_num::<tables::TransactionSenders>(latest_tx_id)?;
164
165 Ok(UnwindOutput {
166 checkpoint: StageCheckpoint::new(unwind_to)
167 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
168 })
169 }
170}
171
172fn recover_range<Provider, CURSOR>(
173 tx_range: Range<TxNumber>,
174 block_numbers: Vec<BlockNumber>,
175 provider: &Provider,
176 tx_batch_sender: mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>,
177 writer: &mut EitherWriter<'_, CURSOR, Provider::Primitives>,
178) -> Result<(), StageError>
179where
180 Provider: DBProvider + HeaderProvider + TransactionsProvider + StaticFileProviderFactory,
181 CURSOR: DbCursorRW<tables::TransactionSenders>,
182{
183 debug_assert_eq!(
184 tx_range.clone().count(),
185 block_numbers.len(),
186 "Transaction range and block numbers count mismatch"
187 );
188
189 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Sending batch for processing");
190
191 let (chunks, receivers): (Vec<_>, Vec<_>) = tx_range
193 .clone()
194 .step_by(WORKER_CHUNK_SIZE)
195 .map(|start| {
196 let range = start..std::cmp::min(start + WORKER_CHUNK_SIZE as u64, tx_range.end);
197 let (tx, rx) = mpsc::channel();
198 ((range, tx), rx)
200 })
201 .unzip();
202
203 if let Some(err) = tx_batch_sender.send(chunks).err() {
204 return Err(StageError::Fatal(err.into()));
205 }
206
207 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Appending recovered senders to the database");
208
209 let mut processed_transactions = 0;
210 let mut block_numbers = block_numbers.into_iter();
211 for channel in receivers {
212 while let Ok(recovered) = channel.recv() {
213 let (tx_id, sender) = match recovered {
214 Ok(result) => result,
215 Err(error) => {
216 return match *error {
217 SenderRecoveryStageError::FailedRecovery(err) => {
218 let block_number = provider
220 .tx_ref()
221 .get::<tables::TransactionBlocks>(err.tx)?
222 .ok_or(ProviderError::BlockNumberForTransactionIndexNotFound)?;
223
224 let sealed_header =
227 provider.sealed_header(block_number)?.ok_or_else(|| {
228 ProviderError::HeaderNotFound(block_number.into())
229 })?;
230
231 Err(StageError::Block {
232 block: Box::new(sealed_header.block_with_parent()),
233 error: BlockErrorKind::Validation(
234 ConsensusError::TransactionSignerRecoveryError,
235 ),
236 })
237 }
238 SenderRecoveryStageError::StageError(err) => Err(err),
239 SenderRecoveryStageError::RecoveredSendersMismatch(expectation) => {
240 Err(StageError::Fatal(
241 SenderRecoveryStageError::RecoveredSendersMismatch(expectation)
242 .into(),
243 ))
244 }
245 }
246 }
247 };
248
249 let new_block_number = block_numbers
250 .next()
251 .expect("block numbers iterator has the same length as the number of transactions");
252 writer.ensure_at_block(new_block_number)?;
253 writer.append_sender(tx_id, &sender)?;
254 processed_transactions += 1;
255 }
256 }
257 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Finished recovering senders batch");
258
259 let expected = tx_range.end - tx_range.start;
261 if processed_transactions != expected {
262 return Err(StageError::Fatal(
263 SenderRecoveryStageError::RecoveredSendersMismatch(GotExpected {
264 got: processed_transactions,
265 expected,
266 })
267 .into(),
268 ));
269 }
270 Ok(())
271}
272
273fn setup_range_recovery<Provider>(
277 provider: &Provider,
278) -> mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>
279where
280 Provider: DBProvider
281 + HeaderProvider
282 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>,
283{
284 let (tx_sender, tx_receiver) = mpsc::channel::<Vec<(Range<u64>, RecoveryResultSender)>>();
285 let static_file_provider = provider.static_file_provider();
286
287 std::thread::spawn(move || {
295 while let Ok(chunks) = tx_receiver.recv() {
296 for (chunk_range, recovered_senders_tx) in chunks {
297 let chunk = match static_file_provider.fetch_range_with_predicate(
299 StaticFileSegment::Transactions,
300 chunk_range.clone(),
301 |cursor, number| {
302 Ok(cursor
303 .get_one::<TransactionMask<
304 RawValue<<Provider::Primitives as NodePrimitives>::SignedTx>,
305 >>(number.into())?
306 .map(|tx| (number, tx)))
307 },
308 |_| true,
309 ) {
310 Ok(chunk) => chunk,
311 Err(err) => {
312 let _ = recovered_senders_tx
314 .send(Err(Box::new(SenderRecoveryStageError::StageError(err.into()))));
315 break
316 }
317 };
318
319 rayon::spawn(move || {
323 let mut rlp_buf = Vec::with_capacity(128);
324 for (number, tx) in chunk {
325 let res = tx
326 .value()
327 .map_err(|err| {
328 Box::new(SenderRecoveryStageError::StageError(err.into()))
329 })
330 .and_then(|tx| recover_sender((number, tx), &mut rlp_buf));
331
332 let is_err = res.is_err();
333
334 let _ = recovered_senders_tx.send(res);
335
336 if is_err {
338 break
339 }
340 }
341 });
342 }
343 }
344 });
345 tx_sender
346}
347
348#[inline]
349fn recover_sender<T: SignedTransaction>(
350 (tx_id, tx): (TxNumber, T),
351 rlp_buf: &mut Vec<u8>,
352) -> Result<(u64, Address), Box<SenderRecoveryStageError>> {
353 rlp_buf.clear();
354 let sender = tx.recover_unchecked_with_buf(rlp_buf).map_err(|_| {
360 SenderRecoveryStageError::FailedRecovery(FailedSenderRecoveryError { tx: tx_id })
361 })?;
362
363 Ok((tx_id, sender))
364}
365
366fn stage_checkpoint<Provider>(provider: &Provider) -> Result<EntitiesCheckpoint, StageError>
367where
368 Provider: StatsReader + StaticFileProviderFactory + PruneCheckpointReader,
369{
370 let pruned_entries = provider
371 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
372 .and_then(|checkpoint| checkpoint.tx_number)
373 .unwrap_or_default();
374 Ok(EntitiesCheckpoint {
375 processed: provider.count_entries::<tables::TransactionSenders>()? as u64 + pruned_entries,
379 total: provider.static_file_provider().count_entries::<tables::Transactions>()? as u64,
383 })
384}
385
386#[derive(Error, Debug)]
387#[error(transparent)]
388enum SenderRecoveryStageError {
389 #[error(transparent)]
391 FailedRecovery(#[from] FailedSenderRecoveryError),
392
393 #[error("mismatched sender count during recovery: {_0}")]
395 RecoveredSendersMismatch(GotExpected<u64>),
396
397 #[error(transparent)]
399 StageError(#[from] StageError),
400}
401
402#[derive(Error, Debug)]
403#[error("sender recovery failed for transaction {tx}")]
404struct FailedSenderRecoveryError {
405 tx: TxNumber,
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use crate::test_utils::{
413 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
414 TestRunnerError, TestStageDB, UnwindStageTestRunner,
415 };
416 use alloy_primitives::{BlockNumber, B256};
417 use assert_matches::assert_matches;
418 use reth_db_api::cursor::DbCursorRO;
419 use reth_ethereum_primitives::{Block, TransactionSigned};
420 use reth_primitives_traits::{SealedBlock, SignerRecoverable};
421 use reth_provider::{
422 providers::StaticFileWriter, BlockBodyIndicesProvider, DatabaseProviderFactory,
423 PruneCheckpointWriter, StaticFileProviderFactory, TransactionsProvider,
424 };
425 use reth_prune_types::{PruneCheckpoint, PruneMode};
426 use reth_stages_api::StageUnitCheckpoint;
427 use reth_testing_utils::generators::{
428 self, random_block, random_block_range, BlockParams, BlockRangeParams,
429 };
430
431 stage_test_suite_ext!(SenderRecoveryTestRunner, sender_recovery);
432
433 #[tokio::test]
435 async fn execute_single_transaction() {
436 let (previous_stage, stage_progress) = (500, 100);
437 let mut rng = generators::rng();
438
439 let runner = SenderRecoveryTestRunner::default();
441 let input = ExecInput {
442 target: Some(previous_stage),
443 checkpoint: Some(StageCheckpoint::new(stage_progress)),
444 };
445
446 let non_empty_block_number = stage_progress + 10;
448 let blocks = (stage_progress..=input.target())
449 .map(|number| {
450 random_block(
451 &mut rng,
452 number,
453 BlockParams {
454 tx_count: Some((number == non_empty_block_number) as u8),
455 ..Default::default()
456 },
457 )
458 })
459 .collect::<Vec<_>>();
460 runner
461 .db
462 .insert_blocks(blocks.iter(), StorageKind::Static)
463 .expect("failed to insert blocks");
464
465 let rx = runner.execute(input);
466
467 let result = rx.await.unwrap();
469 assert_matches!(
470 result,
471 Ok(ExecOutput { checkpoint: StageCheckpoint {
472 block_number,
473 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
474 processed: 1,
475 total: 1
476 }))
477 }, done: true }) if block_number == previous_stage
478 );
479
480 assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
482 }
483
484 #[tokio::test]
486 async fn execute_intermediate_commit() {
487 let mut rng = generators::rng();
488
489 let threshold = 10;
490 let mut runner = SenderRecoveryTestRunner::default();
491 runner.set_threshold(threshold);
492 let (stage_progress, previous_stage) = (1000, 1100); let seed = random_block_range(
496 &mut rng,
497 stage_progress + 1..=previous_stage,
498 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..4, ..Default::default() },
499 ); runner
501 .db
502 .insert_blocks(seed.iter(), StorageKind::Static)
503 .expect("failed to seed execution");
504
505 let total_transactions = runner
506 .db
507 .factory
508 .static_file_provider()
509 .count_entries::<tables::Transactions>()
510 .unwrap() as u64;
511
512 let first_input = ExecInput {
513 target: Some(previous_stage),
514 checkpoint: Some(StageCheckpoint::new(stage_progress)),
515 };
516
517 let result = runner.execute(first_input).await.unwrap();
519 let mut tx_count = 0;
520 let expected_progress = seed
521 .iter()
522 .find(|x| {
523 tx_count += x.transaction_count();
524 tx_count as u64 > threshold
525 })
526 .map(|x| x.number)
527 .unwrap_or(previous_stage);
528 assert_matches!(result, Ok(_));
529 assert_eq!(
530 result.unwrap(),
531 ExecOutput {
532 checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint(
533 EntitiesCheckpoint {
534 processed: runner.db.count_entries::<tables::TransactionSenders>().unwrap()
535 as u64,
536 total: total_transactions
537 }
538 ),
539 done: false
540 }
541 );
542
543 runner.set_threshold(u64::MAX);
545 let second_input = ExecInput {
546 target: Some(previous_stage),
547 checkpoint: Some(StageCheckpoint::new(expected_progress)),
548 };
549 let result = runner.execute(second_input).await.unwrap();
550 assert_matches!(result, Ok(_));
551 assert_eq!(
552 result.as_ref().unwrap(),
553 &ExecOutput {
554 checkpoint: StageCheckpoint::new(previous_stage).with_entities_stage_checkpoint(
555 EntitiesCheckpoint { processed: total_transactions, total: total_transactions }
556 ),
557 done: true
558 }
559 );
560
561 assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed");
562 }
563
564 #[test]
565 fn stage_checkpoint_pruned() {
566 let db = TestStageDB::default();
567 let mut rng = generators::rng();
568
569 let blocks = random_block_range(
570 &mut rng,
571 0..=100,
572 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..10, ..Default::default() },
573 );
574 db.insert_blocks(blocks.iter(), StorageKind::Static).expect("insert blocks");
575
576 let max_pruned_block = 30;
577 let max_processed_block = 70;
578
579 let mut tx_senders = Vec::new();
580 let mut tx_number = 0;
581 for block in &blocks[..=max_processed_block] {
582 for transaction in &block.body().transactions {
583 if block.number > max_pruned_block {
584 tx_senders
585 .push((tx_number, transaction.recover_signer().expect("recover signer")));
586 }
587 tx_number += 1;
588 }
589 }
590 db.insert_transaction_senders(tx_senders).expect("insert tx hash numbers");
591
592 let provider = db.factory.provider_rw().unwrap();
593 provider
594 .save_prune_checkpoint(
595 PruneSegment::SenderRecovery,
596 PruneCheckpoint {
597 block_number: Some(max_pruned_block),
598 tx_number: Some(
599 blocks[..=max_pruned_block as usize]
600 .iter()
601 .map(|block| block.transaction_count() as u64)
602 .sum(),
603 ),
604 prune_mode: PruneMode::Full,
605 },
606 )
607 .expect("save stage checkpoint");
608 provider.commit().expect("commit");
609
610 let provider = db.factory.database_provider_rw().unwrap();
611 assert_eq!(
612 stage_checkpoint(&provider).expect("stage checkpoint"),
613 EntitiesCheckpoint {
614 processed: blocks[..=max_processed_block]
615 .iter()
616 .map(|block| block.transaction_count() as u64)
617 .sum(),
618 total: blocks.iter().map(|block| block.transaction_count() as u64).sum()
619 }
620 );
621 }
622
623 struct SenderRecoveryTestRunner {
624 db: TestStageDB,
625 threshold: u64,
626 }
627
628 impl Default for SenderRecoveryTestRunner {
629 fn default() -> Self {
630 Self { threshold: 1000, db: TestStageDB::default() }
631 }
632 }
633
634 impl SenderRecoveryTestRunner {
635 fn set_threshold(&mut self, threshold: u64) {
636 self.threshold = threshold;
637 }
638
639 fn ensure_no_senders_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
646 let body_result = self
647 .db
648 .factory
649 .provider_rw()?
650 .block_body_indices(block)?
651 .ok_or(ProviderError::BlockBodyIndicesNotFound(block));
652 match body_result {
653 Ok(body) => self.db.ensure_no_entry_above::<tables::TransactionSenders, _>(
654 body.last_tx_num(),
655 |key| key,
656 )?,
657 Err(_) => {
658 assert!(self.db.table_is_empty::<tables::TransactionSenders>()?);
659 }
660 };
661
662 Ok(())
663 }
664 }
665
666 impl StageTestRunner for SenderRecoveryTestRunner {
667 type S = SenderRecoveryStage;
668
669 fn db(&self) -> &TestStageDB {
670 &self.db
671 }
672
673 fn stage(&self) -> Self::S {
674 SenderRecoveryStage { commit_threshold: self.threshold }
675 }
676 }
677
678 impl ExecuteStageTestRunner for SenderRecoveryTestRunner {
679 type Seed = Vec<SealedBlock<Block>>;
680
681 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
682 let mut rng = generators::rng();
683 let stage_progress = input.checkpoint().block_number;
684 let end = input.target();
685
686 let blocks = random_block_range(
687 &mut rng,
688 stage_progress..=end,
689 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..2, ..Default::default() },
690 );
691 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
692 Ok(blocks)
693 }
694
695 fn validate_execution(
696 &self,
697 input: ExecInput,
698 output: Option<ExecOutput>,
699 ) -> Result<(), TestRunnerError> {
700 match output {
701 Some(output) => {
702 let provider = self.db.factory.provider()?;
703 let start_block = input.next_block();
704 let end_block = output.checkpoint.block_number;
705
706 if start_block > end_block {
707 return Ok(())
708 }
709
710 let mut body_cursor =
711 provider.tx_ref().cursor_read::<tables::BlockBodyIndices>()?;
712 body_cursor.seek_exact(start_block)?;
713
714 while let Some((_, body)) = body_cursor.next()? {
715 for tx_id in body.tx_num_range() {
716 let transaction: TransactionSigned = provider
717 .transaction_by_id_unhashed(tx_id)?
718 .expect("no transaction entry");
719 let signer =
720 transaction.recover_signer().expect("failed to recover signer");
721 assert_eq!(Some(signer), provider.transaction_sender(tx_id)?)
722 }
723 }
724 }
725 None => self.ensure_no_senders_by_block(input.checkpoint().block_number)?,
726 };
727
728 Ok(())
729 }
730 }
731
732 impl UnwindStageTestRunner for SenderRecoveryTestRunner {
733 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
734 self.ensure_no_senders_by_block(input.unwind_to)
735 }
736 }
737}