1use crate::PruneLimiter;
2use reth_db_api::{
3 cursor::{DbCursorRO, DbCursorRW, RangeWalker},
4 table::{DupSort, Table, TableRow},
5 transaction::{DbTx, DbTxMut},
6 DatabaseError,
7};
8use std::{fmt::Debug, ops::RangeBounds};
9use tracing::debug;
10
11pub(crate) trait DbTxPruneExt: DbTxMut + DbTx {
12 fn prune_table_with_iterator<T: Table>(
16 &self,
17 keys: impl IntoIterator<Item = T::Key>,
18 limiter: &mut PruneLimiter,
19 mut delete_callback: impl FnMut(TableRow<T>),
20 ) -> Result<(usize, bool), DatabaseError> {
21 let mut cursor = self.cursor_write::<T>()?;
22 let mut keys = keys.into_iter().peekable();
23
24 let mut deleted_entries = 0;
25
26 let mut done = true;
27 while keys.peek().is_some() {
28 if limiter.is_limit_reached() {
29 debug!(
30 target: "providers::db",
31 ?limiter,
32 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
33 time_limit = %limiter.is_time_limit_reached(),
34 table = %T::NAME,
35 "Pruning limit reached"
36 );
37 done = false;
38 break
39 }
40
41 let key = keys.next().expect("peek() said Some");
42 let row = cursor.seek_exact(key)?;
43 if let Some(row) = row {
44 cursor.delete_current()?;
45 limiter.increment_deleted_entries_count();
46 deleted_entries += 1;
47 delete_callback(row);
48 }
49 }
50
51 Ok((deleted_entries, done))
52 }
53
54 fn prune_table_with_range<T: Table>(
58 &self,
59 keys: impl RangeBounds<T::Key> + Clone + Debug,
60 limiter: &mut PruneLimiter,
61 mut skip_filter: impl FnMut(&TableRow<T>) -> bool,
62 mut delete_callback: impl FnMut(TableRow<T>),
63 ) -> Result<(usize, bool), DatabaseError> {
64 let mut cursor = self.cursor_write::<T>()?;
65 let mut walker = cursor.walk_range(keys)?;
66
67 let mut deleted_entries = 0;
68
69 let done = loop {
70 if limiter.is_limit_reached() {
73 debug!(
74 target: "providers::db",
75 ?limiter,
76 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
77 time_limit = %limiter.is_time_limit_reached(),
78 table = %T::NAME,
79 "Pruning limit reached"
80 );
81 break false
82 }
83
84 let done = self.prune_table_with_range_step(
85 &mut walker,
86 limiter,
87 &mut skip_filter,
88 &mut delete_callback,
89 )?;
90
91 if done {
92 break true
93 }
94 deleted_entries += 1;
95 };
96
97 Ok((deleted_entries, done))
98 }
99
100 fn prune_table_with_range_step<T: Table>(
108 &self,
109 walker: &mut RangeWalker<'_, T, Self::CursorMut<T>>,
110 limiter: &mut PruneLimiter,
111 skip_filter: &mut impl FnMut(&TableRow<T>) -> bool,
112 delete_callback: &mut impl FnMut(TableRow<T>),
113 ) -> Result<bool, DatabaseError> {
114 let Some(res) = walker.next() else { return Ok(true) };
115
116 let row = res?;
117
118 if !skip_filter(&row) {
119 walker.delete_current()?;
120 limiter.increment_deleted_entries_count();
121 delete_callback(row);
122 }
123
124 Ok(false)
125 }
126
127 fn prune_dupsort_table_with_range<T: DupSort>(
131 &self,
132 keys: impl RangeBounds<T::Key> + Clone + Debug,
133 limiter: &mut PruneLimiter,
134 mut delete_callback: impl FnMut(TableRow<T>),
135 ) -> Result<(usize, bool), DatabaseError> {
136 let starting_entries = self.entries::<T>()?;
137 let mut cursor = self.cursor_dup_write::<T>()?;
138 let mut walker = cursor.walk_range(keys)?;
139
140 let done = loop {
141 if limiter.is_limit_reached() {
142 debug!(
143 target: "providers::db",
144 ?limiter,
145 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
146 time_limit = %limiter.is_time_limit_reached(),
147 table = %T::NAME,
148 "Pruning limit reached"
149 );
150 break false
151 }
152
153 let Some(res) = walker.next() else { break true };
154 let row = res?;
155
156 walker.delete_current_duplicates()?;
157 limiter.increment_deleted_entries_count();
158 delete_callback(row);
159 };
160
161 debug!(
162 target: "providers::db",
163 table=?T::NAME,
164 cursor_current=?cursor.current(),
165 "done walking",
166 );
167
168 let ending_entries = self.entries::<T>()?;
169
170 Ok((starting_entries - ending_entries, done))
171 }
172}
173
174impl<Tx> DbTxPruneExt for Tx where Tx: DbTxMut + DbTx {}
175
176#[cfg(test)]
177mod tests {
178 use super::DbTxPruneExt;
179 use crate::PruneLimiter;
180 use reth_db_api::tables;
181 use reth_primitives_traits::SignerRecoverable;
182 use reth_provider::{DBProvider, DatabaseProviderFactory};
183 use reth_stages::test_utils::{StorageKind, TestStageDB};
184 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
185 use std::sync::{
186 atomic::{AtomicUsize, Ordering},
187 Arc,
188 };
189
190 struct CountingIter {
191 data: Vec<u64>,
192 calls: Arc<AtomicUsize>,
193 }
194
195 impl CountingIter {
196 fn new(data: Vec<u64>, calls: Arc<AtomicUsize>) -> Self {
197 Self { data, calls }
198 }
199 }
200
201 struct CountingIntoIter {
202 inner: std::vec::IntoIter<u64>,
203 calls: Arc<AtomicUsize>,
204 }
205
206 impl Iterator for CountingIntoIter {
207 type Item = u64;
208 fn next(&mut self) -> Option<Self::Item> {
209 let res = self.inner.next();
210 self.calls.fetch_add(1, Ordering::SeqCst);
211 res
212 }
213 }
214
215 impl IntoIterator for CountingIter {
216 type Item = u64;
217 type IntoIter = CountingIntoIter;
218 fn into_iter(self) -> Self::IntoIter {
219 CountingIntoIter { inner: self.data.into_iter(), calls: self.calls }
220 }
221 }
222
223 #[test]
224 fn prune_table_with_iterator_early_exit_does_not_overconsume() {
225 let db = TestStageDB::default();
226 let mut rng = generators::rng();
227
228 let blocks = random_block_range(
229 &mut rng,
230 1..=3,
231 BlockRangeParams {
232 parent: Some(alloy_primitives::B256::ZERO),
233 tx_count: 2..3,
234 ..Default::default()
235 },
236 );
237 db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
238
239 let mut tx_senders = Vec::new();
240 for block in &blocks {
241 tx_senders.reserve_exact(block.transaction_count());
242 for transaction in &block.body().transactions {
243 tx_senders.push((
244 tx_senders.len() as u64,
245 transaction.recover_signer().expect("recover signer"),
246 ));
247 }
248 }
249 let total = tx_senders.len();
250 db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
251
252 let provider = db.factory.database_provider_rw().unwrap();
253
254 let calls = Arc::new(AtomicUsize::new(0));
255 let keys: Vec<u64> = (0..total as u64).collect();
256 let counting_iter = CountingIter::new(keys, calls.clone());
257
258 let mut limiter = PruneLimiter::default().set_deleted_entries_limit(2);
259
260 let (pruned, done) = provider
261 .tx_ref()
262 .prune_table_with_iterator::<tables::TransactionSenders>(
263 counting_iter,
264 &mut limiter,
265 |_| {},
266 )
267 .expect("prune");
268
269 assert_eq!(pruned, 2);
270 assert!(!done);
271 assert_eq!(calls.load(Ordering::SeqCst), pruned + 1);
272
273 provider.commit().expect("commit");
274 assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), total - 2);
275 }
276
277 #[test]
278 fn prune_table_with_iterator_consumes_to_end_reports_done() {
279 let db = TestStageDB::default();
280 let mut rng = generators::rng();
281
282 let blocks = random_block_range(
283 &mut rng,
284 1..=2,
285 BlockRangeParams {
286 parent: Some(alloy_primitives::B256::ZERO),
287 tx_count: 1..2,
288 ..Default::default()
289 },
290 );
291 db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
292
293 let mut tx_senders = Vec::new();
294 for block in &blocks {
295 for transaction in &block.body().transactions {
296 tx_senders.push((
297 tx_senders.len() as u64,
298 transaction.recover_signer().expect("recover signer"),
299 ));
300 }
301 }
302 let total = tx_senders.len();
303 db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
304
305 let provider = db.factory.database_provider_rw().unwrap();
306
307 let calls = Arc::new(AtomicUsize::new(0));
308 let keys: Vec<u64> = (0..total as u64).collect();
309 let counting_iter = CountingIter::new(keys, calls.clone());
310
311 let mut limiter = PruneLimiter::default().set_deleted_entries_limit(usize::MAX);
312
313 let (pruned, done) = provider
314 .tx_ref()
315 .prune_table_with_iterator::<tables::TransactionSenders>(
316 counting_iter,
317 &mut limiter,
318 |_| {},
319 )
320 .expect("prune");
321
322 assert_eq!(pruned, total);
323 assert!(done);
324 assert_eq!(calls.load(Ordering::SeqCst), total + 1);
325
326 provider.commit().expect("commit");
327 assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), 0);
328 }
329}