1use reth_db_api::{table::Value, transaction::DbTxMut};
2use reth_primitives_traits::NodePrimitives;
3use reth_provider::{
4 BlockReader, DBProvider, PruneCheckpointReader, PruneCheckpointWriter,
5 StaticFileProviderFactory,
6};
7use reth_prune::{
8 PruneMode, PruneModes, PruneSegment, PrunerBuilder, SegmentOutput, SegmentOutputCheckpoint,
9};
10use reth_stages_api::{
11 ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
12};
13use tracing::info;
14
15#[derive(Debug)]
27pub struct PruneStage {
28 prune_modes: PruneModes,
29 commit_threshold: usize,
30}
31
32impl PruneStage {
33 pub const fn new(prune_modes: PruneModes, commit_threshold: usize) -> Self {
35 Self { prune_modes, commit_threshold }
36 }
37}
38
39impl<Provider> Stage<Provider> for PruneStage
40where
41 Provider: DBProvider<Tx: DbTxMut>
42 + PruneCheckpointReader
43 + PruneCheckpointWriter
44 + BlockReader
45 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value, Receipt: Value>>,
46{
47 fn id(&self) -> StageId {
48 StageId::Prune
49 }
50
51 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
52 let mut pruner = PrunerBuilder::default()
53 .segments(self.prune_modes.clone())
54 .delete_limit(self.commit_threshold)
55 .build::<Provider>(provider.static_file_provider());
56
57 let result = pruner.run_with_provider(provider, input.target())?;
58 if result.progress.is_finished() {
59 Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
60 } else {
61 if let Some((last_segment, last_segment_output)) = result.segments.last() {
62 match last_segment_output {
63 SegmentOutput {
64 progress,
65 pruned,
66 checkpoint:
67 checkpoint @ Some(SegmentOutputCheckpoint { block_number: Some(_), .. }),
68 } => {
69 info!(
70 target: "sync::stages::prune::exec",
71 ?last_segment,
72 ?progress,
73 ?pruned,
74 ?checkpoint,
75 "Last segment has more data to prune"
76 )
77 }
78 SegmentOutput { progress, pruned, checkpoint: _ } => {
79 info!(
80 target: "sync::stages::prune::exec",
81 ?last_segment,
82 ?progress,
83 ?pruned,
84 "Last segment has more data to prune"
85 )
86 }
87 }
88 }
89 Ok(ExecOutput { checkpoint: input.checkpoint(), done: false })
92 }
93 }
94
95 fn unwind(
96 &mut self,
97 provider: &Provider,
98 input: UnwindInput,
99 ) -> Result<UnwindOutput, StageError> {
100 let prune_checkpoints = provider.get_prune_checkpoints()?;
103 for (segment, mut checkpoint) in prune_checkpoints {
104 checkpoint.block_number = Some(input.unwind_to);
105 provider.save_prune_checkpoint(segment, checkpoint)?;
106 }
107 Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
108 }
109}
110
111#[derive(Debug)]
116pub struct PruneSenderRecoveryStage(PruneStage);
117
118impl PruneSenderRecoveryStage {
119 pub fn new(prune_mode: PruneMode, commit_threshold: usize) -> Self {
121 Self(PruneStage::new(
122 PruneModes { sender_recovery: Some(prune_mode), ..PruneModes::none() },
123 commit_threshold,
124 ))
125 }
126}
127
128impl<Provider> Stage<Provider> for PruneSenderRecoveryStage
129where
130 Provider: DBProvider<Tx: DbTxMut>
131 + PruneCheckpointReader
132 + PruneCheckpointWriter
133 + BlockReader
134 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value, Receipt: Value>>,
135{
136 fn id(&self) -> StageId {
137 StageId::PruneSenderRecovery
138 }
139
140 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
141 let mut result = self.0.execute(provider, input)?;
142
143 if !result.done {
145 let checkpoint = provider
146 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
147 .ok_or(StageError::MissingPruneCheckpoint(PruneSegment::SenderRecovery))?;
148
149 result.checkpoint = StageCheckpoint::new(checkpoint.block_number.unwrap_or_default());
152 }
153
154 Ok(result)
155 }
156
157 fn unwind(
158 &mut self,
159 provider: &Provider,
160 input: UnwindInput,
161 ) -> Result<UnwindOutput, StageError> {
162 self.0.unwind(provider, input)
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::test_utils::{
170 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
171 TestRunnerError, TestStageDB, UnwindStageTestRunner,
172 };
173 use alloy_primitives::B256;
174 use reth_ethereum_primitives::Block;
175 use reth_primitives_traits::SealedBlock;
176 use reth_provider::{
177 providers::StaticFileWriter, TransactionsProvider, TransactionsProviderExt,
178 };
179 use reth_prune::PruneMode;
180 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
181
182 stage_test_suite_ext!(PruneTestRunner, prune);
183
184 #[derive(Default)]
185 struct PruneTestRunner {
186 db: TestStageDB,
187 }
188
189 impl StageTestRunner for PruneTestRunner {
190 type S = PruneStage;
191
192 fn db(&self) -> &TestStageDB {
193 &self.db
194 }
195
196 fn stage(&self) -> Self::S {
197 PruneStage {
198 prune_modes: PruneModes {
199 sender_recovery: Some(PruneMode::Full),
200 ..Default::default()
201 },
202 commit_threshold: usize::MAX,
203 }
204 }
205 }
206
207 impl ExecuteStageTestRunner for PruneTestRunner {
208 type Seed = Vec<SealedBlock<Block>>;
209
210 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
211 let mut rng = generators::rng();
212 let blocks = random_block_range(
213 &mut rng,
214 input.checkpoint().block_number..=input.target(),
215 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 1..3, ..Default::default() },
216 );
217 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
218 self.db.insert_transaction_senders(
219 blocks.iter().flat_map(|block| block.body().transactions.iter()).enumerate().map(
220 |(i, tx)| (i as u64, tx.recover_signer().expect("failed to recover signer")),
221 ),
222 )?;
223 Ok(blocks)
224 }
225
226 fn validate_execution(
227 &self,
228 input: ExecInput,
229 output: Option<ExecOutput>,
230 ) -> Result<(), TestRunnerError> {
231 if let Some(output) = output {
232 let start_block = input.next_block();
233 let end_block = output.checkpoint.block_number;
234
235 if start_block > end_block {
236 return Ok(())
237 }
238
239 let provider = self.db.factory.provider()?;
240
241 assert!(output.done);
242 assert_eq!(
243 output.checkpoint.block_number,
244 provider
245 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
246 .expect("prune checkpoint must exist")
247 .block_number
248 .unwrap_or_default()
249 );
250
251 let tx_range =
253 provider.transaction_range_by_block_range(start_block..=end_block)?;
254 let senders = self.db.factory.provider()?.senders_by_tx_range(tx_range)?;
255 assert!(senders.is_empty());
256 }
257 Ok(())
258 }
259 }
260
261 impl UnwindStageTestRunner for PruneTestRunner {
262 fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
263 Ok(())
264 }
265 }
266}