Skip to main content

reth_stages/test_utils/
test_db.rs

1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4    test_utils::{
5        create_test_rocksdb_dir, create_test_rw_db, create_test_rw_db_with_path,
6        create_test_static_files_dir,
7    },
8    DatabaseEnv,
9};
10use reth_db_api::{
11    common::KeyValue,
12    cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
13    database::Database,
14    models::{AccountBeforeTx, StorageBeforeTx, StoredBlockBodyIndices},
15    table::Table,
16    tables,
17    transaction::{DbTx, DbTxMut},
18    DatabaseError as DbError,
19};
20use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
21use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
22use reth_provider::{
23    providers::{
24        RocksDBProvider, StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter,
25    },
26    test_utils::MockNodeTypesWithDB,
27    DatabaseProviderFactory, EitherWriter, HistoryWriter, ProviderError, ProviderFactory,
28    RocksBatchArg, StaticFileProviderFactory, StatsReader,
29};
30use reth_static_file_types::StaticFileSegment;
31use reth_storage_errors::provider::ProviderResult;
32use reth_testing_utils::generators::ChangeSet;
33use std::{collections::BTreeMap, fmt::Debug, path::Path};
34use tempfile::TempDir;
35
36/// Test database that is used for testing stage implementations.
37#[derive(Debug)]
38pub struct TestStageDB {
39    pub factory: ProviderFactory<MockNodeTypesWithDB>,
40    pub temp_static_files_dir: TempDir,
41    pub temp_rocksdb_dir: TempDir,
42}
43
44impl Default for TestStageDB {
45    /// Create a new instance of [`TestStageDB`]
46    fn default() -> Self {
47        let (static_dir, static_dir_path) = create_test_static_files_dir();
48        let (rocksdb_dir, rocksdb_dir_path) = create_test_rocksdb_dir();
49        Self {
50            temp_static_files_dir: static_dir,
51            temp_rocksdb_dir: rocksdb_dir,
52            factory: ProviderFactory::new(
53                create_test_rw_db(),
54                MAINNET.clone(),
55                StaticFileProvider::read_write(static_dir_path).unwrap(),
56                RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
57                reth_tasks::Runtime::test(),
58            )
59            .expect("failed to create test provider factory"),
60        }
61    }
62}
63
64impl TestStageDB {
65    pub fn new(path: &Path) -> Self {
66        let (static_dir, static_dir_path) = create_test_static_files_dir();
67        let (rocksdb_dir, rocksdb_dir_path) = create_test_rocksdb_dir();
68
69        Self {
70            temp_static_files_dir: static_dir,
71            temp_rocksdb_dir: rocksdb_dir,
72            factory: ProviderFactory::new(
73                create_test_rw_db_with_path(path),
74                MAINNET.clone(),
75                StaticFileProvider::read_write(static_dir_path).unwrap(),
76                RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(),
77                reth_tasks::Runtime::test(),
78            )
79            .expect("failed to create test provider factory"),
80        }
81    }
82
83    /// Invoke a callback with transaction committing it afterwards
84    pub fn commit<F>(&self, f: F) -> ProviderResult<()>
85    where
86        F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
87    {
88        let tx = self.factory.provider_rw()?;
89        f(tx.tx_ref())?;
90        tx.commit().expect("failed to commit");
91        Ok(())
92    }
93
94    /// Invoke a callback with a read transaction
95    pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
96    where
97        F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
98    {
99        f(self.factory.provider()?.tx_ref())
100    }
101
102    /// Invoke a callback with a provider that can be used to create transactions or fetch from
103    /// static files.
104    pub fn query_with_provider<F, Ok>(&self, f: F) -> ProviderResult<Ok>
105    where
106        F: FnOnce(
107            <ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::Provider,
108        ) -> ProviderResult<Ok>,
109    {
110        f(self.factory.provider()?)
111    }
112
113    /// Invoke a callback with a writable provider, committing afterwards.
114    pub fn commit_with_provider<F>(&self, f: F) -> ProviderResult<()>
115    where
116        F: FnOnce(
117            &<ProviderFactory<MockNodeTypesWithDB> as DatabaseProviderFactory>::ProviderRW,
118        ) -> ProviderResult<()>,
119    {
120        let provider = self.factory.provider_rw()?;
121        f(&provider)?;
122        provider.commit().expect("failed to commit");
123        Ok(())
124    }
125
126    /// Check if the table is empty
127    pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
128        self.query(|tx| {
129            let last = tx.cursor_read::<T>()?.last()?;
130            Ok(last.is_none())
131        })
132    }
133
134    /// Return full table as Vec
135    pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
136    where
137        T::Key: Default + Ord,
138    {
139        self.query(|tx| {
140            Ok(tx
141                .cursor_read::<T>()?
142                .walk(Some(T::Key::default()))?
143                .collect::<Result<Vec<_>, DbError>>()?)
144        })
145    }
146
147    /// Return the number of entries in the table or static file segment
148    pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
149        self.factory.provider()?.count_entries::<T>()
150    }
151
152    /// Check that there is no table entry above a given
153    /// number by [`Table::Key`]
154    pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
155    where
156        T: Table,
157        F: FnMut(T::Key) -> BlockNumber,
158    {
159        self.query(|tx| {
160            let mut cursor = tx.cursor_read::<T>()?;
161            if let Some((key, _)) = cursor.last()? {
162                assert!(selector(key) <= num);
163            }
164            Ok(())
165        })
166    }
167
168    /// Check that there is no table entry above a given
169    /// number by [`Table::Value`]
170    pub fn ensure_no_entry_above_by_value<T, F>(
171        &self,
172        num: u64,
173        mut selector: F,
174    ) -> ProviderResult<()>
175    where
176        T: Table,
177        F: FnMut(T::Value) -> BlockNumber,
178    {
179        self.query(|tx| {
180            let mut cursor = tx.cursor_read::<T>()?;
181            let mut rev_walker = cursor.walk_back(None)?;
182            while let Some((_, value)) = rev_walker.next().transpose()? {
183                assert!(selector(value) <= num);
184            }
185            Ok(())
186        })
187    }
188
189    /// Insert header to static file if `writer` exists, otherwise to DB.
190    pub fn insert_header<TX: DbTx + DbTxMut>(
191        writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
192        tx: &TX,
193        header: &SealedHeader,
194    ) -> ProviderResult<()> {
195        if let Some(writer) = writer {
196            // Backfill: some tests start at a forward block number, but static files require no
197            // gaps.
198            let segment_header = writer.user_header();
199            if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
200                for block_number in 0..header.number {
201                    let mut prev = header.clone_header();
202                    prev.number = block_number;
203                    writer.append_header(&prev, &B256::ZERO)?;
204                }
205            }
206
207            writer.append_header(header.header(), &header.hash())?;
208        } else {
209            tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
210            tx.put::<tables::Headers>(header.number, header.header().clone())?;
211        }
212
213        tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
214        Ok(())
215    }
216
217    fn insert_headers_inner<'a, I>(&self, headers: I) -> ProviderResult<()>
218    where
219        I: IntoIterator<Item = &'a SealedHeader>,
220    {
221        let provider = self.factory.static_file_provider();
222        let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
223        let tx = self.factory.provider_rw()?.into_tx();
224
225        for header in headers {
226            Self::insert_header(Some(&mut writer), &tx, header)?;
227        }
228
229        writer.commit()?;
230        tx.commit()?;
231
232        Ok(())
233    }
234
235    /// Insert ordered collection of [`SealedHeader`] into the corresponding static file and tables
236    /// that are supposed to be populated by the headers stage.
237    pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
238    where
239        I: IntoIterator<Item = &'a SealedHeader>,
240    {
241        self.insert_headers_inner::<I>(headers)
242    }
243
244    /// Insert ordered collection of [`SealedBlock`] into corresponding tables.
245    /// Superset functionality of [`TestStageDB::insert_headers`].
246    ///
247    /// If `tx_offset` is set to `None`, then transactions will be stored on static files, otherwise
248    /// database.
249    ///
250    /// Assumes that there's a single transition for each transaction (i.e. no block rewards).
251    pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
252    where
253        I: IntoIterator<Item = &'a SealedBlock<Block>>,
254    {
255        let provider = self.factory.static_file_provider();
256
257        let tx = self.factory.provider_rw().unwrap().into_tx();
258        let mut next_tx_num = storage_kind.tx_offset();
259
260        let blocks = blocks.into_iter().collect::<Vec<_>>();
261
262        {
263            let mut headers_writer = storage_kind
264                .is_static()
265                .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
266
267            blocks.iter().try_for_each(|block| {
268                Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header())
269            })?;
270
271            if let Some(mut writer) = headers_writer {
272                writer.commit()?;
273            }
274        }
275
276        {
277            let mut txs_writer = storage_kind
278                .is_static()
279                .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
280
281            blocks.into_iter().try_for_each(|block| {
282                // Insert into body tables.
283                let block_body_indices = StoredBlockBodyIndices {
284                    first_tx_num: next_tx_num,
285                    tx_count: block.transaction_count() as u64,
286                };
287
288                if !block.body().transactions.is_empty() {
289                    tx.put::<tables::TransactionBlocks>(
290                        block_body_indices.last_tx_num(),
291                        block.number,
292                    )?;
293                }
294                tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
295
296                let res = block.body().transactions.iter().try_for_each(|body_tx| {
297                    if let Some(txs_writer) = &mut txs_writer {
298                        txs_writer.append_transaction(next_tx_num, body_tx)?;
299                    } else {
300                        tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
301                    }
302                    next_tx_num += 1;
303                    Ok::<(), ProviderError>(())
304                });
305
306                if let Some(txs_writer) = &mut txs_writer {
307                    // Backfill: some tests start at a forward block number, but static files
308                    // require no gaps.
309                    let segment_header = txs_writer.user_header();
310                    if segment_header.block_end().is_none() &&
311                        segment_header.expected_block_start() == 0
312                    {
313                        for block in 0..block.number {
314                            txs_writer.increment_block(block)?;
315                        }
316                    }
317                    txs_writer.increment_block(block.number)?;
318                }
319                res
320            })?;
321
322            if let Some(txs_writer) = &mut txs_writer {
323                txs_writer.commit()?;
324            }
325        }
326
327        tx.commit()?;
328
329        Ok(())
330    }
331
332    pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
333    where
334        I: IntoIterator<Item = (TxHash, TxNumber)>,
335    {
336        self.commit_with_provider(|provider| {
337            provider.with_rocksdb_batch(|batch: RocksBatchArg<'_>| {
338                let mut writer = EitherWriter::new_transaction_hash_numbers(provider, batch)?;
339                for (tx_hash, tx_num) in tx_hash_numbers {
340                    writer.put_transaction_hash_number(tx_hash, tx_num, false)?;
341                }
342                Ok(((), writer.into_raw_rocksdb_batch()))
343            })
344        })
345    }
346
347    /// Insert collection of ([`TxNumber`], [Receipt]) into the corresponding table.
348    pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
349    where
350        I: IntoIterator<Item = (TxNumber, Receipt)>,
351    {
352        self.commit(|tx| {
353            receipts.into_iter().try_for_each(|(tx_num, receipt)| {
354                // Insert into receipts table.
355                Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
356            })
357        })
358    }
359
360    /// Insert collection of ([`TxNumber`], [Receipt]) organized by respective block numbers into
361    /// the corresponding table or static file segment.
362    pub fn insert_receipts_by_block<I, J>(
363        &self,
364        receipts: I,
365        storage_kind: StorageKind,
366    ) -> ProviderResult<()>
367    where
368        I: IntoIterator<Item = (BlockNumber, J)>,
369        J: IntoIterator<Item = (TxNumber, Receipt)>,
370    {
371        match storage_kind {
372            StorageKind::Database(_) => self.commit(|tx| {
373                receipts.into_iter().try_for_each(|(_, receipts)| {
374                    for (tx_num, receipt) in receipts {
375                        tx.put::<tables::Receipts>(tx_num, receipt)?;
376                    }
377                    Ok(())
378                })
379            }),
380            StorageKind::Static => {
381                let provider = self.factory.static_file_provider();
382                let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
383                let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
384                    writer.increment_block(block_num)?;
385                    writer.append_receipts(receipts.into_iter().map(Ok))?;
386                    Ok(())
387                });
388                writer.commit_without_sync_all()?;
389                res
390            }
391        }
392    }
393
394    pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
395    where
396        I: IntoIterator<Item = (TxNumber, Address)>,
397    {
398        self.commit(|tx| {
399            transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
400                // Insert into receipts table.
401                Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
402            })
403        })
404    }
405
406    /// Insert collection of ([Address], [Account]) into corresponding tables.
407    pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
408    where
409        I: IntoIterator<Item = (Address, (Account, S))>,
410        S: IntoIterator<Item = StorageEntry>,
411    {
412        self.commit(|tx| {
413            accounts.into_iter().try_for_each(|(address, (account, storage))| {
414                let hashed_address = keccak256(address);
415
416                // Insert into account tables.
417                tx.put::<tables::PlainAccountState>(address, account)?;
418                tx.put::<tables::HashedAccounts>(hashed_address, account)?;
419
420                // Insert into storage tables.
421                storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
422                    let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
423
424                    let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
425                    if cursor
426                        .seek_by_key_subkey(address, entry.key)?
427                        .filter(|e| e.key == entry.key)
428                        .is_some()
429                    {
430                        cursor.delete_current()?;
431                    }
432                    cursor.upsert(address, &entry)?;
433
434                    let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
435                    if cursor
436                        .seek_by_key_subkey(hashed_address, hashed_entry.key)?
437                        .filter(|e| e.key == hashed_entry.key)
438                        .is_some()
439                    {
440                        cursor.delete_current()?;
441                    }
442                    cursor.upsert(hashed_address, &hashed_entry)?;
443
444                    Ok(())
445                })
446            })
447        })
448    }
449
450    /// Insert collection of [`ChangeSet`] into corresponding tables.
451    pub fn insert_changesets<I>(
452        &self,
453        changesets: I,
454        block_offset: Option<u64>,
455    ) -> ProviderResult<()>
456    where
457        I: IntoIterator<Item = ChangeSet>,
458    {
459        let offset = block_offset.unwrap_or_default();
460        self.commit(|tx| {
461            changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
462                changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
463                    let block = offset + block as u64;
464                    // Insert into account changeset.
465                    tx.put::<tables::AccountChangeSets>(
466                        block,
467                        AccountBeforeTx { address, info: Some(old_account) },
468                    )?;
469
470                    let block_address = (block, address).into();
471
472                    // Insert into storage changeset.
473                    old_storage.into_iter().try_for_each(|entry| {
474                        Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
475                    })
476                })
477            })
478        })
479    }
480
481    /// Insert collection of [`ChangeSet`] into static files (account and storage changesets).
482    pub fn insert_changesets_to_static_files<I>(
483        &self,
484        changesets: I,
485        block_offset: Option<u64>,
486    ) -> ProviderResult<()>
487    where
488        I: IntoIterator<Item = ChangeSet>,
489    {
490        let offset = block_offset.unwrap_or_default();
491        let static_file_provider = self.factory.static_file_provider();
492
493        let mut account_changeset_writer =
494            static_file_provider.latest_writer(StaticFileSegment::AccountChangeSets)?;
495        let mut storage_changeset_writer =
496            static_file_provider.latest_writer(StaticFileSegment::StorageChangeSets)?;
497
498        for (block, changeset) in changesets.into_iter().enumerate() {
499            let block_number = offset + block as u64;
500
501            let mut account_changesets = Vec::new();
502            let mut storage_changesets = Vec::new();
503
504            for (address, old_account, old_storage) in changeset {
505                account_changesets.push(AccountBeforeTx { address, info: Some(old_account) });
506
507                for entry in old_storage {
508                    storage_changesets.push(StorageBeforeTx {
509                        address,
510                        key: entry.key,
511                        value: entry.value,
512                    });
513                }
514            }
515
516            account_changeset_writer.append_account_changeset(account_changesets, block_number)?;
517            storage_changeset_writer.append_storage_changeset(storage_changesets, block_number)?;
518        }
519
520        account_changeset_writer.commit()?;
521        storage_changeset_writer.commit()?;
522
523        Ok(())
524    }
525
526    pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
527    where
528        I: IntoIterator<Item = ChangeSet>,
529    {
530        let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
531        let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
532
533        for (block, changeset) in changesets.into_iter().enumerate() {
534            for (address, _, storage_entries) in changeset {
535                accounts.entry(address).or_default().push(block as u64);
536                for storage_entry in storage_entries {
537                    storages.entry((address, storage_entry.key)).or_default().push(block as u64);
538                }
539            }
540        }
541
542        let provider_rw = self.factory.provider_rw()?;
543        provider_rw.insert_account_history_index(accounts)?;
544        provider_rw.insert_storage_history_index(storages)?;
545        provider_rw.commit()?;
546
547        Ok(())
548    }
549}
550
551/// Used to identify where to store data when setting up a test.
552#[derive(Debug)]
553pub enum StorageKind {
554    Database(Option<u64>),
555    Static,
556}
557
558impl StorageKind {
559    #[expect(dead_code)]
560    const fn is_database(&self) -> bool {
561        matches!(self, Self::Database(_))
562    }
563
564    const fn is_static(&self) -> bool {
565        matches!(self, Self::Static)
566    }
567
568    fn tx_offset(&self) -> u64 {
569        if let Self::Database(offset) = self {
570            return offset.unwrap_or_default();
571        }
572        0
573    }
574}