1use alloy_primitives::{Address, TxNumber};
2use reth_config::config::SenderRecoveryConfig;
3use reth_consensus::ConsensusError;
4use reth_db::static_file::TransactionMask;
5use reth_db_api::{
6 cursor::DbCursorRW,
7 table::Value,
8 tables,
9 transaction::{DbTx, DbTxMut},
10 DbTxUnwindExt, RawValue,
11};
12use reth_primitives_traits::{GotExpected, NodePrimitives, SignedTransaction};
13use reth_provider::{
14 BlockReader, DBProvider, HeaderProvider, ProviderError, PruneCheckpointReader,
15 StaticFileProviderFactory, StatsReader,
16};
17use reth_prune_types::PruneSegment;
18use reth_stages_api::{
19 BlockErrorKind, EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError,
20 StageId, UnwindInput, UnwindOutput,
21};
22use reth_static_file_types::StaticFileSegment;
23use std::{fmt::Debug, ops::Range, sync::mpsc};
24use thiserror::Error;
25use tracing::*;
26
27const BATCH_SIZE: usize = 100_000;
31
32const WORKER_CHUNK_SIZE: usize = 100;
34
35type RecoveryResultSender = mpsc::Sender<Result<(u64, Address), Box<SenderRecoveryStageError>>>;
37
38#[derive(Clone, Debug)]
42pub struct SenderRecoveryStage {
43 pub commit_threshold: u64,
46}
47
48impl SenderRecoveryStage {
49 pub const fn new(config: SenderRecoveryConfig) -> Self {
51 Self { commit_threshold: config.commit_threshold }
52 }
53}
54
55impl Default for SenderRecoveryStage {
56 fn default() -> Self {
57 Self { commit_threshold: 5_000_000 }
58 }
59}
60
61impl<Provider> Stage<Provider> for SenderRecoveryStage
62where
63 Provider: DBProvider<Tx: DbTxMut>
64 + BlockReader
65 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>
66 + StatsReader
67 + PruneCheckpointReader,
68{
69 fn id(&self) -> StageId {
71 StageId::SenderRecovery
72 }
73
74 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
79 if input.target_reached() {
80 return Ok(ExecOutput::done(input.checkpoint()))
81 }
82
83 let (tx_range, block_range, is_final_range) =
84 input.next_block_range_with_transaction_threshold(provider, self.commit_threshold)?;
85 let end_block = *block_range.end();
86
87 if tx_range.is_empty() {
89 info!(target: "sync::stages::sender_recovery", ?tx_range, "Target transaction already reached");
90 return Ok(ExecOutput {
91 checkpoint: StageCheckpoint::new(end_block)
92 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
93 done: is_final_range,
94 })
95 }
96
97 let mut senders_cursor = provider.tx_ref().cursor_write::<tables::TransactionSenders>()?;
99
100 info!(target: "sync::stages::sender_recovery", ?tx_range, "Recovering senders");
101
102 let batch = tx_range
104 .clone()
105 .step_by(BATCH_SIZE)
106 .map(|start| start..std::cmp::min(start + BATCH_SIZE as u64, tx_range.end))
107 .collect::<Vec<Range<u64>>>();
108
109 let tx_batch_sender = setup_range_recovery(provider);
110
111 for range in batch {
112 recover_range(range, provider, tx_batch_sender.clone(), &mut senders_cursor)?;
113 }
114
115 Ok(ExecOutput {
116 checkpoint: StageCheckpoint::new(end_block)
117 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
118 done: is_final_range,
119 })
120 }
121
122 fn unwind(
124 &mut self,
125 provider: &Provider,
126 input: UnwindInput,
127 ) -> Result<UnwindOutput, StageError> {
128 let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
129
130 let latest_tx_id = provider
132 .block_body_indices(unwind_to)?
133 .ok_or(ProviderError::BlockBodyIndicesNotFound(unwind_to))?
134 .last_tx_num();
135 provider.tx_ref().unwind_table_by_num::<tables::TransactionSenders>(latest_tx_id)?;
136
137 Ok(UnwindOutput {
138 checkpoint: StageCheckpoint::new(unwind_to)
139 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
140 })
141 }
142}
143
144fn recover_range<Provider, CURSOR>(
145 tx_range: Range<u64>,
146 provider: &Provider,
147 tx_batch_sender: mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>,
148 senders_cursor: &mut CURSOR,
149) -> Result<(), StageError>
150where
151 Provider: DBProvider + HeaderProvider + StaticFileProviderFactory,
152 CURSOR: DbCursorRW<tables::TransactionSenders>,
153{
154 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Sending batch for processing");
155
156 let (chunks, receivers): (Vec<_>, Vec<_>) = tx_range
158 .clone()
159 .step_by(WORKER_CHUNK_SIZE)
160 .map(|start| {
161 let range = start..std::cmp::min(start + WORKER_CHUNK_SIZE as u64, tx_range.end);
162 let (tx, rx) = mpsc::channel();
163 ((range, tx), rx)
165 })
166 .unzip();
167
168 if let Some(err) = tx_batch_sender.send(chunks).err() {
169 return Err(StageError::Fatal(err.into()));
170 }
171
172 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Appending recovered senders to the database");
173
174 let mut processed_transactions = 0;
175 for channel in receivers {
176 while let Ok(recovered) = channel.recv() {
177 let (tx_id, sender) = match recovered {
178 Ok(result) => result,
179 Err(error) => {
180 return match *error {
181 SenderRecoveryStageError::FailedRecovery(err) => {
182 let block_number = provider
184 .tx_ref()
185 .get::<tables::TransactionBlocks>(err.tx)?
186 .ok_or(ProviderError::BlockNumberForTransactionIndexNotFound)?;
187
188 let sealed_header =
191 provider.sealed_header(block_number)?.ok_or_else(|| {
192 ProviderError::HeaderNotFound(block_number.into())
193 })?;
194
195 Err(StageError::Block {
196 block: Box::new(sealed_header.block_with_parent()),
197 error: BlockErrorKind::Validation(
198 ConsensusError::TransactionSignerRecoveryError,
199 ),
200 })
201 }
202 SenderRecoveryStageError::StageError(err) => Err(err),
203 SenderRecoveryStageError::RecoveredSendersMismatch(expectation) => {
204 Err(StageError::Fatal(
205 SenderRecoveryStageError::RecoveredSendersMismatch(expectation)
206 .into(),
207 ))
208 }
209 }
210 }
211 };
212 senders_cursor.append(tx_id, &sender)?;
213 processed_transactions += 1;
214 }
215 }
216 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Finished recovering senders batch");
217
218 let expected = tx_range.end - tx_range.start;
220 if processed_transactions != expected {
221 return Err(StageError::Fatal(
222 SenderRecoveryStageError::RecoveredSendersMismatch(GotExpected {
223 got: processed_transactions,
224 expected,
225 })
226 .into(),
227 ));
228 }
229 Ok(())
230}
231
232fn setup_range_recovery<Provider>(
236 provider: &Provider,
237) -> mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>
238where
239 Provider: DBProvider
240 + HeaderProvider
241 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>,
242{
243 let (tx_sender, tx_receiver) = mpsc::channel::<Vec<(Range<u64>, RecoveryResultSender)>>();
244 let static_file_provider = provider.static_file_provider();
245
246 std::thread::spawn(move || {
254 while let Ok(chunks) = tx_receiver.recv() {
255 for (chunk_range, recovered_senders_tx) in chunks {
256 let chunk = match static_file_provider.fetch_range_with_predicate(
258 StaticFileSegment::Transactions,
259 chunk_range.clone(),
260 |cursor, number| {
261 Ok(cursor
262 .get_one::<TransactionMask<
263 RawValue<<Provider::Primitives as NodePrimitives>::SignedTx>,
264 >>(number.into())?
265 .map(|tx| (number, tx)))
266 },
267 |_| true,
268 ) {
269 Ok(chunk) => chunk,
270 Err(err) => {
271 let _ = recovered_senders_tx
273 .send(Err(Box::new(SenderRecoveryStageError::StageError(err.into()))));
274 break
275 }
276 };
277
278 rayon::spawn(move || {
282 let mut rlp_buf = Vec::with_capacity(128);
283 for (number, tx) in chunk {
284 let res = tx
285 .value()
286 .map_err(|err| {
287 Box::new(SenderRecoveryStageError::StageError(err.into()))
288 })
289 .and_then(|tx| recover_sender((number, tx), &mut rlp_buf));
290
291 let is_err = res.is_err();
292
293 let _ = recovered_senders_tx.send(res);
294
295 if is_err {
297 break
298 }
299 }
300 });
301 }
302 }
303 });
304 tx_sender
305}
306
307#[inline]
308fn recover_sender<T: SignedTransaction>(
309 (tx_id, tx): (TxNumber, T),
310 rlp_buf: &mut Vec<u8>,
311) -> Result<(u64, Address), Box<SenderRecoveryStageError>> {
312 rlp_buf.clear();
313 let sender = tx.recover_signer_unchecked_with_buf(rlp_buf).map_err(|_| {
319 SenderRecoveryStageError::FailedRecovery(FailedSenderRecoveryError { tx: tx_id })
320 })?;
321
322 Ok((tx_id, sender))
323}
324
325fn stage_checkpoint<Provider>(provider: &Provider) -> Result<EntitiesCheckpoint, StageError>
326where
327 Provider: StatsReader + StaticFileProviderFactory + PruneCheckpointReader,
328{
329 let pruned_entries = provider
330 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
331 .and_then(|checkpoint| checkpoint.tx_number)
332 .unwrap_or_default();
333 Ok(EntitiesCheckpoint {
334 processed: provider.count_entries::<tables::TransactionSenders>()? as u64 + pruned_entries,
338 total: provider.static_file_provider().count_entries::<tables::Transactions>()? as u64,
342 })
343}
344
345#[derive(Error, Debug)]
346#[error(transparent)]
347enum SenderRecoveryStageError {
348 #[error(transparent)]
350 FailedRecovery(#[from] FailedSenderRecoveryError),
351
352 #[error("mismatched sender count during recovery: {_0}")]
354 RecoveredSendersMismatch(GotExpected<u64>),
355
356 #[error(transparent)]
358 StageError(#[from] StageError),
359}
360
361#[derive(Error, Debug)]
362#[error("sender recovery failed for transaction {tx}")]
363struct FailedSenderRecoveryError {
364 tx: TxNumber,
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::test_utils::{
372 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
373 TestRunnerError, TestStageDB, UnwindStageTestRunner,
374 };
375 use alloy_primitives::{BlockNumber, B256};
376 use assert_matches::assert_matches;
377 use reth_db_api::cursor::DbCursorRO;
378 use reth_ethereum_primitives::{Block, TransactionSigned};
379 use reth_primitives_traits::{SealedBlock, SignedTransaction};
380 use reth_provider::{
381 providers::StaticFileWriter, BlockBodyIndicesProvider, DatabaseProviderFactory,
382 PruneCheckpointWriter, StaticFileProviderFactory, TransactionsProvider,
383 };
384 use reth_prune_types::{PruneCheckpoint, PruneMode};
385 use reth_stages_api::StageUnitCheckpoint;
386 use reth_testing_utils::generators::{
387 self, random_block, random_block_range, BlockParams, BlockRangeParams,
388 };
389
390 stage_test_suite_ext!(SenderRecoveryTestRunner, sender_recovery);
391
392 #[tokio::test]
394 async fn execute_single_transaction() {
395 let (previous_stage, stage_progress) = (500, 100);
396 let mut rng = generators::rng();
397
398 let runner = SenderRecoveryTestRunner::default();
400 let input = ExecInput {
401 target: Some(previous_stage),
402 checkpoint: Some(StageCheckpoint::new(stage_progress)),
403 };
404
405 let non_empty_block_number = stage_progress + 10;
407 let blocks = (stage_progress..=input.target())
408 .map(|number| {
409 random_block(
410 &mut rng,
411 number,
412 BlockParams {
413 tx_count: Some((number == non_empty_block_number) as u8),
414 ..Default::default()
415 },
416 )
417 })
418 .collect::<Vec<_>>();
419 runner
420 .db
421 .insert_blocks(blocks.iter(), StorageKind::Static)
422 .expect("failed to insert blocks");
423
424 let rx = runner.execute(input);
425
426 let result = rx.await.unwrap();
428 assert_matches!(
429 result,
430 Ok(ExecOutput { checkpoint: StageCheckpoint {
431 block_number,
432 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
433 processed: 1,
434 total: 1
435 }))
436 }, done: true }) if block_number == previous_stage
437 );
438
439 assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
441 }
442
443 #[tokio::test]
445 async fn execute_intermediate_commit() {
446 let mut rng = generators::rng();
447
448 let threshold = 10;
449 let mut runner = SenderRecoveryTestRunner::default();
450 runner.set_threshold(threshold);
451 let (stage_progress, previous_stage) = (1000, 1100); let seed = random_block_range(
455 &mut rng,
456 stage_progress + 1..=previous_stage,
457 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..4, ..Default::default() },
458 ); runner
460 .db
461 .insert_blocks(seed.iter(), StorageKind::Static)
462 .expect("failed to seed execution");
463
464 let total_transactions = runner
465 .db
466 .factory
467 .static_file_provider()
468 .count_entries::<tables::Transactions>()
469 .unwrap() as u64;
470
471 let first_input = ExecInput {
472 target: Some(previous_stage),
473 checkpoint: Some(StageCheckpoint::new(stage_progress)),
474 };
475
476 let result = runner.execute(first_input).await.unwrap();
478 let mut tx_count = 0;
479 let expected_progress = seed
480 .iter()
481 .find(|x| {
482 tx_count += x.transaction_count();
483 tx_count as u64 > threshold
484 })
485 .map(|x| x.number)
486 .unwrap_or(previous_stage);
487 assert_matches!(result, Ok(_));
488 assert_eq!(
489 result.unwrap(),
490 ExecOutput {
491 checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint(
492 EntitiesCheckpoint {
493 processed: runner.db.table::<tables::TransactionSenders>().unwrap().len()
494 as u64,
495 total: total_transactions
496 }
497 ),
498 done: false
499 }
500 );
501
502 runner.set_threshold(u64::MAX);
504 let second_input = ExecInput {
505 target: Some(previous_stage),
506 checkpoint: Some(StageCheckpoint::new(expected_progress)),
507 };
508 let result = runner.execute(second_input).await.unwrap();
509 assert_matches!(result, Ok(_));
510 assert_eq!(
511 result.as_ref().unwrap(),
512 &ExecOutput {
513 checkpoint: StageCheckpoint::new(previous_stage).with_entities_stage_checkpoint(
514 EntitiesCheckpoint { processed: total_transactions, total: total_transactions }
515 ),
516 done: true
517 }
518 );
519
520 assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed");
521 }
522
523 #[test]
524 fn stage_checkpoint_pruned() {
525 let db = TestStageDB::default();
526 let mut rng = generators::rng();
527
528 let blocks = random_block_range(
529 &mut rng,
530 0..=100,
531 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..10, ..Default::default() },
532 );
533 db.insert_blocks(blocks.iter(), StorageKind::Static).expect("insert blocks");
534
535 let max_pruned_block = 30;
536 let max_processed_block = 70;
537
538 let mut tx_senders = Vec::new();
539 let mut tx_number = 0;
540 for block in &blocks[..=max_processed_block] {
541 for transaction in &block.body().transactions {
542 if block.number > max_pruned_block {
543 tx_senders
544 .push((tx_number, transaction.recover_signer().expect("recover signer")));
545 }
546 tx_number += 1;
547 }
548 }
549 db.insert_transaction_senders(tx_senders).expect("insert tx hash numbers");
550
551 let provider = db.factory.provider_rw().unwrap();
552 provider
553 .save_prune_checkpoint(
554 PruneSegment::SenderRecovery,
555 PruneCheckpoint {
556 block_number: Some(max_pruned_block),
557 tx_number: Some(
558 blocks[..=max_pruned_block as usize]
559 .iter()
560 .map(|block| block.transaction_count() as u64)
561 .sum(),
562 ),
563 prune_mode: PruneMode::Full,
564 },
565 )
566 .expect("save stage checkpoint");
567 provider.commit().expect("commit");
568
569 let provider = db.factory.database_provider_rw().unwrap();
570 assert_eq!(
571 stage_checkpoint(&provider).expect("stage checkpoint"),
572 EntitiesCheckpoint {
573 processed: blocks[..=max_processed_block]
574 .iter()
575 .map(|block| block.transaction_count() as u64)
576 .sum(),
577 total: blocks.iter().map(|block| block.transaction_count() as u64).sum()
578 }
579 );
580 }
581
582 struct SenderRecoveryTestRunner {
583 db: TestStageDB,
584 threshold: u64,
585 }
586
587 impl Default for SenderRecoveryTestRunner {
588 fn default() -> Self {
589 Self { threshold: 1000, db: TestStageDB::default() }
590 }
591 }
592
593 impl SenderRecoveryTestRunner {
594 fn set_threshold(&mut self, threshold: u64) {
595 self.threshold = threshold;
596 }
597
598 fn ensure_no_senders_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
605 let body_result = self
606 .db
607 .factory
608 .provider_rw()?
609 .block_body_indices(block)?
610 .ok_or(ProviderError::BlockBodyIndicesNotFound(block));
611 match body_result {
612 Ok(body) => self.db.ensure_no_entry_above::<tables::TransactionSenders, _>(
613 body.last_tx_num(),
614 |key| key,
615 )?,
616 Err(_) => {
617 assert!(self.db.table_is_empty::<tables::TransactionSenders>()?);
618 }
619 };
620
621 Ok(())
622 }
623 }
624
625 impl StageTestRunner for SenderRecoveryTestRunner {
626 type S = SenderRecoveryStage;
627
628 fn db(&self) -> &TestStageDB {
629 &self.db
630 }
631
632 fn stage(&self) -> Self::S {
633 SenderRecoveryStage { commit_threshold: self.threshold }
634 }
635 }
636
637 impl ExecuteStageTestRunner for SenderRecoveryTestRunner {
638 type Seed = Vec<SealedBlock<Block>>;
639
640 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
641 let mut rng = generators::rng();
642 let stage_progress = input.checkpoint().block_number;
643 let end = input.target();
644
645 let blocks = random_block_range(
646 &mut rng,
647 stage_progress..=end,
648 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..2, ..Default::default() },
649 );
650 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
651 Ok(blocks)
652 }
653
654 fn validate_execution(
655 &self,
656 input: ExecInput,
657 output: Option<ExecOutput>,
658 ) -> Result<(), TestRunnerError> {
659 match output {
660 Some(output) => {
661 let provider = self.db.factory.provider()?;
662 let start_block = input.next_block();
663 let end_block = output.checkpoint.block_number;
664
665 if start_block > end_block {
666 return Ok(())
667 }
668
669 let mut body_cursor =
670 provider.tx_ref().cursor_read::<tables::BlockBodyIndices>()?;
671 body_cursor.seek_exact(start_block)?;
672
673 while let Some((_, body)) = body_cursor.next()? {
674 for tx_id in body.tx_num_range() {
675 let transaction: TransactionSigned = provider
676 .transaction_by_id_unhashed(tx_id)?
677 .expect("no transaction entry");
678 let signer =
679 transaction.recover_signer().expect("failed to recover signer");
680 assert_eq!(Some(signer), provider.transaction_sender(tx_id)?)
681 }
682 }
683 }
684 None => self.ensure_no_senders_by_block(input.checkpoint().block_number)?,
685 };
686
687 Ok(())
688 }
689 }
690
691 impl UnwindStageTestRunner for SenderRecoveryTestRunner {
692 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
693 self.ensure_no_senders_by_block(input.unwind_to)
694 }
695 }
696}