1use crate::PruneLimiter;
2use reth_db_api::{
3 cursor::{DbCursorRO, DbCursorRW, RangeWalker},
4 table::{Table, TableRow},
5 transaction::DbTxMut,
6 DatabaseError,
7};
8use std::{fmt::Debug, ops::RangeBounds};
9use tracing::debug;
10
11pub(crate) trait DbTxPruneExt: DbTxMut {
12 fn prune_table_with_iterator<T: Table>(
16 &self,
17 keys: impl IntoIterator<Item = T::Key>,
18 limiter: &mut PruneLimiter,
19 mut delete_callback: impl FnMut(TableRow<T>),
20 ) -> Result<(usize, bool), DatabaseError> {
21 let mut cursor = self.cursor_write::<T>()?;
22 let mut keys = keys.into_iter().peekable();
23
24 let mut deleted_entries = 0;
25
26 let mut done = true;
27 while keys.peek().is_some() {
28 if limiter.is_limit_reached() {
29 debug!(
30 target: "providers::db",
31 ?limiter,
32 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
33 time_limit = %limiter.is_time_limit_reached(),
34 table = %T::NAME,
35 "Pruning limit reached"
36 );
37 done = false;
38 break
39 }
40
41 let key = keys.next().expect("peek() said Some");
42 let row = cursor.seek_exact(key)?;
43 if let Some(row) = row {
44 cursor.delete_current()?;
45 limiter.increment_deleted_entries_count();
46 deleted_entries += 1;
47 delete_callback(row);
48 }
49 }
50
51 Ok((deleted_entries, done))
52 }
53
54 fn prune_table_with_range<T: Table>(
58 &self,
59 keys: impl RangeBounds<T::Key> + Clone + Debug,
60 limiter: &mut PruneLimiter,
61 mut skip_filter: impl FnMut(&TableRow<T>) -> bool,
62 mut delete_callback: impl FnMut(TableRow<T>),
63 ) -> Result<(usize, bool), DatabaseError> {
64 let mut cursor = self.cursor_write::<T>()?;
65 let mut walker = cursor.walk_range(keys)?;
66
67 let mut deleted_entries = 0;
68
69 let done = loop {
70 if limiter.is_limit_reached() {
73 debug!(
74 target: "providers::db",
75 ?limiter,
76 deleted_entries_limit = %limiter.is_deleted_entries_limit_reached(),
77 time_limit = %limiter.is_time_limit_reached(),
78 table = %T::NAME,
79 "Pruning limit reached"
80 );
81 break false
82 }
83
84 let done = self.prune_table_with_range_step(
85 &mut walker,
86 limiter,
87 &mut skip_filter,
88 &mut delete_callback,
89 )?;
90
91 if done {
92 break true
93 }
94 deleted_entries += 1;
95 };
96
97 Ok((deleted_entries, done))
98 }
99
100 fn prune_table_with_range_step<T: Table>(
108 &self,
109 walker: &mut RangeWalker<'_, T, Self::CursorMut<T>>,
110 limiter: &mut PruneLimiter,
111 skip_filter: &mut impl FnMut(&TableRow<T>) -> bool,
112 delete_callback: &mut impl FnMut(TableRow<T>),
113 ) -> Result<bool, DatabaseError> {
114 let Some(res) = walker.next() else { return Ok(true) };
115
116 let row = res?;
117
118 if !skip_filter(&row) {
119 walker.delete_current()?;
120 limiter.increment_deleted_entries_count();
121 delete_callback(row);
122 }
123
124 Ok(false)
125 }
126}
127
128impl<Tx> DbTxPruneExt for Tx where Tx: DbTxMut {}
129
130#[cfg(test)]
131mod tests {
132 use super::DbTxPruneExt;
133 use crate::PruneLimiter;
134 use reth_db_api::tables;
135 use reth_primitives_traits::SignerRecoverable;
136 use reth_provider::{DBProvider, DatabaseProviderFactory};
137 use reth_stages::test_utils::{StorageKind, TestStageDB};
138 use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
139 use std::sync::{
140 atomic::{AtomicUsize, Ordering},
141 Arc,
142 };
143
144 struct CountingIter {
145 data: Vec<u64>,
146 calls: Arc<AtomicUsize>,
147 }
148
149 impl CountingIter {
150 fn new(data: Vec<u64>, calls: Arc<AtomicUsize>) -> Self {
151 Self { data, calls }
152 }
153 }
154
155 struct CountingIntoIter {
156 inner: std::vec::IntoIter<u64>,
157 calls: Arc<AtomicUsize>,
158 }
159
160 impl Iterator for CountingIntoIter {
161 type Item = u64;
162 fn next(&mut self) -> Option<Self::Item> {
163 let res = self.inner.next();
164 self.calls.fetch_add(1, Ordering::SeqCst);
165 res
166 }
167 }
168
169 impl IntoIterator for CountingIter {
170 type Item = u64;
171 type IntoIter = CountingIntoIter;
172 fn into_iter(self) -> Self::IntoIter {
173 CountingIntoIter { inner: self.data.into_iter(), calls: self.calls }
174 }
175 }
176
177 #[test]
178 fn prune_table_with_iterator_early_exit_does_not_overconsume() {
179 let db = TestStageDB::default();
180 let mut rng = generators::rng();
181
182 let blocks = random_block_range(
183 &mut rng,
184 1..=3,
185 BlockRangeParams {
186 parent: Some(alloy_primitives::B256::ZERO),
187 tx_count: 2..3,
188 ..Default::default()
189 },
190 );
191 db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
192
193 let mut tx_senders = Vec::new();
194 for block in &blocks {
195 tx_senders.reserve_exact(block.transaction_count());
196 for transaction in &block.body().transactions {
197 tx_senders.push((
198 tx_senders.len() as u64,
199 transaction.recover_signer().expect("recover signer"),
200 ));
201 }
202 }
203 let total = tx_senders.len();
204 db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
205
206 let provider = db.factory.database_provider_rw().unwrap();
207
208 let calls = Arc::new(AtomicUsize::new(0));
209 let keys: Vec<u64> = (0..total as u64).collect();
210 let counting_iter = CountingIter::new(keys, calls.clone());
211
212 let mut limiter = PruneLimiter::default().set_deleted_entries_limit(2);
213
214 let (pruned, done) = provider
215 .tx_ref()
216 .prune_table_with_iterator::<tables::TransactionSenders>(
217 counting_iter,
218 &mut limiter,
219 |_| {},
220 )
221 .expect("prune");
222
223 assert_eq!(pruned, 2);
224 assert!(!done);
225 assert_eq!(calls.load(Ordering::SeqCst), pruned + 1);
226
227 provider.commit().expect("commit");
228 assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), total - 2);
229 }
230
231 #[test]
232 fn prune_table_with_iterator_consumes_to_end_reports_done() {
233 let db = TestStageDB::default();
234 let mut rng = generators::rng();
235
236 let blocks = random_block_range(
237 &mut rng,
238 1..=2,
239 BlockRangeParams {
240 parent: Some(alloy_primitives::B256::ZERO),
241 tx_count: 1..2,
242 ..Default::default()
243 },
244 );
245 db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
246
247 let mut tx_senders = Vec::new();
248 for block in &blocks {
249 for transaction in &block.body().transactions {
250 tx_senders.push((
251 tx_senders.len() as u64,
252 transaction.recover_signer().expect("recover signer"),
253 ));
254 }
255 }
256 let total = tx_senders.len();
257 db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
258
259 let provider = db.factory.database_provider_rw().unwrap();
260
261 let calls = Arc::new(AtomicUsize::new(0));
262 let keys: Vec<u64> = (0..total as u64).collect();
263 let counting_iter = CountingIter::new(keys, calls.clone());
264
265 let mut limiter = PruneLimiter::default().set_deleted_entries_limit(usize::MAX);
266
267 let (pruned, done) = provider
268 .tx_ref()
269 .prune_table_with_iterator::<tables::TransactionSenders>(
270 counting_iter,
271 &mut limiter,
272 |_| {},
273 )
274 .expect("prune");
275
276 assert_eq!(pruned, total);
277 assert!(done);
278 assert_eq!(calls.load(Ordering::SeqCst), total + 1);
279
280 provider.commit().expect("commit");
281 assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), 0);
282 }
283}