1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4 test_utils::{
5 create_test_rocksdb_dir, create_test_rw_db, create_test_rw_db_with_path,
6 create_test_static_files_dir,
7 },
8 DatabaseEnv,
9};
10use reth_db_api::{
11 common::KeyValue,
12 cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
13 database::Database,
14 models::{AccountBeforeTx, StorageBeforeTx, StoredBlockBodyIndices},
15 table::Table,
16 tables,
17 transaction::{DbTx, DbTxMut},
18 DatabaseError as DbError,
19};
20use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
21use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
22use reth_provider::{
23 providers::{
24 RocksDBProvider, StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter,
25 },
26 test_utils::MockNodeTypesWithDB,
27 DatabaseProviderFactory, EitherWriter, HistoryWriter, ProviderError, ProviderFactory,
28 RocksBatchArg, StaticFileProviderFactory, StatsReader,
29};
30use reth_static_file_types::StaticFileSegment;
31use reth_storage_errors::provider::ProviderResult;
32use reth_testing_utils::generators::ChangeSet;
33use std::{collections::BTreeMap, fmt::Debug, path::Path};
34use tempfile::TempDir;
35
36#[derive(Debug)]
38pub struct TestStageDB {
39 pub factory: ProviderFactory<MockNodeTypesWithDB>,
40 pub temp_static_files_dir: TempDir,
41 pub temp_rocksdb_dir: TempDir,
42}
43
44impl Default for TestStageDB {
45 fn default() -> Self {
47 let (static_dir, static_dir_path) = create_test_static_files_dir();
48 let (rocksdb_dir, rocksdb_dir_path) = create_test_rocksdb_dir();
49 Self {
50 temp_static_files_dir: static_dir,
51 temp_rocksdb_dir: rocksdb_dir,
52 factory: ProviderFactory::new(
53 create_test_rw_db(),
54 MAINNET.clone(),
55 StaticFileProvider::read_write(static_dir_path).unwrap(),
56 RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
57 reth_tasks::Runtime::test(),
58 )
59 .expect("failed to create test provider factory"),
60 }
61 }
62}
63
64impl TestStageDB {
65 pub fn new(path: &Path) -> Self {
66 let (static_dir, static_dir_path) = create_test_static_files_dir();
67 let (rocksdb_dir, rocksdb_dir_path) = create_test_rocksdb_dir();
68
69 Self {
70 temp_static_files_dir: static_dir,
71 temp_rocksdb_dir: rocksdb_dir,
72 factory: ProviderFactory::new(
73 create_test_rw_db_with_path(path),
74 MAINNET.clone(),
75 StaticFileProvider::read_write(static_dir_path).unwrap(),
76 RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
77 reth_tasks::Runtime::test(),
78 )
79 .expect("failed to create test provider factory"),
80 }
81 }
82
83 pub fn commit<F>(&self, f: F) -> ProviderResult<()>
85 where
86 F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
87 {
88 let tx = self.factory.provider_rw()?;
89 f(tx.tx_ref())?;
90 tx.commit().expect("failed to commit");
91 Ok(())
92 }
93
94 pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
96 where
97 F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
98 {
99 f(self.factory.provider()?.tx_ref())
100 }
101
102 pub fn query_with_provider<F, Ok>(&self, f: F) -> ProviderResult<Ok>
105 where
106 F: FnOnce(
107 <ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::Provider,
108 ) -> ProviderResult<Ok>,
109 {
110 f(self.factory.provider()?)
111 }
112
113 pub fn commit_with_provider<F>(&self, f: F) -> ProviderResult<()>
115 where
116 F: FnOnce(
117 &<ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::ProviderRW,
118 ) -> ProviderResult<()>,
119 {
120 let provider = self.factory.provider_rw()?;
121 f(&provider)?;
122 provider.commit().expect("failed to commit");
123 Ok(())
124 }
125
126 pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
128 self.query(|tx| {
129 let last = tx.cursor_read::<T>()?.last()?;
130 Ok(last.is_none())
131 })
132 }
133
134 pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
136 where
137 T::Key: Default + Ord,
138 {
139 self.query(|tx| {
140 Ok(tx
141 .cursor_read::<T>()?
142 .walk(Some(T::Key::default()))?
143 .collect::<Result<Vec<_>, DbError>>()?)
144 })
145 }
146
147 pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
149 self.factory.provider()?.count_entries::<T>()
150 }
151
152 pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
155 where
156 T: Table,
157 F: FnMut(T::Key) -> BlockNumber,
158 {
159 self.query(|tx| {
160 let mut cursor = tx.cursor_read::<T>()?;
161 if let Some((key, _)) = cursor.last()? {
162 assert!(selector(key) <= num);
163 }
164 Ok(())
165 })
166 }
167
168 pub fn ensure_no_entry_above_by_value<T, F>(
171 &self,
172 num: u64,
173 mut selector: F,
174 ) -> ProviderResult<()>
175 where
176 T: Table,
177 F: FnMut(T::Value) -> BlockNumber,
178 {
179 self.query(|tx| {
180 let mut cursor = tx.cursor_read::<T>()?;
181 let mut rev_walker = cursor.walk_back(None)?;
182 while let Some((_, value)) = rev_walker.next().transpose()? {
183 assert!(selector(value) <= num);
184 }
185 Ok(())
186 })
187 }
188
189 pub fn insert_header<TX: DbTx + DbTxMut>(
191 writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
192 tx: &TX,
193 header: &SealedHeader,
194 ) -> ProviderResult<()> {
195 if let Some(writer) = writer {
196 let segment_header = writer.user_header();
199 if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
200 for block_number in 0..header.number {
201 let mut prev = header.clone_header();
202 prev.number = block_number;
203 writer.append_header(&prev, &B256::ZERO)?;
204 }
205 }
206
207 writer.append_header(header.header(), &header.hash())?;
208 } else {
209 tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
210 tx.put::<tables::Headers>(header.number, header.header().clone())?;
211 }
212
213 tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
214 Ok(())
215 }
216
217 fn insert_headers_inner<'a, I>(&self, headers: I) -> ProviderResult<()>
218 where
219 I: IntoIterator<Item = &'a SealedHeader>,
220 {
221 let provider = self.factory.static_file_provider();
222 let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
223 let tx = self.factory.provider_rw()?.into_tx();
224
225 for header in headers {
226 Self::insert_header(Some(&mut writer), &tx, header)?;
227 }
228
229 writer.commit()?;
230 tx.commit()?;
231
232 Ok(())
233 }
234
235 pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
238 where
239 I: IntoIterator<Item = &'a SealedHeader>,
240 {
241 self.insert_headers_inner::<I>(headers)
242 }
243
244 pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
252 where
253 I: IntoIterator<Item = &'a SealedBlock<Block>>,
254 {
255 let provider = self.factory.static_file_provider();
256
257 let tx = self.factory.provider_rw().unwrap().into_tx();
258 let mut next_tx_num = storage_kind.tx_offset();
259
260 let blocks = blocks.into_iter().collect::<Vec<_>>();
261
262 {
263 let mut headers_writer = storage_kind
264 .is_static()
265 .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
266
267 blocks.iter().try_for_each(|block| {
268 Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header())
269 })?;
270
271 if let Some(mut writer) = headers_writer {
272 writer.commit()?;
273 }
274 }
275
276 {
277 let mut txs_writer = storage_kind
278 .is_static()
279 .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
280
281 blocks.into_iter().try_for_each(|block| {
282 let block_body_indices = StoredBlockBodyIndices {
284 first_tx_num: next_tx_num,
285 tx_count: block.transaction_count() as u64,
286 };
287
288 if !block.body().transactions.is_empty() {
289 tx.put::<tables::TransactionBlocks>(
290 block_body_indices.last_tx_num(),
291 block.number,
292 )?;
293 }
294 tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
295
296 let res = block.body().transactions.iter().try_for_each(|body_tx| {
297 if let Some(txs_writer) = &mut txs_writer {
298 txs_writer.append_transaction(next_tx_num, body_tx)?;
299 } else {
300 tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
301 }
302 next_tx_num += 1;
303 Ok::<(), ProviderError>(())
304 });
305
306 if let Some(txs_writer) = &mut txs_writer {
307 let segment_header = txs_writer.user_header();
310 if segment_header.block_end().is_none() &&
311 segment_header.expected_block_start() == 0
312 {
313 for block in 0..block.number {
314 txs_writer.increment_block(block)?;
315 }
316 }
317 txs_writer.increment_block(block.number)?;
318 }
319 res
320 })?;
321
322 if let Some(txs_writer) = &mut txs_writer {
323 txs_writer.commit()?;
324 }
325 }
326
327 tx.commit()?;
328
329 Ok(())
330 }
331
332 pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
333 where
334 I: IntoIterator<Item = (TxHash, TxNumber)>,
335 {
336 self.commit_with_provider(|provider| {
337 provider.with_rocksdb_batch(|batch: RocksBatchArg<'_>| {
338 let mut writer = EitherWriter::new_transaction_hash_numbers(provider, batch)?;
339 for (tx_hash, tx_num) in tx_hash_numbers {
340 writer.put_transaction_hash_number(tx_hash, tx_num, false)?;
341 }
342 Ok(((), writer.into_raw_rocksdb_batch()))
343 })
344 })
345 }
346
347 pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
349 where
350 I: IntoIterator<Item = (TxNumber, Receipt)>,
351 {
352 self.commit(|tx| {
353 receipts.into_iter().try_for_each(|(tx_num, receipt)| {
354 Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
356 })
357 })
358 }
359
360 pub fn insert_receipts_by_block<I, J>(
363 &self,
364 receipts: I,
365 storage_kind: StorageKind,
366 ) -> ProviderResult<()>
367 where
368 I: IntoIterator<Item = (BlockNumber, J)>,
369 J: IntoIterator<Item = (TxNumber, Receipt)>,
370 {
371 match storage_kind {
372 StorageKind::Database(_) => self.commit(|tx| {
373 receipts.into_iter().try_for_each(|(_, receipts)| {
374 for (tx_num, receipt) in receipts {
375 tx.put::<tables::Receipts>(tx_num, receipt)?;
376 }
377 Ok(())
378 })
379 }),
380 StorageKind::Static => {
381 let provider = self.factory.static_file_provider();
382 let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
383 let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
384 writer.increment_block(block_num)?;
385 writer.append_receipts(receipts.into_iter().map(Ok))?;
386 Ok(())
387 });
388 writer.commit_without_sync_all()?;
389 res
390 }
391 }
392 }
393
394 pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
395 where
396 I: IntoIterator<Item = (TxNumber, Address)>,
397 {
398 self.commit(|tx| {
399 transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
400 Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
402 })
403 })
404 }
405
406 pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
408 where
409 I: IntoIterator<Item = (Address, (Account, S))>,
410 S: IntoIterator<Item = StorageEntry>,
411 {
412 self.commit(|tx| {
413 accounts.into_iter().try_for_each(|(address, (account, storage))| {
414 let hashed_address = keccak256(address);
415
416 tx.put::<tables::PlainAccountState>(address, account)?;
418 tx.put::<tables::HashedAccounts>(hashed_address, account)?;
419
420 storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
422 let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
423
424 let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
425 if cursor
426 .seek_by_key_subkey(address, entry.key)?
427 .filter(|e| e.key == entry.key)
428 .is_some()
429 {
430 cursor.delete_current()?;
431 }
432 cursor.upsert(address, &entry)?;
433
434 let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
435 if cursor
436 .seek_by_key_subkey(hashed_address, hashed_entry.key)?
437 .filter(|e| e.key == hashed_entry.key)
438 .is_some()
439 {
440 cursor.delete_current()?;
441 }
442 cursor.upsert(hashed_address, &hashed_entry)?;
443
444 Ok(())
445 })
446 })
447 })
448 }
449
450 pub fn insert_changesets<I>(
452 &self,
453 changesets: I,
454 block_offset: Option<u64>,
455 ) -> ProviderResult<()>
456 where
457 I: IntoIterator<Item = ChangeSet>,
458 {
459 let offset = block_offset.unwrap_or_default();
460 self.commit(|tx| {
461 changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
462 changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
463 let block = offset + block as u64;
464 tx.put::<tables::AccountChangeSets>(
466 block,
467 AccountBeforeTx { address, info: Some(old_account) },
468 )?;
469
470 let block_address = (block, address).into();
471
472 old_storage.into_iter().try_for_each(|entry| {
474 Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
475 })
476 })
477 })
478 })
479 }
480
481 pub fn insert_changesets_to_static_files<I>(
483 &self,
484 changesets: I,
485 block_offset: Option<u64>,
486 ) -> ProviderResult<()>
487 where
488 I: IntoIterator<Item = ChangeSet>,
489 {
490 let offset = block_offset.unwrap_or_default();
491 let static_file_provider = self.factory.static_file_provider();
492
493 let mut account_changeset_writer =
494 static_file_provider.latest_writer(StaticFileSegment::AccountChangeSets)?;
495 let mut storage_changeset_writer =
496 static_file_provider.latest_writer(StaticFileSegment::StorageChangeSets)?;
497
498 for (block, changeset) in changesets.into_iter().enumerate() {
499 let block_number = offset + block as u64;
500
501 let mut account_changesets = Vec::new();
502 let mut storage_changesets = Vec::new();
503
504 for (address, old_account, old_storage) in changeset {
505 account_changesets.push(AccountBeforeTx { address, info: Some(old_account) });
506
507 for entry in old_storage {
508 storage_changesets.push(StorageBeforeTx {
509 address,
510 key: entry.key,
511 value: entry.value,
512 });
513 }
514 }
515
516 account_changeset_writer.append_account_changeset(account_changesets, block_number)?;
517 storage_changeset_writer.append_storage_changeset(storage_changesets, block_number)?;
518 }
519
520 account_changeset_writer.commit()?;
521 storage_changeset_writer.commit()?;
522
523 Ok(())
524 }
525
526 pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
527 where
528 I: IntoIterator<Item = ChangeSet>,
529 {
530 let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
531 let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
532
533 for (block, changeset) in changesets.into_iter().enumerate() {
534 for (address, _, storage_entries) in changeset {
535 accounts.entry(address).or_default().push(block as u64);
536 for storage_entry in storage_entries {
537 storages.entry((address, storage_entry.key)).or_default().push(block as u64);
538 }
539 }
540 }
541
542 let provider_rw = self.factory.provider_rw()?;
543 provider_rw.insert_account_history_index(accounts)?;
544 provider_rw.insert_storage_history_index(storages)?;
545 provider_rw.commit()?;
546
547 Ok(())
548 }
549}
550
551#[derive(Debug)]
553pub enum StorageKind {
554 Database(Option<u64>),
555 Static,
556}
557
558impl StorageKind {
559 #[expect(dead_code)]
560 const fn is_database(&self) -> bool {
561 matches!(self, Self::Database(_))
562 }
563
564 const fn is_static(&self) -> bool {
565 matches!(self, Self::Static)
566 }
567
568 fn tx_offset(&self) -> u64 {
569 if let Self::Database(offset) = self {
570 return offset.unwrap_or_default();
571 }
572 0
573 }
574}