1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4 test_utils::{
5 create_test_rocksdb_dir, create_test_rw_db, create_test_rw_db_with_path,
6 create_test_static_files_dir,
7 },
8 DatabaseEnv,
9};
10use reth_db_api::{
11 common::KeyValue,
12 cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
13 database::Database,
14 models::{AccountBeforeTx, StoredBlockBodyIndices},
15 table::Table,
16 tables,
17 transaction::{DbTx, DbTxMut},
18 DatabaseError as DbError,
19};
20use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
21use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
22use reth_provider::{
23 providers::{
24 RocksDBProvider, StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter,
25 },
26 test_utils::MockNodeTypesWithDB,
27 DatabaseProviderFactory, EitherWriter, HistoryWriter, ProviderError, ProviderFactory,
28 RocksBatchArg, StaticFileProviderFactory, StatsReader,
29};
30use reth_static_file_types::StaticFileSegment;
31use reth_storage_errors::provider::ProviderResult;
32use reth_testing_utils::generators::ChangeSet;
33use std::{collections::BTreeMap, fmt::Debug, path::Path};
34use tempfile::TempDir;
35
36#[derive(Debug)]
38pub struct TestStageDB {
39 pub factory: ProviderFactory<MockNodeTypesWithDB>,
40 pub temp_static_files_dir: TempDir,
41}
42
43impl Default for TestStageDB {
44 fn default() -> Self {
46 let (static_dir, static_dir_path) = create_test_static_files_dir();
47 let (_, rocksdb_dir_path) = create_test_rocksdb_dir();
48 Self {
49 temp_static_files_dir: static_dir,
50 factory: ProviderFactory::new(
51 create_test_rw_db(),
52 MAINNET.clone(),
53 StaticFileProvider::read_write(static_dir_path).unwrap(),
54 RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
55 )
56 .expect("failed to create test provider factory"),
57 }
58 }
59}
60
61impl TestStageDB {
62 pub fn new(path: &Path) -> Self {
63 let (static_dir, static_dir_path) = create_test_static_files_dir();
64 let (_, rocksdb_dir_path) = create_test_rocksdb_dir();
65
66 Self {
67 temp_static_files_dir: static_dir,
68 factory: ProviderFactory::new(
69 create_test_rw_db_with_path(path),
70 MAINNET.clone(),
71 StaticFileProvider::read_write(static_dir_path).unwrap(),
72 RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
73 )
74 .expect("failed to create test provider factory"),
75 }
76 }
77
78 pub fn commit<F>(&self, f: F) -> ProviderResult<()>
80 where
81 F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
82 {
83 let tx = self.factory.provider_rw()?;
84 f(tx.tx_ref())?;
85 tx.commit().expect("failed to commit");
86 Ok(())
87 }
88
89 pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
91 where
92 F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
93 {
94 f(self.factory.provider()?.tx_ref())
95 }
96
97 pub fn query_with_provider<F, Ok>(&self, f: F) -> ProviderResult<Ok>
100 where
101 F: FnOnce(
102 <ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::Provider,
103 ) -> ProviderResult<Ok>,
104 {
105 f(self.factory.provider()?)
106 }
107
108 pub fn commit_with_provider<F>(&self, f: F) -> ProviderResult<()>
110 where
111 F: FnOnce(
112 &<ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::ProviderRW,
113 ) -> ProviderResult<()>,
114 {
115 let provider = self.factory.provider_rw()?;
116 f(&provider)?;
117 provider.commit().expect("failed to commit");
118 Ok(())
119 }
120
121 pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
123 self.query(|tx| {
124 let last = tx.cursor_read::<T>()?.last()?;
125 Ok(last.is_none())
126 })
127 }
128
129 pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
131 where
132 T::Key: Default + Ord,
133 {
134 self.query(|tx| {
135 Ok(tx
136 .cursor_read::<T>()?
137 .walk(Some(T::Key::default()))?
138 .collect::<Result<Vec<_>, DbError>>()?)
139 })
140 }
141
142 pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
144 self.factory.provider()?.count_entries::<T>()
145 }
146
147 pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
150 where
151 T: Table,
152 F: FnMut(T::Key) -> BlockNumber,
153 {
154 self.query(|tx| {
155 let mut cursor = tx.cursor_read::<T>()?;
156 if let Some((key, _)) = cursor.last()? {
157 assert!(selector(key) <= num);
158 }
159 Ok(())
160 })
161 }
162
163 pub fn ensure_no_entry_above_by_value<T, F>(
166 &self,
167 num: u64,
168 mut selector: F,
169 ) -> ProviderResult<()>
170 where
171 T: Table,
172 F: FnMut(T::Value) -> BlockNumber,
173 {
174 self.query(|tx| {
175 let mut cursor = tx.cursor_read::<T>()?;
176 let mut rev_walker = cursor.walk_back(None)?;
177 while let Some((_, value)) = rev_walker.next().transpose()? {
178 assert!(selector(value) <= num);
179 }
180 Ok(())
181 })
182 }
183
184 pub fn insert_header<TX: DbTx + DbTxMut>(
186 writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
187 tx: &TX,
188 header: &SealedHeader,
189 ) -> ProviderResult<()> {
190 if let Some(writer) = writer {
191 let segment_header = writer.user_header();
194 if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
195 for block_number in 0..header.number {
196 let mut prev = header.clone_header();
197 prev.number = block_number;
198 writer.append_header(&prev, &B256::ZERO)?;
199 }
200 }
201
202 writer.append_header(header.header(), &header.hash())?;
203 } else {
204 tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
205 tx.put::<tables::Headers>(header.number, header.header().clone())?;
206 }
207
208 tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
209 Ok(())
210 }
211
212 fn insert_headers_inner<'a, I>(&self, headers: I) -> ProviderResult<()>
213 where
214 I: IntoIterator<Item = &'a SealedHeader>,
215 {
216 let provider = self.factory.static_file_provider();
217 let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
218 let tx = self.factory.provider_rw()?.into_tx();
219
220 for header in headers {
221 Self::insert_header(Some(&mut writer), &tx, header)?;
222 }
223
224 writer.commit()?;
225 tx.commit()?;
226
227 Ok(())
228 }
229
230 pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
233 where
234 I: IntoIterator<Item = &'a SealedHeader>,
235 {
236 self.insert_headers_inner::<I>(headers)
237 }
238
239 pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
247 where
248 I: IntoIterator<Item = &'a SealedBlock<Block>>,
249 {
250 let provider = self.factory.static_file_provider();
251
252 let tx = self.factory.provider_rw().unwrap().into_tx();
253 let mut next_tx_num = storage_kind.tx_offset();
254
255 let blocks = blocks.into_iter().collect::<Vec<_>>();
256
257 {
258 let mut headers_writer = storage_kind
259 .is_static()
260 .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
261
262 blocks.iter().try_for_each(|block| {
263 Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header())
264 })?;
265
266 if let Some(mut writer) = headers_writer {
267 writer.commit()?;
268 }
269 }
270
271 {
272 let mut txs_writer = storage_kind
273 .is_static()
274 .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
275
276 blocks.into_iter().try_for_each(|block| {
277 let block_body_indices = StoredBlockBodyIndices {
279 first_tx_num: next_tx_num,
280 tx_count: block.transaction_count() as u64,
281 };
282
283 if !block.body().transactions.is_empty() {
284 tx.put::<tables::TransactionBlocks>(
285 block_body_indices.last_tx_num(),
286 block.number,
287 )?;
288 }
289 tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
290
291 let res = block.body().transactions.iter().try_for_each(|body_tx| {
292 if let Some(txs_writer) = &mut txs_writer {
293 txs_writer.append_transaction(next_tx_num, body_tx)?;
294 } else {
295 tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
296 }
297 next_tx_num += 1;
298 Ok::<(), ProviderError>(())
299 });
300
301 if let Some(txs_writer) = &mut txs_writer {
302 let segment_header = txs_writer.user_header();
305 if segment_header.block_end().is_none() &&
306 segment_header.expected_block_start() == 0
307 {
308 for block in 0..block.number {
309 txs_writer.increment_block(block)?;
310 }
311 }
312 txs_writer.increment_block(block.number)?;
313 }
314 res
315 })?;
316
317 if let Some(txs_writer) = &mut txs_writer {
318 txs_writer.commit()?;
319 }
320 }
321
322 tx.commit()?;
323
324 Ok(())
325 }
326
327 pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
328 where
329 I: IntoIterator<Item = (TxHash, TxNumber)>,
330 {
331 self.commit_with_provider(|provider| {
332 provider.with_rocksdb_batch(|batch: RocksBatchArg<'_>| {
333 let mut writer = EitherWriter::new_transaction_hash_numbers(provider, batch)?;
334 for (tx_hash, tx_num) in tx_hash_numbers {
335 writer.put_transaction_hash_number(tx_hash, tx_num, false)?;
336 }
337 Ok(((), writer.into_raw_rocksdb_batch()))
338 })
339 })
340 }
341
342 pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
344 where
345 I: IntoIterator<Item = (TxNumber, Receipt)>,
346 {
347 self.commit(|tx| {
348 receipts.into_iter().try_for_each(|(tx_num, receipt)| {
349 Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
351 })
352 })
353 }
354
355 pub fn insert_receipts_by_block<I, J>(
358 &self,
359 receipts: I,
360 storage_kind: StorageKind,
361 ) -> ProviderResult<()>
362 where
363 I: IntoIterator<Item = (BlockNumber, J)>,
364 J: IntoIterator<Item = (TxNumber, Receipt)>,
365 {
366 match storage_kind {
367 StorageKind::Database(_) => self.commit(|tx| {
368 receipts.into_iter().try_for_each(|(_, receipts)| {
369 for (tx_num, receipt) in receipts {
370 tx.put::<tables::Receipts>(tx_num, receipt)?;
371 }
372 Ok(())
373 })
374 }),
375 StorageKind::Static => {
376 let provider = self.factory.static_file_provider();
377 let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
378 let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
379 writer.increment_block(block_num)?;
380 writer.append_receipts(receipts.into_iter().map(Ok))?;
381 Ok(())
382 });
383 writer.commit_without_sync_all()?;
384 res
385 }
386 }
387 }
388
389 pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
390 where
391 I: IntoIterator<Item = (TxNumber, Address)>,
392 {
393 self.commit(|tx| {
394 transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
395 Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
397 })
398 })
399 }
400
401 pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
403 where
404 I: IntoIterator<Item = (Address, (Account, S))>,
405 S: IntoIterator<Item = StorageEntry>,
406 {
407 self.commit(|tx| {
408 accounts.into_iter().try_for_each(|(address, (account, storage))| {
409 let hashed_address = keccak256(address);
410
411 tx.put::<tables::PlainAccountState>(address, account)?;
413 tx.put::<tables::HashedAccounts>(hashed_address, account)?;
414
415 storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
417 let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
418
419 let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
420 if cursor
421 .seek_by_key_subkey(address, entry.key)?
422 .filter(|e| e.key == entry.key)
423 .is_some()
424 {
425 cursor.delete_current()?;
426 }
427 cursor.upsert(address, &entry)?;
428
429 let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
430 if cursor
431 .seek_by_key_subkey(hashed_address, hashed_entry.key)?
432 .filter(|e| e.key == hashed_entry.key)
433 .is_some()
434 {
435 cursor.delete_current()?;
436 }
437 cursor.upsert(hashed_address, &hashed_entry)?;
438
439 Ok(())
440 })
441 })
442 })
443 }
444
445 pub fn insert_changesets<I>(
447 &self,
448 changesets: I,
449 block_offset: Option<u64>,
450 ) -> ProviderResult<()>
451 where
452 I: IntoIterator<Item = ChangeSet>,
453 {
454 let offset = block_offset.unwrap_or_default();
455 self.commit(|tx| {
456 changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
457 changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
458 let block = offset + block as u64;
459 tx.put::<tables::AccountChangeSets>(
461 block,
462 AccountBeforeTx { address, info: Some(old_account) },
463 )?;
464
465 let block_address = (block, address).into();
466
467 old_storage.into_iter().try_for_each(|entry| {
469 Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
470 })
471 })
472 })
473 })
474 }
475
476 pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
477 where
478 I: IntoIterator<Item = ChangeSet>,
479 {
480 let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
481 let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
482
483 for (block, changeset) in changesets.into_iter().enumerate() {
484 for (address, _, storage_entries) in changeset {
485 accounts.entry(address).or_default().push(block as u64);
486 for storage_entry in storage_entries {
487 storages.entry((address, storage_entry.key)).or_default().push(block as u64);
488 }
489 }
490 }
491
492 let provider_rw = self.factory.provider_rw()?;
493 provider_rw.insert_account_history_index(accounts)?;
494 provider_rw.insert_storage_history_index(storages)?;
495 provider_rw.commit()?;
496
497 Ok(())
498 }
499}
500
501#[derive(Debug)]
503pub enum StorageKind {
504 Database(Option<u64>),
505 Static,
506}
507
508impl StorageKind {
509 #[expect(dead_code)]
510 const fn is_database(&self) -> bool {
511 matches!(self, Self::Database(_))
512 }
513
514 const fn is_static(&self) -> bool {
515 matches!(self, Self::Static)
516 }
517
518 fn tx_offset(&self) -> u64 {
519 if let Self::Database(offset) = self {
520 return offset.unwrap_or_default();
521 }
522 0
523 }
524}