1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256, U256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4 test_utils::{create_test_rw_db, create_test_rw_db_with_path, create_test_static_files_dir},
5 DatabaseEnv,
6};
7use reth_db_api::{
8 common::KeyValue,
9 cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
10 database::Database,
11 models::{AccountBeforeTx, StoredBlockBodyIndices},
12 table::Table,
13 tables,
14 transaction::{DbTx, DbTxMut},
15 DatabaseError as DbError,
16};
17use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
18use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
19use reth_provider::{
20 providers::{StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter},
21 test_utils::MockNodeTypesWithDB,
22 HistoryWriter, ProviderError, ProviderFactory, StaticFileProviderFactory, StatsReader,
23};
24use reth_static_file_types::StaticFileSegment;
25use reth_storage_errors::provider::ProviderResult;
26use reth_testing_utils::generators::ChangeSet;
27use std::{collections::BTreeMap, fmt::Debug, path::Path};
28use tempfile::TempDir;
29
30#[derive(Debug)]
32pub struct TestStageDB {
33 pub factory: ProviderFactory<MockNodeTypesWithDB>,
34 pub temp_static_files_dir: TempDir,
35}
36
37impl Default for TestStageDB {
38 fn default() -> Self {
40 let (static_dir, static_dir_path) = create_test_static_files_dir();
41 Self {
42 temp_static_files_dir: static_dir,
43 factory: ProviderFactory::new(
44 create_test_rw_db(),
45 MAINNET.clone(),
46 StaticFileProvider::read_write(static_dir_path).unwrap(),
47 ),
48 }
49 }
50}
51
52impl TestStageDB {
53 pub fn new(path: &Path) -> Self {
54 let (static_dir, static_dir_path) = create_test_static_files_dir();
55
56 Self {
57 temp_static_files_dir: static_dir,
58 factory: ProviderFactory::new(
59 create_test_rw_db_with_path(path),
60 MAINNET.clone(),
61 StaticFileProvider::read_write(static_dir_path).unwrap(),
62 ),
63 }
64 }
65
66 pub fn commit<F>(&self, f: F) -> ProviderResult<()>
68 where
69 F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
70 {
71 let tx = self.factory.provider_rw()?;
72 f(tx.tx_ref())?;
73 tx.commit().expect("failed to commit");
74 Ok(())
75 }
76
77 pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
79 where
80 F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
81 {
82 f(self.factory.provider()?.tx_ref())
83 }
84
85 pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
87 self.query(|tx| {
88 let last = tx.cursor_read::<T>()?.last()?;
89 Ok(last.is_none())
90 })
91 }
92
93 pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
95 where
96 T::Key: Default + Ord,
97 {
98 self.query(|tx| {
99 Ok(tx
100 .cursor_read::<T>()?
101 .walk(Some(T::Key::default()))?
102 .collect::<Result<Vec<_>, DbError>>()?)
103 })
104 }
105
106 pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
108 self.factory.provider()?.count_entries::<T>()
109 }
110
111 pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
114 where
115 T: Table,
116 F: FnMut(T::Key) -> BlockNumber,
117 {
118 self.query(|tx| {
119 let mut cursor = tx.cursor_read::<T>()?;
120 if let Some((key, _)) = cursor.last()? {
121 assert!(selector(key) <= num);
122 }
123 Ok(())
124 })
125 }
126
127 pub fn ensure_no_entry_above_by_value<T, F>(
130 &self,
131 num: u64,
132 mut selector: F,
133 ) -> ProviderResult<()>
134 where
135 T: Table,
136 F: FnMut(T::Value) -> BlockNumber,
137 {
138 self.query(|tx| {
139 let mut cursor = tx.cursor_read::<T>()?;
140 let mut rev_walker = cursor.walk_back(None)?;
141 while let Some((_, value)) = rev_walker.next().transpose()? {
142 assert!(selector(value) <= num);
143 }
144 Ok(())
145 })
146 }
147
148 pub fn insert_header<TX: DbTx + DbTxMut>(
150 writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
151 tx: &TX,
152 header: &SealedHeader,
153 td: U256,
154 ) -> ProviderResult<()> {
155 if let Some(writer) = writer {
156 let segment_header = writer.user_header();
159 if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
160 for block_number in 0..header.number {
161 let mut prev = header.clone_header();
162 prev.number = block_number;
163 writer.append_header(&prev, U256::ZERO, &B256::ZERO)?;
164 }
165 }
166
167 writer.append_header(header.header(), td, &header.hash())?;
168 } else {
169 tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
170 tx.put::<tables::HeaderTerminalDifficulties>(header.number, td.into())?;
171 tx.put::<tables::Headers>(header.number, header.header().clone())?;
172 }
173
174 tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
175 Ok(())
176 }
177
178 fn insert_headers_inner<'a, I, const TD: bool>(&self, headers: I) -> ProviderResult<()>
179 where
180 I: IntoIterator<Item = &'a SealedHeader>,
181 {
182 let provider = self.factory.static_file_provider();
183 let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
184 let tx = self.factory.provider_rw()?.into_tx();
185 let mut td = U256::ZERO;
186
187 for header in headers {
188 if TD {
189 td += header.difficulty;
190 }
191 Self::insert_header(Some(&mut writer), &tx, header, td)?;
192 }
193
194 writer.commit()?;
195 tx.commit()?;
196
197 Ok(())
198 }
199
200 pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
203 where
204 I: IntoIterator<Item = &'a SealedHeader>,
205 {
206 self.insert_headers_inner::<I, false>(headers)
207 }
208
209 pub fn insert_headers_with_td<'a, I>(&self, headers: I) -> ProviderResult<()>
213 where
214 I: IntoIterator<Item = &'a SealedHeader>,
215 {
216 self.insert_headers_inner::<I, true>(headers)
217 }
218
219 pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
227 where
228 I: IntoIterator<Item = &'a SealedBlock<Block>>,
229 {
230 let provider = self.factory.static_file_provider();
231
232 let tx = self.factory.provider_rw().unwrap().into_tx();
233 let mut next_tx_num = storage_kind.tx_offset();
234
235 let blocks = blocks.into_iter().collect::<Vec<_>>();
236
237 {
238 let mut headers_writer = storage_kind
239 .is_static()
240 .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
241
242 blocks.iter().try_for_each(|block| {
243 Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header(), U256::ZERO)
244 })?;
245
246 if let Some(mut writer) = headers_writer {
247 writer.commit()?;
248 }
249 }
250
251 {
252 let mut txs_writer = storage_kind
253 .is_static()
254 .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
255
256 blocks.into_iter().try_for_each(|block| {
257 let block_body_indices = StoredBlockBodyIndices {
259 first_tx_num: next_tx_num,
260 tx_count: block.transaction_count() as u64,
261 };
262
263 if !block.body().transactions.is_empty() {
264 tx.put::<tables::TransactionBlocks>(
265 block_body_indices.last_tx_num(),
266 block.number,
267 )?;
268 }
269 tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
270
271 let res = block.body().transactions.iter().try_for_each(|body_tx| {
272 if let Some(txs_writer) = &mut txs_writer {
273 txs_writer.append_transaction(next_tx_num, body_tx)?;
274 } else {
275 tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
276 }
277 next_tx_num += 1;
278 Ok::<(), ProviderError>(())
279 });
280
281 if let Some(txs_writer) = &mut txs_writer {
282 let segment_header = txs_writer.user_header();
285 if segment_header.block_end().is_none() &&
286 segment_header.expected_block_start() == 0
287 {
288 for block in 0..block.number {
289 txs_writer.increment_block(block)?;
290 }
291 }
292 txs_writer.increment_block(block.number)?;
293 }
294 res
295 })?;
296
297 if let Some(txs_writer) = &mut txs_writer {
298 txs_writer.commit()?;
299 }
300 }
301
302 tx.commit()?;
303
304 Ok(())
305 }
306
307 pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
308 where
309 I: IntoIterator<Item = (TxHash, TxNumber)>,
310 {
311 self.commit(|tx| {
312 tx_hash_numbers.into_iter().try_for_each(|(tx_hash, tx_num)| {
313 Ok(tx.put::<tables::TransactionHashNumbers>(tx_hash, tx_num)?)
315 })
316 })
317 }
318
319 pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
321 where
322 I: IntoIterator<Item = (TxNumber, Receipt)>,
323 {
324 self.commit(|tx| {
325 receipts.into_iter().try_for_each(|(tx_num, receipt)| {
326 Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
328 })
329 })
330 }
331
332 pub fn insert_receipts_by_block<I, J>(
335 &self,
336 receipts: I,
337 storage_kind: StorageKind,
338 ) -> ProviderResult<()>
339 where
340 I: IntoIterator<Item = (BlockNumber, J)>,
341 J: IntoIterator<Item = (TxNumber, Receipt)>,
342 {
343 match storage_kind {
344 StorageKind::Database(_) => self.commit(|tx| {
345 receipts.into_iter().try_for_each(|(_, receipts)| {
346 for (tx_num, receipt) in receipts {
347 tx.put::<tables::Receipts>(tx_num, receipt)?;
348 }
349 Ok(())
350 })
351 }),
352 StorageKind::Static => {
353 let provider = self.factory.static_file_provider();
354 let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
355 let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
356 writer.increment_block(block_num)?;
357 writer.append_receipts(receipts.into_iter().map(Ok))?;
358 Ok(())
359 });
360 writer.commit_without_sync_all()?;
361 res
362 }
363 }
364 }
365
366 pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
367 where
368 I: IntoIterator<Item = (TxNumber, Address)>,
369 {
370 self.commit(|tx| {
371 transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
372 Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
374 })
375 })
376 }
377
378 pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
380 where
381 I: IntoIterator<Item = (Address, (Account, S))>,
382 S: IntoIterator<Item = StorageEntry>,
383 {
384 self.commit(|tx| {
385 accounts.into_iter().try_for_each(|(address, (account, storage))| {
386 let hashed_address = keccak256(address);
387
388 tx.put::<tables::PlainAccountState>(address, account)?;
390 tx.put::<tables::HashedAccounts>(hashed_address, account)?;
391
392 storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
394 let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
395
396 let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
397 if cursor
398 .seek_by_key_subkey(address, entry.key)?
399 .filter(|e| e.key == entry.key)
400 .is_some()
401 {
402 cursor.delete_current()?;
403 }
404 cursor.upsert(address, &entry)?;
405
406 let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
407 if cursor
408 .seek_by_key_subkey(hashed_address, hashed_entry.key)?
409 .filter(|e| e.key == hashed_entry.key)
410 .is_some()
411 {
412 cursor.delete_current()?;
413 }
414 cursor.upsert(hashed_address, &hashed_entry)?;
415
416 Ok(())
417 })
418 })
419 })
420 }
421
422 pub fn insert_changesets<I>(
424 &self,
425 changesets: I,
426 block_offset: Option<u64>,
427 ) -> ProviderResult<()>
428 where
429 I: IntoIterator<Item = ChangeSet>,
430 {
431 let offset = block_offset.unwrap_or_default();
432 self.commit(|tx| {
433 changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
434 changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
435 let block = offset + block as u64;
436 tx.put::<tables::AccountChangeSets>(
438 block,
439 AccountBeforeTx { address, info: Some(old_account) },
440 )?;
441
442 let block_address = (block, address).into();
443
444 old_storage.into_iter().try_for_each(|entry| {
446 Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
447 })
448 })
449 })
450 })
451 }
452
453 pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
454 where
455 I: IntoIterator<Item = ChangeSet>,
456 {
457 let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
458 let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
459
460 for (block, changeset) in changesets.into_iter().enumerate() {
461 for (address, _, storage_entries) in changeset {
462 accounts.entry(address).or_default().push(block as u64);
463 for storage_entry in storage_entries {
464 storages.entry((address, storage_entry.key)).or_default().push(block as u64);
465 }
466 }
467 }
468
469 let provider_rw = self.factory.provider_rw()?;
470 provider_rw.insert_account_history_index(accounts)?;
471 provider_rw.insert_storage_history_index(storages)?;
472 provider_rw.commit()?;
473
474 Ok(())
475 }
476}
477
478#[derive(Debug)]
480pub enum StorageKind {
481 Database(Option<u64>),
482 Static,
483}
484
485impl StorageKind {
486 #[expect(dead_code)]
487 const fn is_database(&self) -> bool {
488 matches!(self, Self::Database(_))
489 }
490
491 const fn is_static(&self) -> bool {
492 matches!(self, Self::Static)
493 }
494
495 fn tx_offset(&self) -> u64 {
496 if let Self::Database(offset) = self {
497 return offset.unwrap_or_default();
498 }
499 0
500 }
501}