1use reth_db_api::{table::Value, transaction::DbTxMut};
2use reth_primitives_traits::NodePrimitives;
3use reth_provider::{
4 BlockReader, ChainStateBlockReader, DBProvider, PruneCheckpointReader, PruneCheckpointWriter,
5 StaticFileProviderFactory,
6};
7use reth_prune::{
8 PruneMode, PruneModes, PruneSegment, PrunerBuilder, SegmentOutput, SegmentOutputCheckpoint,
9};
10use reth_stages_api::{
11 ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
12};
13use tracing::info;
14
15#[derive(Debug)]
27pub struct PruneStage {
28 prune_modes: PruneModes,
29 commit_threshold: usize,
30}
31
32impl PruneStage {
33 pub const fn new(prune_modes: PruneModes, commit_threshold: usize) -> Self {
35 Self { prune_modes, commit_threshold }
36 }
37}
38
39impl<Provider> Stage<Provider> for PruneStage
40where
41 Provider: DBProvider<Tx: DbTxMut>
42 + PruneCheckpointReader
43 + PruneCheckpointWriter
44 + BlockReader
45 + ChainStateBlockReader
46 + StaticFileProviderFactory<
47 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
48 >,
49{
50 fn id(&self) -> StageId {
51 StageId::Prune
52 }
53
54 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
55 let mut pruner = PrunerBuilder::default()
56 .segments(self.prune_modes.clone())
57 .delete_limit(self.commit_threshold)
58 .build::<Provider>(provider.static_file_provider());
59
60 let result = pruner.run_with_provider(provider, input.target())?;
61 if result.progress.is_finished() {
62 Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
63 } else {
64 if let Some((last_segment, last_segment_output)) = result.segments.last() {
65 match last_segment_output {
66 SegmentOutput {
67 progress,
68 pruned,
69 checkpoint:
70 checkpoint @ Some(SegmentOutputCheckpoint { block_number: Some(_), .. }),
71 } => {
72 info!(
73 target: "sync::stages::prune::exec",
74 ?last_segment,
75 ?progress,
76 ?pruned,
77 ?checkpoint,
78 "Last segment has more data to prune"
79 )
80 }
81 SegmentOutput { progress, pruned, checkpoint: _ } => {
82 info!(
83 target: "sync::stages::prune::exec",
84 ?last_segment,
85 ?progress,
86 ?pruned,
87 "Last segment has more data to prune"
88 )
89 }
90 }
91 }
92 Ok(ExecOutput { checkpoint: input.checkpoint(), done: false })
95 }
96 }
97
98 fn unwind(
99 &mut self,
100 provider: &Provider,
101 input: UnwindInput,
102 ) -> Result<UnwindOutput, StageError> {
103 let prune_checkpoints = provider.get_prune_checkpoints()?;
106 let unwind_to_last_tx =
107 provider.block_body_indices(input.unwind_to)?.map(|i| i.last_tx_num());
108
109 for (segment, mut checkpoint) in prune_checkpoints {
110 if let Some(block) = checkpoint.block_number &&
112 input.unwind_to < block
113 {
114 checkpoint.block_number = Some(input.unwind_to);
115 checkpoint.tx_number = unwind_to_last_tx;
116 provider.save_prune_checkpoint(segment, checkpoint)?;
117 }
118 }
119 Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
120 }
121}
122
123#[derive(Debug)]
128pub struct PruneSenderRecoveryStage(PruneStage);
129
130impl PruneSenderRecoveryStage {
131 pub fn new(prune_mode: PruneMode, commit_threshold: usize) -> Self {
133 Self(PruneStage::new(
134 PruneModes { sender_recovery: Some(prune_mode), ..PruneModes::default() },
135 commit_threshold,
136 ))
137 }
138}
139
140impl<Provider> Stage<Provider> for PruneSenderRecoveryStage
141where
142 Provider: DBProvider<Tx: DbTxMut>
143 + PruneCheckpointReader
144 + PruneCheckpointWriter
145 + BlockReader
146 + ChainStateBlockReader
147 + StaticFileProviderFactory<
148 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
149 >,
150{
151 fn id(&self) -> StageId {
152 StageId::PruneSenderRecovery
153 }
154
155 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
156 let mut result = self.0.execute(provider, input)?;
157
158 if !result.done {
160 let checkpoint = provider
161 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
162 .ok_or(StageError::MissingPruneCheckpoint(PruneSegment::SenderRecovery))?;
163
164 result.checkpoint = StageCheckpoint::new(checkpoint.block_number.unwrap_or_default());
167 }
168
169 Ok(result)
170 }
171
172 fn unwind(
173 &mut self,
174 provider: &Provider,
175 input: UnwindInput,
176 ) -> Result<UnwindOutput, StageError> {
177 self.0.unwind(provider, input)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::test_utils::{
185 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
186 TestRunnerError, TestStageDB, UnwindStageTestRunner,
187 };
188 use alloy_primitives::B256;
189 use reth_ethereum_primitives::Block;
190 use reth_primitives_traits::{SealedBlock, SignerRecoverable};
191 use reth_provider::{
192 providers::StaticFileWriter, TransactionsProvider, TransactionsProviderExt,
193 };
194 use reth_prune::PruneMode;
195 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
196
197 stage_test_suite_ext!(PruneTestRunner, prune);
198
199 #[derive(Default)]
200 struct PruneTestRunner {
201 db: TestStageDB,
202 }
203
204 impl StageTestRunner for PruneTestRunner {
205 type S = PruneStage;
206
207 fn db(&self) -> &TestStageDB {
208 &self.db
209 }
210
211 fn stage(&self) -> Self::S {
212 PruneStage {
213 prune_modes: PruneModes {
214 sender_recovery: Some(PruneMode::Full),
215 ..Default::default()
216 },
217 commit_threshold: usize::MAX,
218 }
219 }
220 }
221
222 impl ExecuteStageTestRunner for PruneTestRunner {
223 type Seed = Vec<SealedBlock<Block>>;
224
225 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
226 let mut rng = generators::rng();
227 let blocks = random_block_range(
228 &mut rng,
229 input.checkpoint().block_number..=input.target(),
230 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 1..3, ..Default::default() },
231 );
232 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
233 self.db.insert_transaction_senders(
234 blocks.iter().flat_map(|block| block.body().transactions.iter()).enumerate().map(
235 |(i, tx)| (i as u64, tx.recover_signer().expect("failed to recover signer")),
236 ),
237 )?;
238 Ok(blocks)
239 }
240
241 fn validate_execution(
242 &self,
243 input: ExecInput,
244 output: Option<ExecOutput>,
245 ) -> Result<(), TestRunnerError> {
246 if let Some(output) = output {
247 let start_block = input.next_block();
248 let end_block = output.checkpoint.block_number;
249
250 if start_block > end_block {
251 return Ok(())
252 }
253
254 let provider = self.db.factory.provider()?;
255
256 assert!(output.done);
257 assert_eq!(
258 output.checkpoint.block_number,
259 provider
260 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
261 .expect("prune checkpoint must exist")
262 .block_number
263 .unwrap_or_default()
264 );
265
266 let tx_range =
268 provider.transaction_range_by_block_range(start_block..=end_block)?;
269 let senders = self.db.factory.provider()?.senders_by_tx_range(tx_range)?;
270 assert!(senders.is_empty());
271 }
272 Ok(())
273 }
274 }
275
276 impl UnwindStageTestRunner for PruneTestRunner {
277 fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
278 Ok(())
279 }
280 }
281}