1use crate::PruneLimiter;
2use reth_db_api::{
3 cursor::{DbCursorRO, DbCursorRW, RangeWalker},
4 table::{DupSort, Table, TableRow},
5 transaction::{DbTx, DbTxMut},
6 DatabaseError,
7};
8use std::{fmt::Debug, ops::RangeBounds};
9use tracing::debug;
10
11#[derive(Debug, Clone, Copy)]
13pub(crate) struct PruneStepResult {
14 done: bool,
16 deleted: bool,
18}
19
20pub(crate) trait DbTxPruneExt: DbTxMut + DbTx {
21 fn clear_table<T: Table>(&self) -> Result<usize, DatabaseError> {
26 let count = self.entries::<T>()?;
27 <Self as DbTxMut>::clear::<T>(self)?;
28 Ok(count)
29 }
30
31 fn prune_table_with_iterator<T: Table>(
35 &self,
36 keys: impl IntoIterator<Item = T::Key>,
37 limiter: &mut PruneLimiter,
38 mut delete_callback: impl FnMut(TableRow<T>),
39 ) -> Result<(usize, bool), DatabaseError> {
40 let mut cursor = self.cursor_write::<T>()?;
41 let mut keys = keys.into_iter().peekable();
42
43 let mut deleted_entries = 0;
44
45 let mut done = true;
46 while keys.peek().is_some() {
47 if limiter.is_limit_reached() {
48 debug!(
49 target: "providers::db",
50 ?limiter,
51 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
52 time_limit = %limiter.is_time_limit_reached(),
53 table = %T::NAME,
54 "Pruning limit reached"
55 );
56 done = false;
57 break
58 }
59
60 let key = keys.next().expect("peek() said Some");
61 let row = cursor.seek_exact(key)?;
62 if let Some(row) = row {
63 cursor.delete_current()?;
64 limiter.increment_deleted_entries_count();
65 deleted_entries += 1;
66 delete_callback(row);
67 }
68 }
69
70 Ok((deleted_entries, done))
71 }
72
73 fn prune_table_with_range<T: Table>(
77 &self,
78 keys: impl RangeBounds<T::Key> + Clone + Debug,
79 limiter: &mut PruneLimiter,
80 mut skip_filter: impl FnMut(&TableRow<T>) -> bool,
81 mut delete_callback: impl FnMut(TableRow<T>),
82 ) -> Result<(usize, bool), DatabaseError> {
83 let mut cursor = self.cursor_write::<T>()?;
84 let mut walker = cursor.walk_range(keys)?;
85
86 let mut deleted_entries = 0;
87
88 let done = loop {
89 if limiter.is_limit_reached() {
92 debug!(
93 target: "providers::db",
94 ?limiter,
95 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
96 time_limit = %limiter.is_time_limit_reached(),
97 table = %T::NAME,
98 "Pruning limit reached"
99 );
100 break false
101 }
102
103 let result = self.prune_table_with_range_step(
104 &mut walker,
105 limiter,
106 &mut skip_filter,
107 &mut delete_callback,
108 )?;
109
110 if result.deleted {
111 deleted_entries += 1;
112 }
113
114 if result.done {
115 break true
116 }
117 };
118
119 Ok((deleted_entries, done))
120 }
121
122 fn prune_table_with_range_step<T: Table>(
128 &self,
129 walker: &mut RangeWalker<'_, T, Self::CursorMut<T>>,
130 limiter: &mut PruneLimiter,
131 skip_filter: &mut impl FnMut(&TableRow<T>) -> bool,
132 delete_callback: &mut impl FnMut(TableRow<T>),
133 ) -> Result<PruneStepResult, DatabaseError> {
134 let Some(res) = walker.next() else {
135 return Ok(PruneStepResult { done: true, deleted: false })
136 };
137
138 let row = res?;
139
140 if skip_filter(&row) {
141 Ok(PruneStepResult { done: false, deleted: false })
142 } else {
143 walker.delete_current()?;
144 limiter.increment_deleted_entries_count();
145 delete_callback(row);
146 Ok(PruneStepResult { done: false, deleted: true })
147 }
148 }
149
150 #[expect(unused)]
154 fn prune_dupsort_table_with_range<T: DupSort>(
155 &self,
156 keys: impl RangeBounds<T::Key> + Clone + Debug,
157 limiter: &mut PruneLimiter,
158 mut delete_callback: impl FnMut(TableRow<T>),
159 ) -> Result<(usize, bool), DatabaseError> {
160 let starting_entries = self.entries::<T>()?;
161 let mut cursor = self.cursor_dup_write::<T>()?;
162 let mut walker = cursor.walk_range(keys)?;
163
164 let done = loop {
165 if limiter.is_limit_reached() {
166 debug!(
167 target: "providers::db",
168 ?limiter,
169 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
170 time_limit = %limiter.is_time_limit_reached(),
171 table = %T::NAME,
172 "Pruning limit reached"
173 );
174 break false
175 }
176
177 let Some(res) = walker.next() else { break true };
178 let row = res?;
179
180 walker.delete_current_duplicates()?;
181 limiter.increment_deleted_entries_count();
182 delete_callback(row);
183 };
184
185 debug!(
186 target: "providers::db",
187 table=?T::NAME,
188 cursor_current=?cursor.current(),
189 "done walking",
190 );
191
192 let ending_entries = self.entries::<T>()?;
193
194 Ok((starting_entries - ending_entries, done))
195 }
196}
197
198impl<Tx> DbTxPruneExt for Tx where Tx: DbTxMut + DbTx {}
199
200#[cfg(test)]
201mod tests {
202 use super::DbTxPruneExt;
203 use crate::PruneLimiter;
204 use reth_db_api::tables;
205 use reth_primitives_traits::SignerRecoverable;
206 use reth_provider::{DBProvider, DatabaseProviderFactory};
207 use reth_stages::test_utils::{StorageKind, TestStageDB};
208 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
209 use std::sync::{
210 atomic::{AtomicUsize, Ordering},
211 Arc,
212 };
213
214 struct CountingIter {
215 data: Vec<u64>,
216 calls: Arc<AtomicUsize>,
217 }
218
219 impl CountingIter {
220 fn new(data: Vec<u64>, calls: Arc<AtomicUsize>) -> Self {
221 Self { data, calls }
222 }
223 }
224
225 struct CountingIntoIter {
226 inner: std::vec::IntoIter<u64>,
227 calls: Arc<AtomicUsize>,
228 }
229
230 impl Iterator for CountingIntoIter {
231 type Item = u64;
232 fn next(&mut self) -> Option<Self::Item> {
233 let res = self.inner.next();
234 self.calls.fetch_add(1, Ordering::SeqCst);
235 res
236 }
237 }
238
239 impl IntoIterator for CountingIter {
240 type Item = u64;
241 type IntoIter = CountingIntoIter;
242 fn into_iter(self) -> Self::IntoIter {
243 CountingIntoIter { inner: self.data.into_iter(), calls: self.calls }
244 }
245 }
246
247 #[test]
248 fn prune_table_with_iterator_early_exit_does_not_overconsume() {
249 let db = TestStageDB::default();
250 let mut rng = generators::rng();
251
252 let blocks = random_block_range(
253 &mut rng,
254 1..=3,
255 BlockRangeParams {
256 parent: Some(alloy_primitives::B256::ZERO),
257 tx_count: 2..3,
258 ..Default::default()
259 },
260 );
261 db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
262
263 let mut tx_senders = Vec::new();
264 for block in &blocks {
265 tx_senders.reserve_exact(block.transaction_count());
266 for transaction in &block.body().transactions {
267 tx_senders.push((
268 tx_senders.len() as u64,
269 transaction.recover_signer().expect("recover signer"),
270 ));
271 }
272 }
273 let total = tx_senders.len();
274 db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
275
276 let provider = db.factory.database_provider_rw().unwrap();
277
278 let calls = Arc::new(AtomicUsize::new(0));
279 let keys: Vec<u64> = (0..total as u64).collect();
280 let counting_iter = CountingIter::new(keys, calls.clone());
281
282 let mut limiter = PruneLimiter::default().set_deleted_entries_limit(2);
283
284 let (pruned, done) = provider
285 .tx_ref()
286 .prune_table_with_iterator::<tables::TransactionSenders>(
287 counting_iter,
288 &mut limiter,
289 |_| {},
290 )
291 .expect("prune");
292
293 assert_eq!(pruned, 2);
294 assert!(!done);
295 assert_eq!(calls.load(Ordering::SeqCst), pruned + 1);
296
297 provider.commit().expect("commit");
298 assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), total - 2);
299 }
300
301 #[test]
302 fn prune_table_with_iterator_consumes_to_end_reports_done() {
303 let db = TestStageDB::default();
304 let mut rng = generators::rng();
305
306 let blocks = random_block_range(
307 &mut rng,
308 1..=2,
309 BlockRangeParams {
310 parent: Some(alloy_primitives::B256::ZERO),
311 tx_count: 1..2,
312 ..Default::default()
313 },
314 );
315 db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
316
317 let mut tx_senders = Vec::new();
318 for block in &blocks {
319 for transaction in &block.body().transactions {
320 tx_senders.push((
321 tx_senders.len() as u64,
322 transaction.recover_signer().expect("recover signer"),
323 ));
324 }
325 }
326 let total = tx_senders.len();
327 db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
328
329 let provider = db.factory.database_provider_rw().unwrap();
330
331 let calls = Arc::new(AtomicUsize::new(0));
332 let keys: Vec<u64> = (0..total as u64).collect();
333 let counting_iter = CountingIter::new(keys, calls.clone());
334
335 let mut limiter = PruneLimiter::default().set_deleted_entries_limit(usize::MAX);
336
337 let (pruned, done) = provider
338 .tx_ref()
339 .prune_table_with_iterator::<tables::TransactionSenders>(
340 counting_iter,
341 &mut limiter,
342 |_| {},
343 )
344 .expect("prune");
345
346 assert_eq!(pruned, total);
347 assert!(done);
348 assert_eq!(calls.load(Ordering::SeqCst), total + 1);
349
350 provider.commit().expect("commit");
351 assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), 0);
352 }
353}