reth_stages/test_utils/
test_db.rs

1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4    test_utils::{
5        create_test_rocksdb_dir, create_test_rw_db, create_test_rw_db_with_path,
6        create_test_static_files_dir,
7    },
8    DatabaseEnv,
9};
10use reth_db_api::{
11    common::KeyValue,
12    cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
13    database::Database,
14    models::{AccountBeforeTx, StoredBlockBodyIndices},
15    table::Table,
16    tables,
17    transaction::{DbTx, DbTxMut},
18    DatabaseError as DbError,
19};
20use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
21use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
22use reth_provider::{
23    providers::{
24        RocksDBProvider, StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter,
25    },
26    test_utils::MockNodeTypesWithDB,
27    DatabaseProviderFactory, EitherWriter, HistoryWriter, ProviderError, ProviderFactory,
28    RocksBatchArg, StaticFileProviderFactory, StatsReader,
29};
30use reth_static_file_types::StaticFileSegment;
31use reth_storage_errors::provider::ProviderResult;
32use reth_testing_utils::generators::ChangeSet;
33use std::{collections::BTreeMap, fmt::Debug, path::Path};
34use tempfile::TempDir;
35
36/// Test database that is used for testing stage implementations.
37#[derive(Debug)]
38pub struct TestStageDB {
39    pub factory: ProviderFactory<MockNodeTypesWithDB>,
40    pub temp_static_files_dir: TempDir,
41}
42
43impl Default for TestStageDB {
44    /// Create a new instance of [`TestStageDB`]
45    fn default() -> Self {
46        let (static_dir, static_dir_path) = create_test_static_files_dir();
47        let (_, rocksdb_dir_path) = create_test_rocksdb_dir();
48        Self {
49            temp_static_files_dir: static_dir,
50            factory: ProviderFactory::new(
51                create_test_rw_db(),
52                MAINNET.clone(),
53                StaticFileProvider::read_write(static_dir_path).unwrap(),
54                RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
55            )
56            .expect("failed to create test provider factory"),
57        }
58    }
59}
60
61impl TestStageDB {
62    pub fn new(path: &Path) -> Self {
63        let (static_dir, static_dir_path) = create_test_static_files_dir();
64        let (_, rocksdb_dir_path) = create_test_rocksdb_dir();
65
66        Self {
67            temp_static_files_dir: static_dir,
68            factory: ProviderFactory::new(
69                create_test_rw_db_with_path(path),
70                MAINNET.clone(),
71                StaticFileProvider::read_write(static_dir_path).unwrap(),
72                RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
73            )
74            .expect("failed to create test provider factory"),
75        }
76    }
77
78    /// Invoke a callback with transaction committing it afterwards
79    pub fn commit<F>(&self, f: F) -> ProviderResult<()>
80    where
81        F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
82    {
83        let tx = self.factory.provider_rw()?;
84        f(tx.tx_ref())?;
85        tx.commit().expect("failed to commit");
86        Ok(())
87    }
88
89    /// Invoke a callback with a read transaction
90    pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
91    where
92        F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
93    {
94        f(self.factory.provider()?.tx_ref())
95    }
96
97    /// Invoke a callback with a provider that can be used to create transactions or fetch from
98    /// static files.
99    pub fn query_with_provider<F, Ok>(&self, f: F) -> ProviderResult<Ok>
100    where
101        F: FnOnce(
102            <ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::Provider,
103        ) -> ProviderResult<Ok>,
104    {
105        f(self.factory.provider()?)
106    }
107
108    /// Invoke a callback with a writable provider, committing afterwards.
109    pub fn commit_with_provider<F>(&self, f: F) -> ProviderResult<()>
110    where
111        F: FnOnce(
112            &<ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::ProviderRW,
113        ) -> ProviderResult<()>,
114    {
115        let provider = self.factory.provider_rw()?;
116        f(&provider)?;
117        provider.commit().expect("failed to commit");
118        Ok(())
119    }
120
121    /// Check if the table is empty
122    pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
123        self.query(|tx| {
124            let last = tx.cursor_read::<T>()?.last()?;
125            Ok(last.is_none())
126        })
127    }
128
129    /// Return full table as Vec
130    pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
131    where
132        T::Key: Default + Ord,
133    {
134        self.query(|tx| {
135            Ok(tx
136                .cursor_read::<T>()?
137                .walk(Some(T::Key::default()))?
138                .collect::<Result<Vec<_>, DbError>>()?)
139        })
140    }
141
142    /// Return the number of entries in the table or static file segment
143    pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
144        self.factory.provider()?.count_entries::<T>()
145    }
146
147    /// Check that there is no table entry above a given
148    /// number by [`Table::Key`]
149    pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
150    where
151        T: Table,
152        F: FnMut(T::Key) -> BlockNumber,
153    {
154        self.query(|tx| {
155            let mut cursor = tx.cursor_read::<T>()?;
156            if let Some((key, _)) = cursor.last()? {
157                assert!(selector(key) <= num);
158            }
159            Ok(())
160        })
161    }
162
163    /// Check that there is no table entry above a given
164    /// number by [`Table::Value`]
165    pub fn ensure_no_entry_above_by_value<T, F>(
166        &self,
167        num: u64,
168        mut selector: F,
169    ) -> ProviderResult<()>
170    where
171        T: Table,
172        F: FnMut(T::Value) -> BlockNumber,
173    {
174        self.query(|tx| {
175            let mut cursor = tx.cursor_read::<T>()?;
176            let mut rev_walker = cursor.walk_back(None)?;
177            while let Some((_, value)) = rev_walker.next().transpose()? {
178                assert!(selector(value) <= num);
179            }
180            Ok(())
181        })
182    }
183
184    /// Insert header to static file if `writer` exists, otherwise to DB.
185    pub fn insert_header<TX: DbTx + DbTxMut>(
186        writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
187        tx: &TX,
188        header: &SealedHeader,
189    ) -> ProviderResult<()> {
190        if let Some(writer) = writer {
191            // Backfill: some tests start at a forward block number, but static files require no
192            // gaps.
193            let segment_header = writer.user_header();
194            if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
195                for block_number in 0..header.number {
196                    let mut prev = header.clone_header();
197                    prev.number = block_number;
198                    writer.append_header(&prev, &B256::ZERO)?;
199                }
200            }
201
202            writer.append_header(header.header(), &header.hash())?;
203        } else {
204            tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
205            tx.put::<tables::Headers>(header.number, header.header().clone())?;
206        }
207
208        tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
209        Ok(())
210    }
211
212    fn insert_headers_inner<'a, I>(&self, headers: I) -> ProviderResult<()>
213    where
214        I: IntoIterator<Item = &'a SealedHeader>,
215    {
216        let provider = self.factory.static_file_provider();
217        let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
218        let tx = self.factory.provider_rw()?.into_tx();
219
220        for header in headers {
221            Self::insert_header(Some(&mut writer), &tx, header)?;
222        }
223
224        writer.commit()?;
225        tx.commit()?;
226
227        Ok(())
228    }
229
230    /// Insert ordered collection of [`SealedHeader`] into the corresponding static file and tables
231    /// that are supposed to be populated by the headers stage.
232    pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
233    where
234        I: IntoIterator<Item = &'a SealedHeader>,
235    {
236        self.insert_headers_inner::<I>(headers)
237    }
238
239    /// Insert ordered collection of [`SealedBlock`] into corresponding tables.
240    /// Superset functionality of [`TestStageDB::insert_headers`].
241    ///
242    /// If `tx_offset` is set to `None`, then transactions will be stored on static files, otherwise
243    /// database.
244    ///
245    /// Assumes that there's a single transition for each transaction (i.e. no block rewards).
246    pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
247    where
248        I: IntoIterator<Item = &'a SealedBlock<Block>>,
249    {
250        let provider = self.factory.static_file_provider();
251
252        let tx = self.factory.provider_rw().unwrap().into_tx();
253        let mut next_tx_num = storage_kind.tx_offset();
254
255        let blocks = blocks.into_iter().collect::<Vec<_>>();
256
257        {
258            let mut headers_writer = storage_kind
259                .is_static()
260                .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
261
262            blocks.iter().try_for_each(|block| {
263                Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header())
264            })?;
265
266            if let Some(mut writer) = headers_writer {
267                writer.commit()?;
268            }
269        }
270
271        {
272            let mut txs_writer = storage_kind
273                .is_static()
274                .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
275
276            blocks.into_iter().try_for_each(|block| {
277                // Insert into body tables.
278                let block_body_indices = StoredBlockBodyIndices {
279                    first_tx_num: next_tx_num,
280                    tx_count: block.transaction_count() as u64,
281                };
282
283                if !block.body().transactions.is_empty() {
284                    tx.put::<tables::TransactionBlocks>(
285                        block_body_indices.last_tx_num(),
286                        block.number,
287                    )?;
288                }
289                tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
290
291                let res = block.body().transactions.iter().try_for_each(|body_tx| {
292                    if let Some(txs_writer) = &mut txs_writer {
293                        txs_writer.append_transaction(next_tx_num, body_tx)?;
294                    } else {
295                        tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
296                    }
297                    next_tx_num += 1;
298                    Ok::<(), ProviderError>(())
299                });
300
301                if let Some(txs_writer) = &mut txs_writer {
302                    // Backfill: some tests start at a forward block number, but static files
303                    // require no gaps.
304                    let segment_header = txs_writer.user_header();
305                    if segment_header.block_end().is_none() &&
306                        segment_header.expected_block_start() == 0
307                    {
308                        for block in 0..block.number {
309                            txs_writer.increment_block(block)?;
310                        }
311                    }
312                    txs_writer.increment_block(block.number)?;
313                }
314                res
315            })?;
316
317            if let Some(txs_writer) = &mut txs_writer {
318                txs_writer.commit()?;
319            }
320        }
321
322        tx.commit()?;
323
324        Ok(())
325    }
326
327    pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
328    where
329        I: IntoIterator<Item = (TxHash, TxNumber)>,
330    {
331        self.commit_with_provider(|provider| {
332            provider.with_rocksdb_batch(|batch: RocksBatchArg<'_>| {
333                let mut writer = EitherWriter::new_transaction_hash_numbers(provider, batch)?;
334                for (tx_hash, tx_num) in tx_hash_numbers {
335                    writer.put_transaction_hash_number(tx_hash, tx_num, false)?;
336                }
337                Ok(((), writer.into_raw_rocksdb_batch()))
338            })
339        })
340    }
341
342    /// Insert collection of ([`TxNumber`], [Receipt]) into the corresponding table.
343    pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
344    where
345        I: IntoIterator<Item = (TxNumber, Receipt)>,
346    {
347        self.commit(|tx| {
348            receipts.into_iter().try_for_each(|(tx_num, receipt)| {
349                // Insert into receipts table.
350                Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
351            })
352        })
353    }
354
355    /// Insert collection of ([`TxNumber`], [Receipt]) organized by respective block numbers into
356    /// the corresponding table or static file segment.
357    pub fn insert_receipts_by_block<I, J>(
358        &self,
359        receipts: I,
360        storage_kind: StorageKind,
361    ) -> ProviderResult<()>
362    where
363        I: IntoIterator<Item = (BlockNumber, J)>,
364        J: IntoIterator<Item = (TxNumber, Receipt)>,
365    {
366        match storage_kind {
367            StorageKind::Database(_) => self.commit(|tx| {
368                receipts.into_iter().try_for_each(|(_, receipts)| {
369                    for (tx_num, receipt) in receipts {
370                        tx.put::<tables::Receipts>(tx_num, receipt)?;
371                    }
372                    Ok(())
373                })
374            }),
375            StorageKind::Static => {
376                let provider = self.factory.static_file_provider();
377                let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
378                let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
379                    writer.increment_block(block_num)?;
380                    writer.append_receipts(receipts.into_iter().map(Ok))?;
381                    Ok(())
382                });
383                writer.commit_without_sync_all()?;
384                res
385            }
386        }
387    }
388
389    pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
390    where
391        I: IntoIterator<Item = (TxNumber, Address)>,
392    {
393        self.commit(|tx| {
394            transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
395                // Insert into receipts table.
396                Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
397            })
398        })
399    }
400
401    /// Insert collection of ([Address], [Account]) into corresponding tables.
402    pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
403    where
404        I: IntoIterator<Item = (Address, (Account, S))>,
405        S: IntoIterator<Item = StorageEntry>,
406    {
407        self.commit(|tx| {
408            accounts.into_iter().try_for_each(|(address, (account, storage))| {
409                let hashed_address = keccak256(address);
410
411                // Insert into account tables.
412                tx.put::<tables::PlainAccountState>(address, account)?;
413                tx.put::<tables::HashedAccounts>(hashed_address, account)?;
414
415                // Insert into storage tables.
416                storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
417                    let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
418
419                    let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
420                    if cursor
421                        .seek_by_key_subkey(address, entry.key)?
422                        .filter(|e| e.key == entry.key)
423                        .is_some()
424                    {
425                        cursor.delete_current()?;
426                    }
427                    cursor.upsert(address, &entry)?;
428
429                    let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
430                    if cursor
431                        .seek_by_key_subkey(hashed_address, hashed_entry.key)?
432                        .filter(|e| e.key == hashed_entry.key)
433                        .is_some()
434                    {
435                        cursor.delete_current()?;
436                    }
437                    cursor.upsert(hashed_address, &hashed_entry)?;
438
439                    Ok(())
440                })
441            })
442        })
443    }
444
445    /// Insert collection of [`ChangeSet`] into corresponding tables.
446    pub fn insert_changesets<I>(
447        &self,
448        changesets: I,
449        block_offset: Option<u64>,
450    ) -> ProviderResult<()>
451    where
452        I: IntoIterator<Item = ChangeSet>,
453    {
454        let offset = block_offset.unwrap_or_default();
455        self.commit(|tx| {
456            changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
457                changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
458                    let block = offset + block as u64;
459                    // Insert into account changeset.
460                    tx.put::<tables::AccountChangeSets>(
461                        block,
462                        AccountBeforeTx { address, info: Some(old_account) },
463                    )?;
464
465                    let block_address = (block, address).into();
466
467                    // Insert into storage changeset.
468                    old_storage.into_iter().try_for_each(|entry| {
469                        Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
470                    })
471                })
472            })
473        })
474    }
475
476    pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
477    where
478        I: IntoIterator<Item = ChangeSet>,
479    {
480        let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
481        let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
482
483        for (block, changeset) in changesets.into_iter().enumerate() {
484            for (address, _, storage_entries) in changeset {
485                accounts.entry(address).or_default().push(block as u64);
486                for storage_entry in storage_entries {
487                    storages.entry((address, storage_entry.key)).or_default().push(block as u64);
488                }
489            }
490        }
491
492        let provider_rw = self.factory.provider_rw()?;
493        provider_rw.insert_account_history_index(accounts)?;
494        provider_rw.insert_storage_history_index(storages)?;
495        provider_rw.commit()?;
496
497        Ok(())
498    }
499}
500
501/// Used to identify where to store data when setting up a test.
502#[derive(Debug)]
503pub enum StorageKind {
504    Database(Option<u64>),
505    Static,
506}
507
508impl StorageKind {
509    #[expect(dead_code)]
510    const fn is_database(&self) -> bool {
511        matches!(self, Self::Database(_))
512    }
513
514    const fn is_static(&self) -> bool {
515        matches!(self, Self::Static)
516    }
517
518    fn tx_offset(&self) -> u64 {
519        if let Self::Database(offset) = self {
520            return offset.unwrap_or_default();
521        }
522        0
523    }
524}