1use reth_db_api::{table::Value, transaction::DbTxMut};
2use reth_primitives_traits::NodePrimitives;
3use reth_provider::{
4 BlockReader, ChainStateBlockReader, DBProvider, PruneCheckpointReader, PruneCheckpointWriter,
5 StageCheckpointReader, StaticFileProviderFactory, StorageSettingsCache,
6};
7use reth_prune::{
8 PruneMode, PruneModes, PruneSegment, PrunerBuilder, SegmentOutput, SegmentOutputCheckpoint,
9};
10use reth_stages_api::{
11 ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
12};
13use tracing::info;
14
15#[derive(Debug)]
27pub struct PruneStage {
28 prune_modes: PruneModes,
29 commit_threshold: usize,
30}
31
32impl PruneStage {
33 pub const fn new(prune_modes: PruneModes, commit_threshold: usize) -> Self {
35 Self { prune_modes, commit_threshold }
36 }
37}
38
39impl<Provider> Stage<Provider> for PruneStage
40where
41 Provider: DBProvider<Tx: DbTxMut>
42 + PruneCheckpointReader
43 + PruneCheckpointWriter
44 + BlockReader
45 + ChainStateBlockReader
46 + StageCheckpointReader
47 + StaticFileProviderFactory<
48 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
49 > + StorageSettingsCache,
50{
51 fn id(&self) -> StageId {
52 StageId::Prune
53 }
54
55 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
56 let mut pruner = PrunerBuilder::default()
57 .segments(self.prune_modes.clone())
58 .delete_limit(self.commit_threshold)
59 .build::<Provider>(provider.static_file_provider());
60
61 let result = pruner.run_with_provider(provider, input.target())?;
62 if result.progress.is_finished() {
63 Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
64 } else {
65 if let Some((last_segment, last_segment_output)) = result.segments.last() {
66 match last_segment_output {
67 SegmentOutput {
68 progress,
69 pruned,
70 checkpoint:
71 checkpoint @ Some(SegmentOutputCheckpoint { block_number: Some(_), .. }),
72 } => {
73 info!(
74 target: "sync::stages::prune::exec",
75 ?last_segment,
76 ?progress,
77 ?pruned,
78 ?checkpoint,
79 "Last segment has more data to prune"
80 )
81 }
82 SegmentOutput { progress, pruned, checkpoint: _ } => {
83 info!(
84 target: "sync::stages::prune::exec",
85 ?last_segment,
86 ?progress,
87 ?pruned,
88 "Last segment has more data to prune"
89 )
90 }
91 }
92 }
93 Ok(ExecOutput { checkpoint: input.checkpoint(), done: false })
96 }
97 }
98
99 fn unwind(
100 &mut self,
101 provider: &Provider,
102 input: UnwindInput,
103 ) -> Result<UnwindOutput, StageError> {
104 let prune_checkpoints = provider.get_prune_checkpoints()?;
107 let unwind_to_last_tx =
108 provider.block_body_indices(input.unwind_to)?.map(|i| i.last_tx_num());
109
110 for (segment, mut checkpoint) in prune_checkpoints {
111 if let Some(block) = checkpoint.block_number &&
113 input.unwind_to < block
114 {
115 checkpoint.block_number = Some(input.unwind_to);
116 checkpoint.tx_number = unwind_to_last_tx;
117 provider.save_prune_checkpoint(segment, checkpoint)?;
118 }
119 }
120 Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
121 }
122}
123
124#[derive(Debug)]
132pub struct PruneSenderRecoveryStage(PruneStage);
133
134impl PruneSenderRecoveryStage {
135 pub fn new(prune_mode: PruneMode, commit_threshold: usize) -> Self {
137 Self(PruneStage::new(
138 PruneModes { sender_recovery: Some(prune_mode), ..PruneModes::default() },
139 commit_threshold,
140 ))
141 }
142}
143
144impl<Provider> Stage<Provider> for PruneSenderRecoveryStage
145where
146 Provider: DBProvider<Tx: DbTxMut>
147 + PruneCheckpointReader
148 + PruneCheckpointWriter
149 + BlockReader
150 + ChainStateBlockReader
151 + StageCheckpointReader
152 + StaticFileProviderFactory<
153 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
154 > + StorageSettingsCache,
155{
156 fn id(&self) -> StageId {
157 StageId::PruneSenderRecovery
158 }
159
160 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
161 let mut result = self.0.execute(provider, input)?;
162
163 if !result.done {
165 let checkpoint = provider
166 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
167 .ok_or(StageError::MissingPruneCheckpoint(PruneSegment::SenderRecovery))?;
168
169 result.checkpoint = StageCheckpoint::new(checkpoint.block_number.unwrap_or_default());
172 }
173
174 Ok(result)
175 }
176
177 fn unwind(
178 &mut self,
179 provider: &Provider,
180 input: UnwindInput,
181 ) -> Result<UnwindOutput, StageError> {
182 self.0.unwind(provider, input)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::test_utils::{
190 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
191 TestRunnerError, TestStageDB, UnwindStageTestRunner,
192 };
193 use alloy_primitives::B256;
194 use reth_ethereum_primitives::Block;
195 use reth_primitives_traits::{SealedBlock, SignerRecoverable};
196 use reth_provider::{
197 providers::StaticFileWriter, TransactionsProvider, TransactionsProviderExt,
198 };
199 use reth_prune::PruneMode;
200 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
201
202 stage_test_suite_ext!(PruneTestRunner, prune);
203
204 #[derive(Default)]
205 struct PruneTestRunner {
206 db: TestStageDB,
207 }
208
209 impl StageTestRunner for PruneTestRunner {
210 type S = PruneStage;
211
212 fn db(&self) -> &TestStageDB {
213 &self.db
214 }
215
216 fn stage(&self) -> Self::S {
217 PruneStage {
218 prune_modes: PruneModes {
219 sender_recovery: Some(PruneMode::Full),
220 ..Default::default()
221 },
222 commit_threshold: usize::MAX,
223 }
224 }
225 }
226
227 impl ExecuteStageTestRunner for PruneTestRunner {
228 type Seed = Vec<SealedBlock<Block>>;
229
230 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
231 let mut rng = generators::rng();
232 let blocks = random_block_range(
233 &mut rng,
234 input.checkpoint().block_number..=input.target(),
235 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 1..3, ..Default::default() },
236 );
237 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
238 self.db.insert_transaction_senders(
239 blocks.iter().flat_map(|block| block.body().transactions.iter()).enumerate().map(
240 |(i, tx)| (i as u64, tx.recover_signer().expect("failed to recover signer")),
241 ),
242 )?;
243 Ok(blocks)
244 }
245
246 fn validate_execution(
247 &self,
248 input: ExecInput,
249 output: Option<ExecOutput>,
250 ) -> Result<(), TestRunnerError> {
251 if let Some(output) = output {
252 let start_block = input.next_block();
253 let end_block = output.checkpoint.block_number;
254
255 if start_block > end_block {
256 return Ok(())
257 }
258
259 let provider = self.db.factory.provider()?;
260
261 assert!(output.done);
262 assert_eq!(
263 output.checkpoint.block_number,
264 provider
265 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
266 .expect("prune checkpoint must exist")
267 .block_number
268 .unwrap_or_default()
269 );
270
271 let tx_range =
273 provider.transaction_range_by_block_range(start_block..=end_block)?;
274 let senders = self.db.factory.provider()?.senders_by_tx_range(tx_range)?;
275 assert!(senders.is_empty());
276 }
277 Ok(())
278 }
279 }
280
281 impl UnwindStageTestRunner for PruneTestRunner {
282 fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
283 Ok(())
284 }
285 }
286}