1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4 test_utils::{
5 create_test_rocksdb_dir, create_test_rw_db, create_test_rw_db_with_path,
6 create_test_static_files_dir,
7 },
8 DatabaseEnv,
9};
10use reth_db_api::{
11 common::KeyValue,
12 cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
13 database::Database,
14 models::{AccountBeforeTx, StoredBlockBodyIndices},
15 table::Table,
16 tables,
17 transaction::{DbTx, DbTxMut},
18 DatabaseError as DbError,
19};
20use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
21use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
22use reth_provider::{
23 providers::{
24 RocksDBProvider, StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter,
25 },
26 test_utils::MockNodeTypesWithDB,
27 HistoryWriter, ProviderError, ProviderFactory, StaticFileProviderFactory, StatsReader,
28};
29use reth_static_file_types::StaticFileSegment;
30use reth_storage_errors::provider::ProviderResult;
31use reth_testing_utils::generators::ChangeSet;
32use std::{collections::BTreeMap, fmt::Debug, path::Path};
33use tempfile::TempDir;
34
35#[derive(Debug)]
37pub struct TestStageDB {
38 pub factory: ProviderFactory<MockNodeTypesWithDB>,
39 pub temp_static_files_dir: TempDir,
40}
41
42impl Default for TestStageDB {
43 fn default() -> Self {
45 let (static_dir, static_dir_path) = create_test_static_files_dir();
46 let (_, rocksdb_dir_path) = create_test_rocksdb_dir();
47 Self {
48 temp_static_files_dir: static_dir,
49 factory: ProviderFactory::new(
50 create_test_rw_db(),
51 MAINNET.clone(),
52 StaticFileProvider::read_write(static_dir_path).unwrap(),
53 RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
54 )
55 .expect("failed to create test provider factory"),
56 }
57 }
58}
59
60impl TestStageDB {
61 pub fn new(path: &Path) -> Self {
62 let (static_dir, static_dir_path) = create_test_static_files_dir();
63 let (_, rocksdb_dir_path) = create_test_rocksdb_dir();
64
65 Self {
66 temp_static_files_dir: static_dir,
67 factory: ProviderFactory::new(
68 create_test_rw_db_with_path(path),
69 MAINNET.clone(),
70 StaticFileProvider::read_write(static_dir_path).unwrap(),
71 RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
72 )
73 .expect("failed to create test provider factory"),
74 }
75 }
76
77 pub fn commit<F>(&self, f: F) -> ProviderResult<()>
79 where
80 F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
81 {
82 let tx = self.factory.provider_rw()?;
83 f(tx.tx_ref())?;
84 tx.commit().expect("failed to commit");
85 Ok(())
86 }
87
88 pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
90 where
91 F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
92 {
93 f(self.factory.provider()?.tx_ref())
94 }
95
96 pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
98 self.query(|tx| {
99 let last = tx.cursor_read::<T>()?.last()?;
100 Ok(last.is_none())
101 })
102 }
103
104 pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
106 where
107 T::Key: Default + Ord,
108 {
109 self.query(|tx| {
110 Ok(tx
111 .cursor_read::<T>()?
112 .walk(Some(T::Key::default()))?
113 .collect::<Result<Vec<_>, DbError>>()?)
114 })
115 }
116
117 pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
119 self.factory.provider()?.count_entries::<T>()
120 }
121
122 pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
125 where
126 T: Table,
127 F: FnMut(T::Key) -> BlockNumber,
128 {
129 self.query(|tx| {
130 let mut cursor = tx.cursor_read::<T>()?;
131 if let Some((key, _)) = cursor.last()? {
132 assert!(selector(key) <= num);
133 }
134 Ok(())
135 })
136 }
137
138 pub fn ensure_no_entry_above_by_value<T, F>(
141 &self,
142 num: u64,
143 mut selector: F,
144 ) -> ProviderResult<()>
145 where
146 T: Table,
147 F: FnMut(T::Value) -> BlockNumber,
148 {
149 self.query(|tx| {
150 let mut cursor = tx.cursor_read::<T>()?;
151 let mut rev_walker = cursor.walk_back(None)?;
152 while let Some((_, value)) = rev_walker.next().transpose()? {
153 assert!(selector(value) <= num);
154 }
155 Ok(())
156 })
157 }
158
159 pub fn insert_header<TX: DbTx + DbTxMut>(
161 writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
162 tx: &TX,
163 header: &SealedHeader,
164 ) -> ProviderResult<()> {
165 if let Some(writer) = writer {
166 let segment_header = writer.user_header();
169 if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
170 for block_number in 0..header.number {
171 let mut prev = header.clone_header();
172 prev.number = block_number;
173 writer.append_header(&prev, &B256::ZERO)?;
174 }
175 }
176
177 writer.append_header(header.header(), &header.hash())?;
178 } else {
179 tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
180 tx.put::<tables::Headers>(header.number, header.header().clone())?;
181 }
182
183 tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
184 Ok(())
185 }
186
187 fn insert_headers_inner<'a, I>(&self, headers: I) -> ProviderResult<()>
188 where
189 I: IntoIterator<Item = &'a SealedHeader>,
190 {
191 let provider = self.factory.static_file_provider();
192 let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
193 let tx = self.factory.provider_rw()?.into_tx();
194
195 for header in headers {
196 Self::insert_header(Some(&mut writer), &tx, header)?;
197 }
198
199 writer.commit()?;
200 tx.commit()?;
201
202 Ok(())
203 }
204
205 pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
208 where
209 I: IntoIterator<Item = &'a SealedHeader>,
210 {
211 self.insert_headers_inner::<I>(headers)
212 }
213
214 pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
222 where
223 I: IntoIterator<Item = &'a SealedBlock<Block>>,
224 {
225 let provider = self.factory.static_file_provider();
226
227 let tx = self.factory.provider_rw().unwrap().into_tx();
228 let mut next_tx_num = storage_kind.tx_offset();
229
230 let blocks = blocks.into_iter().collect::<Vec<_>>();
231
232 {
233 let mut headers_writer = storage_kind
234 .is_static()
235 .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
236
237 blocks.iter().try_for_each(|block| {
238 Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header())
239 })?;
240
241 if let Some(mut writer) = headers_writer {
242 writer.commit()?;
243 }
244 }
245
246 {
247 let mut txs_writer = storage_kind
248 .is_static()
249 .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
250
251 blocks.into_iter().try_for_each(|block| {
252 let block_body_indices = StoredBlockBodyIndices {
254 first_tx_num: next_tx_num,
255 tx_count: block.transaction_count() as u64,
256 };
257
258 if !block.body().transactions.is_empty() {
259 tx.put::<tables::TransactionBlocks>(
260 block_body_indices.last_tx_num(),
261 block.number,
262 )?;
263 }
264 tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
265
266 let res = block.body().transactions.iter().try_for_each(|body_tx| {
267 if let Some(txs_writer) = &mut txs_writer {
268 txs_writer.append_transaction(next_tx_num, body_tx)?;
269 } else {
270 tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
271 }
272 next_tx_num += 1;
273 Ok::<(), ProviderError>(())
274 });
275
276 if let Some(txs_writer) = &mut txs_writer {
277 let segment_header = txs_writer.user_header();
280 if segment_header.block_end().is_none() &&
281 segment_header.expected_block_start() == 0
282 {
283 for block in 0..block.number {
284 txs_writer.increment_block(block)?;
285 }
286 }
287 txs_writer.increment_block(block.number)?;
288 }
289 res
290 })?;
291
292 if let Some(txs_writer) = &mut txs_writer {
293 txs_writer.commit()?;
294 }
295 }
296
297 tx.commit()?;
298
299 Ok(())
300 }
301
302 pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
303 where
304 I: IntoIterator<Item = (TxHash, TxNumber)>,
305 {
306 self.commit(|tx| {
307 tx_hash_numbers.into_iter().try_for_each(|(tx_hash, tx_num)| {
308 Ok(tx.put::<tables::TransactionHashNumbers>(tx_hash, tx_num)?)
310 })
311 })
312 }
313
314 pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
316 where
317 I: IntoIterator<Item = (TxNumber, Receipt)>,
318 {
319 self.commit(|tx| {
320 receipts.into_iter().try_for_each(|(tx_num, receipt)| {
321 Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
323 })
324 })
325 }
326
327 pub fn insert_receipts_by_block<I, J>(
330 &self,
331 receipts: I,
332 storage_kind: StorageKind,
333 ) -> ProviderResult<()>
334 where
335 I: IntoIterator<Item = (BlockNumber, J)>,
336 J: IntoIterator<Item = (TxNumber, Receipt)>,
337 {
338 match storage_kind {
339 StorageKind::Database(_) => self.commit(|tx| {
340 receipts.into_iter().try_for_each(|(_, receipts)| {
341 for (tx_num, receipt) in receipts {
342 tx.put::<tables::Receipts>(tx_num, receipt)?;
343 }
344 Ok(())
345 })
346 }),
347 StorageKind::Static => {
348 let provider = self.factory.static_file_provider();
349 let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
350 let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
351 writer.increment_block(block_num)?;
352 writer.append_receipts(receipts.into_iter().map(Ok))?;
353 Ok(())
354 });
355 writer.commit_without_sync_all()?;
356 res
357 }
358 }
359 }
360
361 pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
362 where
363 I: IntoIterator<Item = (TxNumber, Address)>,
364 {
365 self.commit(|tx| {
366 transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
367 Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
369 })
370 })
371 }
372
373 pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
375 where
376 I: IntoIterator<Item = (Address, (Account, S))>,
377 S: IntoIterator<Item = StorageEntry>,
378 {
379 self.commit(|tx| {
380 accounts.into_iter().try_for_each(|(address, (account, storage))| {
381 let hashed_address = keccak256(address);
382
383 tx.put::<tables::PlainAccountState>(address, account)?;
385 tx.put::<tables::HashedAccounts>(hashed_address, account)?;
386
387 storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
389 let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
390
391 let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
392 if cursor
393 .seek_by_key_subkey(address, entry.key)?
394 .filter(|e| e.key == entry.key)
395 .is_some()
396 {
397 cursor.delete_current()?;
398 }
399 cursor.upsert(address, &entry)?;
400
401 let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
402 if cursor
403 .seek_by_key_subkey(hashed_address, hashed_entry.key)?
404 .filter(|e| e.key == hashed_entry.key)
405 .is_some()
406 {
407 cursor.delete_current()?;
408 }
409 cursor.upsert(hashed_address, &hashed_entry)?;
410
411 Ok(())
412 })
413 })
414 })
415 }
416
417 pub fn insert_changesets<I>(
419 &self,
420 changesets: I,
421 block_offset: Option<u64>,
422 ) -> ProviderResult<()>
423 where
424 I: IntoIterator<Item = ChangeSet>,
425 {
426 let offset = block_offset.unwrap_or_default();
427 self.commit(|tx| {
428 changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
429 changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
430 let block = offset + block as u64;
431 tx.put::<tables::AccountChangeSets>(
433 block,
434 AccountBeforeTx { address, info: Some(old_account) },
435 )?;
436
437 let block_address = (block, address).into();
438
439 old_storage.into_iter().try_for_each(|entry| {
441 Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
442 })
443 })
444 })
445 })
446 }
447
448 pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
449 where
450 I: IntoIterator<Item = ChangeSet>,
451 {
452 let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
453 let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
454
455 for (block, changeset) in changesets.into_iter().enumerate() {
456 for (address, _, storage_entries) in changeset {
457 accounts.entry(address).or_default().push(block as u64);
458 for storage_entry in storage_entries {
459 storages.entry((address, storage_entry.key)).or_default().push(block as u64);
460 }
461 }
462 }
463
464 let provider_rw = self.factory.provider_rw()?;
465 provider_rw.insert_account_history_index(accounts)?;
466 provider_rw.insert_storage_history_index(storages)?;
467 provider_rw.commit()?;
468
469 Ok(())
470 }
471}
472
473#[derive(Debug)]
475pub enum StorageKind {
476 Database(Option<u64>),
477 Static,
478}
479
480impl StorageKind {
481 #[expect(dead_code)]
482 const fn is_database(&self) -> bool {
483 matches!(self, Self::Database(_))
484 }
485
486 const fn is_static(&self) -> bool {
487 matches!(self, Self::Static)
488 }
489
490 fn tx_offset(&self) -> u64 {
491 if let Self::Database(offset) = self {
492 return offset.unwrap_or_default();
493 }
494 0
495 }
496}