1use reth_db_api::{table::Value, transaction::DbTxMut};
2use reth_primitives_traits::NodePrimitives;
3use reth_provider::{
4 BlockReader, ChainStateBlockReader, DBProvider, PruneCheckpointReader, PruneCheckpointWriter,
5 RocksDBProviderFactory, StageCheckpointReader, StaticFileProviderFactory,
6};
7use reth_prune::{
8 PruneMode, PruneModes, PruneSegment, PrunerBuilder, SegmentOutput, SegmentOutputCheckpoint,
9};
10use reth_stages_api::{
11 ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
12};
13use reth_storage_api::{ChangeSetReader, StorageChangeSetReader, StorageSettingsCache};
14use tracing::info;
15
16#[derive(Debug)]
28pub struct PruneStage {
29 prune_modes: PruneModes,
30 commit_threshold: usize,
31}
32
33impl PruneStage {
34 pub const fn new(prune_modes: PruneModes, commit_threshold: usize) -> Self {
36 Self { prune_modes, commit_threshold }
37 }
38}
39
40impl<Provider> Stage<Provider> for PruneStage
41where
42 Provider: DBProvider<Tx: DbTxMut>
43 + PruneCheckpointReader
44 + PruneCheckpointWriter
45 + BlockReader
46 + ChainStateBlockReader
47 + StageCheckpointReader
48 + StaticFileProviderFactory<
49 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
50 > + StorageSettingsCache
51 + ChangeSetReader
52 + StorageChangeSetReader
53 + RocksDBProviderFactory,
54{
55 fn id(&self) -> StageId {
56 StageId::Prune
57 }
58
59 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
60 let mut pruner = PrunerBuilder::default()
61 .segments(self.prune_modes.clone())
62 .delete_limit(self.commit_threshold)
63 .build::<Provider>(provider.static_file_provider());
64
65 let result = pruner.run_with_provider(provider, input.target())?;
66 if result.progress.is_finished() {
67 Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
68 } else {
69 if let Some((last_segment, last_segment_output)) = result.segments.last() {
70 match last_segment_output {
71 SegmentOutput {
72 progress,
73 pruned,
74 checkpoint:
75 checkpoint @ Some(SegmentOutputCheckpoint { block_number: Some(_), .. }),
76 } => {
77 info!(
78 target: "sync::stages::prune::exec",
79 ?last_segment,
80 ?progress,
81 ?pruned,
82 ?checkpoint,
83 "Last segment has more data to prune"
84 )
85 }
86 SegmentOutput { progress, pruned, checkpoint: _ } => {
87 info!(
88 target: "sync::stages::prune::exec",
89 ?last_segment,
90 ?progress,
91 ?pruned,
92 "Last segment has more data to prune"
93 )
94 }
95 }
96 }
97 Ok(ExecOutput { checkpoint: input.checkpoint(), done: false })
100 }
101 }
102
103 fn unwind(
104 &mut self,
105 provider: &Provider,
106 input: UnwindInput,
107 ) -> Result<UnwindOutput, StageError> {
108 let prune_checkpoints = provider.get_prune_checkpoints()?;
111 let unwind_to_last_tx =
112 provider.block_body_indices(input.unwind_to)?.map(|i| i.last_tx_num());
113
114 for (segment, mut checkpoint) in prune_checkpoints {
115 if let Some(block) = checkpoint.block_number &&
117 input.unwind_to < block
118 {
119 checkpoint.block_number = Some(input.unwind_to);
120 checkpoint.tx_number = unwind_to_last_tx;
121 provider.save_prune_checkpoint(segment, checkpoint)?;
122 }
123 }
124 Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
125 }
126}
127
128#[derive(Debug)]
136pub struct PruneSenderRecoveryStage(PruneStage);
137
138impl PruneSenderRecoveryStage {
139 pub fn new(prune_mode: PruneMode, commit_threshold: usize) -> Self {
141 Self(PruneStage::new(
142 PruneModes { sender_recovery: Some(prune_mode), ..PruneModes::default() },
143 commit_threshold,
144 ))
145 }
146}
147
148impl<Provider> Stage<Provider> for PruneSenderRecoveryStage
149where
150 Provider: DBProvider<Tx: DbTxMut>
151 + PruneCheckpointReader
152 + PruneCheckpointWriter
153 + BlockReader
154 + ChainStateBlockReader
155 + StageCheckpointReader
156 + StaticFileProviderFactory<
157 Primitives: NodePrimitives<SignedTx: Value, Receipt: Value, BlockHeader: Value>,
158 > + StorageSettingsCache
159 + ChangeSetReader
160 + StorageChangeSetReader
161 + RocksDBProviderFactory,
162{
163 fn id(&self) -> StageId {
164 StageId::PruneSenderRecovery
165 }
166
167 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
168 let mut result = self.0.execute(provider, input)?;
169
170 if !result.done {
172 let checkpoint = provider
173 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
174 .ok_or(StageError::MissingPruneCheckpoint(PruneSegment::SenderRecovery))?;
175
176 result.checkpoint = StageCheckpoint::new(checkpoint.block_number.unwrap_or_default());
179 }
180
181 Ok(result)
182 }
183
184 fn unwind(
185 &mut self,
186 provider: &Provider,
187 input: UnwindInput,
188 ) -> Result<UnwindOutput, StageError> {
189 self.0.unwind(provider, input)
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::test_utils::{
197 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
198 TestRunnerError, TestStageDB, UnwindStageTestRunner,
199 };
200 use alloy_primitives::B256;
201 use reth_ethereum_primitives::Block;
202 use reth_primitives_traits::{SealedBlock, SignerRecoverable};
203 use reth_provider::{
204 providers::StaticFileWriter, TransactionsProvider, TransactionsProviderExt,
205 };
206 use reth_prune::PruneMode;
207 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
208
209 stage_test_suite_ext!(PruneTestRunner, prune);
210
211 #[derive(Default)]
212 struct PruneTestRunner {
213 db: TestStageDB,
214 }
215
216 impl StageTestRunner for PruneTestRunner {
217 type S = PruneStage;
218
219 fn db(&self) -> &TestStageDB {
220 &self.db
221 }
222
223 fn stage(&self) -> Self::S {
224 PruneStage {
225 prune_modes: PruneModes {
226 sender_recovery: Some(PruneMode::Full),
227 ..Default::default()
228 },
229 commit_threshold: usize::MAX,
230 }
231 }
232 }
233
234 impl ExecuteStageTestRunner for PruneTestRunner {
235 type Seed = Vec<SealedBlock<Block>>;
236
237 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
238 let mut rng = generators::rng();
239 let blocks = random_block_range(
240 &mut rng,
241 input.checkpoint().block_number..=input.target(),
242 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 1..3, ..Default::default() },
243 );
244 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
245 self.db.insert_transaction_senders(
246 blocks.iter().flat_map(|block| block.body().transactions.iter()).enumerate().map(
247 |(i, tx)| (i as u64, tx.recover_signer().expect("failed to recover signer")),
248 ),
249 )?;
250 Ok(blocks)
251 }
252
253 fn validate_execution(
254 &self,
255 input: ExecInput,
256 output: Option<ExecOutput>,
257 ) -> Result<(), TestRunnerError> {
258 if let Some(output) = output {
259 let start_block = input.next_block();
260 let end_block = output.checkpoint.block_number;
261
262 if start_block > end_block {
263 return Ok(())
264 }
265
266 let provider = self.db.factory.provider()?;
267
268 assert!(output.done);
269 assert_eq!(
270 output.checkpoint.block_number,
271 provider
272 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
273 .expect("prune checkpoint must exist")
274 .block_number
275 .unwrap_or_default()
276 );
277
278 let tx_range =
280 provider.transaction_range_by_block_range(start_block..=end_block)?;
281 let senders = self.db.factory.provider()?.senders_by_tx_range(tx_range)?;
282 assert!(senders.is_empty());
283 }
284 Ok(())
285 }
286 }
287
288 impl UnwindStageTestRunner for PruneTestRunner {
289 fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
290 Ok(())
291 }
292 }
293}