1use reth_db_api::{table::Value, transaction::DbTxMut};
2use reth_primitives_traits::NodePrimitives;
3use reth_provider::{
4 BlockReader, DBProvider, PruneCheckpointReader, PruneCheckpointWriter,
5 StaticFileProviderFactory,
6};
7use reth_prune::{
8 PruneMode, PruneModes, PruneSegment, PrunerBuilder, SegmentOutput, SegmentOutputCheckpoint,
9};
10use reth_stages_api::{
11 ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
12};
13use tracing::info;
14
15#[derive(Debug)]
27pub struct PruneStage {
28 prune_modes: PruneModes,
29 commit_threshold: usize,
30}
31
32impl PruneStage {
33 pub const fn new(prune_modes: PruneModes, commit_threshold: usize) -> Self {
35 Self { prune_modes, commit_threshold }
36 }
37}
38
39impl<Provider> Stage<Provider> for PruneStage
40where
41 Provider: DBProvider<Tx: DbTxMut>
42 + PruneCheckpointReader
43 + PruneCheckpointWriter
44 + BlockReader
45 + StaticFileProviderFactory<
46 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
47 >,
48{
49 fn id(&self) -> StageId {
50 StageId::Prune
51 }
52
53 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
54 let mut pruner = PrunerBuilder::default()
55 .segments(self.prune_modes.clone())
56 .delete_limit(self.commit_threshold)
57 .build::<Provider>(provider.static_file_provider());
58
59 let result = pruner.run_with_provider(provider, input.target())?;
60 if result.progress.is_finished() {
61 Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
62 } else {
63 if let Some((last_segment, last_segment_output)) = result.segments.last() {
64 match last_segment_output {
65 SegmentOutput {
66 progress,
67 pruned,
68 checkpoint:
69 checkpoint @ Some(SegmentOutputCheckpoint { block_number: Some(_), .. }),
70 } => {
71 info!(
72 target: "sync::stages::prune::exec",
73 ?last_segment,
74 ?progress,
75 ?pruned,
76 ?checkpoint,
77 "Last segment has more data to prune"
78 )
79 }
80 SegmentOutput { progress, pruned, checkpoint: _ } => {
81 info!(
82 target: "sync::stages::prune::exec",
83 ?last_segment,
84 ?progress,
85 ?pruned,
86 "Last segment has more data to prune"
87 )
88 }
89 }
90 }
91 Ok(ExecOutput { checkpoint: input.checkpoint(), done: false })
94 }
95 }
96
97 fn unwind(
98 &mut self,
99 provider: &Provider,
100 input: UnwindInput,
101 ) -> Result<UnwindOutput, StageError> {
102 let prune_checkpoints = provider.get_prune_checkpoints()?;
105 for (segment, mut checkpoint) in prune_checkpoints {
106 checkpoint.block_number = Some(input.unwind_to);
107 provider.save_prune_checkpoint(segment, checkpoint)?;
108 }
109 Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
110 }
111}
112
113#[derive(Debug)]
118pub struct PruneSenderRecoveryStage(PruneStage);
119
120impl PruneSenderRecoveryStage {
121 pub fn new(prune_mode: PruneMode, commit_threshold: usize) -> Self {
123 Self(PruneStage::new(
124 PruneModes { sender_recovery: Some(prune_mode), ..PruneModes::none() },
125 commit_threshold,
126 ))
127 }
128}
129
130impl<Provider> Stage<Provider> for PruneSenderRecoveryStage
131where
132 Provider: DBProvider<Tx: DbTxMut>
133 + PruneCheckpointReader
134 + PruneCheckpointWriter
135 + BlockReader
136 + StaticFileProviderFactory<
137 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
138 >,
139{
140 fn id(&self) -> StageId {
141 StageId::PruneSenderRecovery
142 }
143
144 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
145 let mut result = self.0.execute(provider, input)?;
146
147 if !result.done {
149 let checkpoint = provider
150 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
151 .ok_or(StageError::MissingPruneCheckpoint(PruneSegment::SenderRecovery))?;
152
153 result.checkpoint = StageCheckpoint::new(checkpoint.block_number.unwrap_or_default());
156 }
157
158 Ok(result)
159 }
160
161 fn unwind(
162 &mut self,
163 provider: &Provider,
164 input: UnwindInput,
165 ) -> Result<UnwindOutput, StageError> {
166 self.0.unwind(provider, input)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::test_utils::{
174 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
175 TestRunnerError, TestStageDB, UnwindStageTestRunner,
176 };
177 use alloy_primitives::B256;
178 use reth_ethereum_primitives::Block;
179 use reth_primitives_traits::{SealedBlock, SignerRecoverable};
180 use reth_provider::{
181 providers::StaticFileWriter, TransactionsProvider, TransactionsProviderExt,
182 };
183 use reth_prune::PruneMode;
184 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
185
186 stage_test_suite_ext!(PruneTestRunner, prune);
187
188 #[derive(Default)]
189 struct PruneTestRunner {
190 db: TestStageDB,
191 }
192
193 impl StageTestRunner for PruneTestRunner {
194 type S = PruneStage;
195
196 fn db(&self) -> &TestStageDB {
197 &self.db
198 }
199
200 fn stage(&self) -> Self::S {
201 PruneStage {
202 prune_modes: PruneModes {
203 sender_recovery: Some(PruneMode::Full),
204 ..Default::default()
205 },
206 commit_threshold: usize::MAX,
207 }
208 }
209 }
210
211 impl ExecuteStageTestRunner for PruneTestRunner {
212 type Seed = Vec<SealedBlock<Block>>;
213
214 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
215 let mut rng = generators::rng();
216 let blocks = random_block_range(
217 &mut rng,
218 input.checkpoint().block_number..=input.target(),
219 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 1..3, ..Default::default() },
220 );
221 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
222 self.db.insert_transaction_senders(
223 blocks.iter().flat_map(|block| block.body().transactions.iter()).enumerate().map(
224 |(i, tx)| (i as u64, tx.recover_signer().expect("failed to recover signer")),
225 ),
226 )?;
227 Ok(blocks)
228 }
229
230 fn validate_execution(
231 &self,
232 input: ExecInput,
233 output: Option<ExecOutput>,
234 ) -> Result<(), TestRunnerError> {
235 if let Some(output) = output {
236 let start_block = input.next_block();
237 let end_block = output.checkpoint.block_number;
238
239 if start_block > end_block {
240 return Ok(())
241 }
242
243 let provider = self.db.factory.provider()?;
244
245 assert!(output.done);
246 assert_eq!(
247 output.checkpoint.block_number,
248 provider
249 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
250 .expect("prune checkpoint must exist")
251 .block_number
252 .unwrap_or_default()
253 );
254
255 let tx_range =
257 provider.transaction_range_by_block_range(start_block..=end_block)?;
258 let senders = self.db.factory.provider()?.senders_by_tx_range(tx_range)?;
259 assert!(senders.is_empty());
260 }
261 Ok(())
262 }
263 }
264
265 impl UnwindStageTestRunner for PruneTestRunner {
266 fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
267 Ok(())
268 }
269 }
270}