1use reth_db_api::{table::Value, transaction::DbTxMut};
2use reth_primitives_traits::NodePrimitives;
3use reth_provider::{
4 BlockReader, ChainStateBlockReader, DBProvider, PruneCheckpointReader, PruneCheckpointWriter,
5 StageCheckpointReader, StaticFileProviderFactory, StorageSettingsCache,
6};
7use reth_prune::{
8 PruneMode, PruneModes, PruneSegment, PrunerBuilder, SegmentOutput, SegmentOutputCheckpoint,
9};
10use reth_stages_api::{
11 ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
12};
13use tracing::info;
14
15#[derive(Debug)]
27pub struct PruneStage {
28 prune_modes: PruneModes,
29 commit_threshold: usize,
30}
31
32impl PruneStage {
33 pub const fn new(prune_modes: PruneModes, commit_threshold: usize) -> Self {
35 Self { prune_modes, commit_threshold }
36 }
37}
38
39impl<Provider> Stage<Provider> for PruneStage
40where
41 Provider: DBProvider<Tx: DbTxMut>
42 + PruneCheckpointReader
43 + PruneCheckpointWriter
44 + BlockReader
45 + ChainStateBlockReader
46 + StageCheckpointReader
47 + StaticFileProviderFactory<
48 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
49 > + StorageSettingsCache,
50{
51 fn id(&self) -> StageId {
52 StageId::Prune
53 }
54
55 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
56 let mut pruner = PrunerBuilder::default()
57 .segments(self.prune_modes.clone())
58 .delete_limit(self.commit_threshold)
59 .build::<Provider>(provider.static_file_provider());
60
61 let result = pruner.run_with_provider(provider, input.target())?;
62 if result.progress.is_finished() {
63 Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
64 } else {
65 if let Some((last_segment, last_segment_output)) = result.segments.last() {
66 match last_segment_output {
67 SegmentOutput {
68 progress,
69 pruned,
70 checkpoint:
71 checkpoint @ Some(SegmentOutputCheckpoint { block_number: Some(_), .. }),
72 } => {
73 info!(
74 target: "sync::stages::prune::exec",
75 ?last_segment,
76 ?progress,
77 ?pruned,
78 ?checkpoint,
79 "Last segment has more data to prune"
80 )
81 }
82 SegmentOutput { progress, pruned, checkpoint: _ } => {
83 info!(
84 target: "sync::stages::prune::exec",
85 ?last_segment,
86 ?progress,
87 ?pruned,
88 "Last segment has more data to prune"
89 )
90 }
91 }
92 }
93 Ok(ExecOutput { checkpoint: input.checkpoint(), done: false })
96 }
97 }
98
99 fn unwind(
100 &mut self,
101 provider: &Provider,
102 input: UnwindInput,
103 ) -> Result<UnwindOutput, StageError> {
104 let prune_checkpoints = provider.get_prune_checkpoints()?;
107 let unwind_to_last_tx =
108 provider.block_body_indices(input.unwind_to)?.map(|i| i.last_tx_num());
109
110 for (segment, mut checkpoint) in prune_checkpoints {
111 if let Some(block) = checkpoint.block_number &&
113 input.unwind_to < block
114 {
115 checkpoint.block_number = Some(input.unwind_to);
116 checkpoint.tx_number = unwind_to_last_tx;
117 provider.save_prune_checkpoint(segment, checkpoint)?;
118 }
119 }
120 Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
121 }
122}
123
124#[derive(Debug)]
129pub struct PruneSenderRecoveryStage(PruneStage);
130
131impl PruneSenderRecoveryStage {
132 pub fn new(prune_mode: PruneMode, commit_threshold: usize) -> Self {
134 Self(PruneStage::new(
135 PruneModes { sender_recovery: Some(prune_mode), ..PruneModes::default() },
136 commit_threshold,
137 ))
138 }
139}
140
141impl<Provider> Stage<Provider> for PruneSenderRecoveryStage
142where
143 Provider: DBProvider<Tx: DbTxMut>
144 + PruneCheckpointReader
145 + PruneCheckpointWriter
146 + BlockReader
147 + ChainStateBlockReader
148 + StageCheckpointReader
149 + StaticFileProviderFactory<
150 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
151 > + StorageSettingsCache,
152{
153 fn id(&self) -> StageId {
154 StageId::PruneSenderRecovery
155 }
156
157 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
158 let mut result = self.0.execute(provider, input)?;
159
160 if !result.done {
162 let checkpoint = provider
163 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
164 .ok_or(StageError::MissingPruneCheckpoint(PruneSegment::SenderRecovery))?;
165
166 result.checkpoint = StageCheckpoint::new(checkpoint.block_number.unwrap_or_default());
169 }
170
171 Ok(result)
172 }
173
174 fn unwind(
175 &mut self,
176 provider: &Provider,
177 input: UnwindInput,
178 ) -> Result<UnwindOutput, StageError> {
179 self.0.unwind(provider, input)
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::test_utils::{
187 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
188 TestRunnerError, TestStageDB, UnwindStageTestRunner,
189 };
190 use alloy_primitives::B256;
191 use reth_ethereum_primitives::Block;
192 use reth_primitives_traits::{SealedBlock, SignerRecoverable};
193 use reth_provider::{
194 providers::StaticFileWriter, TransactionsProvider, TransactionsProviderExt,
195 };
196 use reth_prune::PruneMode;
197 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
198
199 stage_test_suite_ext!(PruneTestRunner, prune);
200
201 #[derive(Default)]
202 struct PruneTestRunner {
203 db: TestStageDB,
204 }
205
206 impl StageTestRunner for PruneTestRunner {
207 type S = PruneStage;
208
209 fn db(&self) -> &TestStageDB {
210 &self.db
211 }
212
213 fn stage(&self) -> Self::S {
214 PruneStage {
215 prune_modes: PruneModes {
216 sender_recovery: Some(PruneMode::Full),
217 ..Default::default()
218 },
219 commit_threshold: usize::MAX,
220 }
221 }
222 }
223
224 impl ExecuteStageTestRunner for PruneTestRunner {
225 type Seed = Vec<SealedBlock<Block>>;
226
227 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
228 let mut rng = generators::rng();
229 let blocks = random_block_range(
230 &mut rng,
231 input.checkpoint().block_number..=input.target(),
232 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 1..3, ..Default::default() },
233 );
234 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
235 self.db.insert_transaction_senders(
236 blocks.iter().flat_map(|block| block.body().transactions.iter()).enumerate().map(
237 |(i, tx)| (i as u64, tx.recover_signer().expect("failed to recover signer")),
238 ),
239 )?;
240 Ok(blocks)
241 }
242
243 fn validate_execution(
244 &self,
245 input: ExecInput,
246 output: Option<ExecOutput>,
247 ) -> Result<(), TestRunnerError> {
248 if let Some(output) = output {
249 let start_block = input.next_block();
250 let end_block = output.checkpoint.block_number;
251
252 if start_block > end_block {
253 return Ok(())
254 }
255
256 let provider = self.db.factory.provider()?;
257
258 assert!(output.done);
259 assert_eq!(
260 output.checkpoint.block_number,
261 provider
262 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
263 .expect("prune checkpoint must exist")
264 .block_number
265 .unwrap_or_default()
266 );
267
268 let tx_range =
270 provider.transaction_range_by_block_range(start_block..=end_block)?;
271 let senders = self.db.factory.provider()?.senders_by_tx_range(tx_range)?;
272 assert!(senders.is_empty());
273 }
274 Ok(())
275 }
276 }
277
278 impl UnwindStageTestRunner for PruneTestRunner {
279 fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
280 Ok(())
281 }
282 }
283}