1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4 test_utils::{create_test_rw_db, create_test_rw_db_with_path, create_test_static_files_dir},
5 DatabaseEnv,
6};
7use reth_db_api::{
8 common::KeyValue,
9 cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
10 database::Database,
11 models::{AccountBeforeTx, StoredBlockBodyIndices},
12 table::Table,
13 tables,
14 transaction::{DbTx, DbTxMut},
15 DatabaseError as DbError,
16};
17use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
18use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
19use reth_provider::{
20 providers::{StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter},
21 test_utils::MockNodeTypesWithDB,
22 HistoryWriter, ProviderError, ProviderFactory, StaticFileProviderFactory, StatsReader,
23};
24use reth_static_file_types::StaticFileSegment;
25use reth_storage_errors::provider::ProviderResult;
26use reth_testing_utils::generators::ChangeSet;
27use std::{collections::BTreeMap, fmt::Debug, path::Path};
28use tempfile::TempDir;
29
30#[derive(Debug)]
32pub struct TestStageDB {
33 pub factory: ProviderFactory<MockNodeTypesWithDB>,
34 pub temp_static_files_dir: TempDir,
35}
36
37impl Default for TestStageDB {
38 fn default() -> Self {
40 let (static_dir, static_dir_path) = create_test_static_files_dir();
41 Self {
42 temp_static_files_dir: static_dir,
43 factory: ProviderFactory::new(
44 create_test_rw_db(),
45 MAINNET.clone(),
46 StaticFileProvider::read_write(static_dir_path).unwrap(),
47 )
48 .expect("failed to create test provider factory"),
49 }
50 }
51}
52
53impl TestStageDB {
54 pub fn new(path: &Path) -> Self {
55 let (static_dir, static_dir_path) = create_test_static_files_dir();
56
57 Self {
58 temp_static_files_dir: static_dir,
59 factory: ProviderFactory::new(
60 create_test_rw_db_with_path(path),
61 MAINNET.clone(),
62 StaticFileProvider::read_write(static_dir_path).unwrap(),
63 )
64 .expect("failed to create test provider factory"),
65 }
66 }
67
68 pub fn commit<F>(&self, f: F) -> ProviderResult<()>
70 where
71 F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
72 {
73 let tx = self.factory.provider_rw()?;
74 f(tx.tx_ref())?;
75 tx.commit().expect("failed to commit");
76 Ok(())
77 }
78
79 pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
81 where
82 F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
83 {
84 f(self.factory.provider()?.tx_ref())
85 }
86
87 pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
89 self.query(|tx| {
90 let last = tx.cursor_read::<T>()?.last()?;
91 Ok(last.is_none())
92 })
93 }
94
95 pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
97 where
98 T::Key: Default + Ord,
99 {
100 self.query(|tx| {
101 Ok(tx
102 .cursor_read::<T>()?
103 .walk(Some(T::Key::default()))?
104 .collect::<Result<Vec<_>, DbError>>()?)
105 })
106 }
107
108 pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
110 self.factory.provider()?.count_entries::<T>()
111 }
112
113 pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
116 where
117 T: Table,
118 F: FnMut(T::Key) -> BlockNumber,
119 {
120 self.query(|tx| {
121 let mut cursor = tx.cursor_read::<T>()?;
122 if let Some((key, _)) = cursor.last()? {
123 assert!(selector(key) <= num);
124 }
125 Ok(())
126 })
127 }
128
129 pub fn ensure_no_entry_above_by_value<T, F>(
132 &self,
133 num: u64,
134 mut selector: F,
135 ) -> ProviderResult<()>
136 where
137 T: Table,
138 F: FnMut(T::Value) -> BlockNumber,
139 {
140 self.query(|tx| {
141 let mut cursor = tx.cursor_read::<T>()?;
142 let mut rev_walker = cursor.walk_back(None)?;
143 while let Some((_, value)) = rev_walker.next().transpose()? {
144 assert!(selector(value) <= num);
145 }
146 Ok(())
147 })
148 }
149
150 pub fn insert_header<TX: DbTx + DbTxMut>(
152 writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
153 tx: &TX,
154 header: &SealedHeader,
155 ) -> ProviderResult<()> {
156 if let Some(writer) = writer {
157 let segment_header = writer.user_header();
160 if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
161 for block_number in 0..header.number {
162 let mut prev = header.clone_header();
163 prev.number = block_number;
164 writer.append_header(&prev, &B256::ZERO)?;
165 }
166 }
167
168 writer.append_header(header.header(), &header.hash())?;
169 } else {
170 tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
171 tx.put::<tables::Headers>(header.number, header.header().clone())?;
172 }
173
174 tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
175 Ok(())
176 }
177
178 fn insert_headers_inner<'a, I>(&self, headers: I) -> ProviderResult<()>
179 where
180 I: IntoIterator<Item = &'a SealedHeader>,
181 {
182 let provider = self.factory.static_file_provider();
183 let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
184 let tx = self.factory.provider_rw()?.into_tx();
185
186 for header in headers {
187 Self::insert_header(Some(&mut writer), &tx, header)?;
188 }
189
190 writer.commit()?;
191 tx.commit()?;
192
193 Ok(())
194 }
195
196 pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
199 where
200 I: IntoIterator<Item = &'a SealedHeader>,
201 {
202 self.insert_headers_inner::<I>(headers)
203 }
204
205 pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
213 where
214 I: IntoIterator<Item = &'a SealedBlock<Block>>,
215 {
216 let provider = self.factory.static_file_provider();
217
218 let tx = self.factory.provider_rw().unwrap().into_tx();
219 let mut next_tx_num = storage_kind.tx_offset();
220
221 let blocks = blocks.into_iter().collect::<Vec<_>>();
222
223 {
224 let mut headers_writer = storage_kind
225 .is_static()
226 .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
227
228 blocks.iter().try_for_each(|block| {
229 Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header())
230 })?;
231
232 if let Some(mut writer) = headers_writer {
233 writer.commit()?;
234 }
235 }
236
237 {
238 let mut txs_writer = storage_kind
239 .is_static()
240 .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
241
242 blocks.into_iter().try_for_each(|block| {
243 let block_body_indices = StoredBlockBodyIndices {
245 first_tx_num: next_tx_num,
246 tx_count: block.transaction_count() as u64,
247 };
248
249 if !block.body().transactions.is_empty() {
250 tx.put::<tables::TransactionBlocks>(
251 block_body_indices.last_tx_num(),
252 block.number,
253 )?;
254 }
255 tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
256
257 let res = block.body().transactions.iter().try_for_each(|body_tx| {
258 if let Some(txs_writer) = &mut txs_writer {
259 txs_writer.append_transaction(next_tx_num, body_tx)?;
260 } else {
261 tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
262 }
263 next_tx_num += 1;
264 Ok::<(), ProviderError>(())
265 });
266
267 if let Some(txs_writer) = &mut txs_writer {
268 let segment_header = txs_writer.user_header();
271 if segment_header.block_end().is_none() &&
272 segment_header.expected_block_start() == 0
273 {
274 for block in 0..block.number {
275 txs_writer.increment_block(block)?;
276 }
277 }
278 txs_writer.increment_block(block.number)?;
279 }
280 res
281 })?;
282
283 if let Some(txs_writer) = &mut txs_writer {
284 txs_writer.commit()?;
285 }
286 }
287
288 tx.commit()?;
289
290 Ok(())
291 }
292
293 pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
294 where
295 I: IntoIterator<Item = (TxHash, TxNumber)>,
296 {
297 self.commit(|tx| {
298 tx_hash_numbers.into_iter().try_for_each(|(tx_hash, tx_num)| {
299 Ok(tx.put::<tables::TransactionHashNumbers>(tx_hash, tx_num)?)
301 })
302 })
303 }
304
305 pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
307 where
308 I: IntoIterator<Item = (TxNumber, Receipt)>,
309 {
310 self.commit(|tx| {
311 receipts.into_iter().try_for_each(|(tx_num, receipt)| {
312 Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
314 })
315 })
316 }
317
318 pub fn insert_receipts_by_block<I, J>(
321 &self,
322 receipts: I,
323 storage_kind: StorageKind,
324 ) -> ProviderResult<()>
325 where
326 I: IntoIterator<Item = (BlockNumber, J)>,
327 J: IntoIterator<Item = (TxNumber, Receipt)>,
328 {
329 match storage_kind {
330 StorageKind::Database(_) => self.commit(|tx| {
331 receipts.into_iter().try_for_each(|(_, receipts)| {
332 for (tx_num, receipt) in receipts {
333 tx.put::<tables::Receipts>(tx_num, receipt)?;
334 }
335 Ok(())
336 })
337 }),
338 StorageKind::Static => {
339 let provider = self.factory.static_file_provider();
340 let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
341 let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
342 writer.increment_block(block_num)?;
343 writer.append_receipts(receipts.into_iter().map(Ok))?;
344 Ok(())
345 });
346 writer.commit_without_sync_all()?;
347 res
348 }
349 }
350 }
351
352 pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
353 where
354 I: IntoIterator<Item = (TxNumber, Address)>,
355 {
356 self.commit(|tx| {
357 transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
358 Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
360 })
361 })
362 }
363
364 pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
366 where
367 I: IntoIterator<Item = (Address, (Account, S))>,
368 S: IntoIterator<Item = StorageEntry>,
369 {
370 self.commit(|tx| {
371 accounts.into_iter().try_for_each(|(address, (account, storage))| {
372 let hashed_address = keccak256(address);
373
374 tx.put::<tables::PlainAccountState>(address, account)?;
376 tx.put::<tables::HashedAccounts>(hashed_address, account)?;
377
378 storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
380 let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
381
382 let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
383 if cursor
384 .seek_by_key_subkey(address, entry.key)?
385 .filter(|e| e.key == entry.key)
386 .is_some()
387 {
388 cursor.delete_current()?;
389 }
390 cursor.upsert(address, &entry)?;
391
392 let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
393 if cursor
394 .seek_by_key_subkey(hashed_address, hashed_entry.key)?
395 .filter(|e| e.key == hashed_entry.key)
396 .is_some()
397 {
398 cursor.delete_current()?;
399 }
400 cursor.upsert(hashed_address, &hashed_entry)?;
401
402 Ok(())
403 })
404 })
405 })
406 }
407
408 pub fn insert_changesets<I>(
410 &self,
411 changesets: I,
412 block_offset: Option<u64>,
413 ) -> ProviderResult<()>
414 where
415 I: IntoIterator<Item = ChangeSet>,
416 {
417 let offset = block_offset.unwrap_or_default();
418 self.commit(|tx| {
419 changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
420 changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
421 let block = offset + block as u64;
422 tx.put::<tables::AccountChangeSets>(
424 block,
425 AccountBeforeTx { address, info: Some(old_account) },
426 )?;
427
428 let block_address = (block, address).into();
429
430 old_storage.into_iter().try_for_each(|entry| {
432 Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
433 })
434 })
435 })
436 })
437 }
438
439 pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
440 where
441 I: IntoIterator<Item = ChangeSet>,
442 {
443 let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
444 let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
445
446 for (block, changeset) in changesets.into_iter().enumerate() {
447 for (address, _, storage_entries) in changeset {
448 accounts.entry(address).or_default().push(block as u64);
449 for storage_entry in storage_entries {
450 storages.entry((address, storage_entry.key)).or_default().push(block as u64);
451 }
452 }
453 }
454
455 let provider_rw = self.factory.provider_rw()?;
456 provider_rw.insert_account_history_index(accounts)?;
457 provider_rw.insert_storage_history_index(storages)?;
458 provider_rw.commit()?;
459
460 Ok(())
461 }
462}
463
464#[derive(Debug)]
466pub enum StorageKind {
467 Database(Option<u64>),
468 Static,
469}
470
471impl StorageKind {
472 #[expect(dead_code)]
473 const fn is_database(&self) -> bool {
474 matches!(self, Self::Database(_))
475 }
476
477 const fn is_static(&self) -> bool {
478 matches!(self, Self::Static)
479 }
480
481 fn tx_offset(&self) -> u64 {
482 if let Self::Database(offset) = self {
483 return offset.unwrap_or_default();
484 }
485 0
486 }
487}