reth_stages/test_utils/
test_db.rs

1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4    test_utils::{create_test_rw_db, create_test_rw_db_with_path, create_test_static_files_dir},
5    DatabaseEnv,
6};
7use reth_db_api::{
8    common::KeyValue,
9    cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
10    database::Database,
11    models::{AccountBeforeTx, StoredBlockBodyIndices},
12    table::Table,
13    tables,
14    transaction::{DbTx, DbTxMut},
15    DatabaseError as DbError,
16};
17use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
18use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
19use reth_provider::{
20    providers::{StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter},
21    test_utils::MockNodeTypesWithDB,
22    HistoryWriter, ProviderError, ProviderFactory, StaticFileProviderFactory, StatsReader,
23};
24use reth_static_file_types::StaticFileSegment;
25use reth_storage_errors::provider::ProviderResult;
26use reth_testing_utils::generators::ChangeSet;
27use std::{collections::BTreeMap, fmt::Debug, path::Path};
28use tempfile::TempDir;
29
30/// Test database that is used for testing stage implementations.
31#[derive(Debug)]
32pub struct TestStageDB {
33    pub factory: ProviderFactory<MockNodeTypesWithDB>,
34    pub temp_static_files_dir: TempDir,
35}
36
37impl Default for TestStageDB {
38    /// Create a new instance of [`TestStageDB`]
39    fn default() -> Self {
40        let (static_dir, static_dir_path) = create_test_static_files_dir();
41        Self {
42            temp_static_files_dir: static_dir,
43            factory: ProviderFactory::new(
44                create_test_rw_db(),
45                MAINNET.clone(),
46                StaticFileProvider::read_write(static_dir_path).unwrap(),
47            )
48            .expect("failed to create test provider factory"),
49        }
50    }
51}
52
53impl TestStageDB {
54    pub fn new(path: &Path) -> Self {
55        let (static_dir, static_dir_path) = create_test_static_files_dir();
56
57        Self {
58            temp_static_files_dir: static_dir,
59            factory: ProviderFactory::new(
60                create_test_rw_db_with_path(path),
61                MAINNET.clone(),
62                StaticFileProvider::read_write(static_dir_path).unwrap(),
63            )
64            .expect("failed to create test provider factory"),
65        }
66    }
67
68    /// Invoke a callback with transaction committing it afterwards
69    pub fn commit<F>(&self, f: F) -> ProviderResult<()>
70    where
71        F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
72    {
73        let tx = self.factory.provider_rw()?;
74        f(tx.tx_ref())?;
75        tx.commit().expect("failed to commit");
76        Ok(())
77    }
78
79    /// Invoke a callback with a read transaction
80    pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
81    where
82        F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
83    {
84        f(self.factory.provider()?.tx_ref())
85    }
86
87    /// Check if the table is empty
88    pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
89        self.query(|tx| {
90            let last = tx.cursor_read::<T>()?.last()?;
91            Ok(last.is_none())
92        })
93    }
94
95    /// Return full table as Vec
96    pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
97    where
98        T::Key: Default + Ord,
99    {
100        self.query(|tx| {
101            Ok(tx
102                .cursor_read::<T>()?
103                .walk(Some(T::Key::default()))?
104                .collect::<Result<Vec<_>, DbError>>()?)
105        })
106    }
107
108    /// Return the number of entries in the table or static file segment
109    pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
110        self.factory.provider()?.count_entries::<T>()
111    }
112
113    /// Check that there is no table entry above a given
114    /// number by [`Table::Key`]
115    pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
116    where
117        T: Table,
118        F: FnMut(T::Key) -> BlockNumber,
119    {
120        self.query(|tx| {
121            let mut cursor = tx.cursor_read::<T>()?;
122            if let Some((key, _)) = cursor.last()? {
123                assert!(selector(key) <= num);
124            }
125            Ok(())
126        })
127    }
128
129    /// Check that there is no table entry above a given
130    /// number by [`Table::Value`]
131    pub fn ensure_no_entry_above_by_value<T, F>(
132        &self,
133        num: u64,
134        mut selector: F,
135    ) -> ProviderResult<()>
136    where
137        T: Table,
138        F: FnMut(T::Value) -> BlockNumber,
139    {
140        self.query(|tx| {
141            let mut cursor = tx.cursor_read::<T>()?;
142            let mut rev_walker = cursor.walk_back(None)?;
143            while let Some((_, value)) = rev_walker.next().transpose()? {
144                assert!(selector(value) <= num);
145            }
146            Ok(())
147        })
148    }
149
150    /// Insert header to static file if `writer` exists, otherwise to DB.
151    pub fn insert_header<TX: DbTx + DbTxMut>(
152        writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
153        tx: &TX,
154        header: &SealedHeader,
155    ) -> ProviderResult<()> {
156        if let Some(writer) = writer {
157            // Backfill: some tests start at a forward block number, but static files require no
158            // gaps.
159            let segment_header = writer.user_header();
160            if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
161                for block_number in 0..header.number {
162                    let mut prev = header.clone_header();
163                    prev.number = block_number;
164                    writer.append_header(&prev, &B256::ZERO)?;
165                }
166            }
167
168            writer.append_header(header.header(), &header.hash())?;
169        } else {
170            tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
171            tx.put::<tables::Headers>(header.number, header.header().clone())?;
172        }
173
174        tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
175        Ok(())
176    }
177
178    fn insert_headers_inner<'a, I>(&self, headers: I) -> ProviderResult<()>
179    where
180        I: IntoIterator<Item = &'a SealedHeader>,
181    {
182        let provider = self.factory.static_file_provider();
183        let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
184        let tx = self.factory.provider_rw()?.into_tx();
185
186        for header in headers {
187            Self::insert_header(Some(&mut writer), &tx, header)?;
188        }
189
190        writer.commit()?;
191        tx.commit()?;
192
193        Ok(())
194    }
195
196    /// Insert ordered collection of [`SealedHeader`] into the corresponding static file and tables
197    /// that are supposed to be populated by the headers stage.
198    pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
199    where
200        I: IntoIterator<Item = &'a SealedHeader>,
201    {
202        self.insert_headers_inner::<I>(headers)
203    }
204
205    /// Insert ordered collection of [`SealedBlock`] into corresponding tables.
206    /// Superset functionality of [`TestStageDB::insert_headers`].
207    ///
208    /// If `tx_offset` is set to `None`, then transactions will be stored on static files, otherwise
209    /// database.
210    ///
211    /// Assumes that there's a single transition for each transaction (i.e. no block rewards).
212    pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
213    where
214        I: IntoIterator<Item = &'a SealedBlock<Block>>,
215    {
216        let provider = self.factory.static_file_provider();
217
218        let tx = self.factory.provider_rw().unwrap().into_tx();
219        let mut next_tx_num = storage_kind.tx_offset();
220
221        let blocks = blocks.into_iter().collect::<Vec<_>>();
222
223        {
224            let mut headers_writer = storage_kind
225                .is_static()
226                .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
227
228            blocks.iter().try_for_each(|block| {
229                Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header())
230            })?;
231
232            if let Some(mut writer) = headers_writer {
233                writer.commit()?;
234            }
235        }
236
237        {
238            let mut txs_writer = storage_kind
239                .is_static()
240                .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
241
242            blocks.into_iter().try_for_each(|block| {
243                // Insert into body tables.
244                let block_body_indices = StoredBlockBodyIndices {
245                    first_tx_num: next_tx_num,
246                    tx_count: block.transaction_count() as u64,
247                };
248
249                if !block.body().transactions.is_empty() {
250                    tx.put::<tables::TransactionBlocks>(
251                        block_body_indices.last_tx_num(),
252                        block.number,
253                    )?;
254                }
255                tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
256
257                let res = block.body().transactions.iter().try_for_each(|body_tx| {
258                    if let Some(txs_writer) = &mut txs_writer {
259                        txs_writer.append_transaction(next_tx_num, body_tx)?;
260                    } else {
261                        tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
262                    }
263                    next_tx_num += 1;
264                    Ok::<(), ProviderError>(())
265                });
266
267                if let Some(txs_writer) = &mut txs_writer {
268                    // Backfill: some tests start at a forward block number, but static files
269                    // require no gaps.
270                    let segment_header = txs_writer.user_header();
271                    if segment_header.block_end().is_none() &&
272                        segment_header.expected_block_start() == 0
273                    {
274                        for block in 0..block.number {
275                            txs_writer.increment_block(block)?;
276                        }
277                    }
278                    txs_writer.increment_block(block.number)?;
279                }
280                res
281            })?;
282
283            if let Some(txs_writer) = &mut txs_writer {
284                txs_writer.commit()?;
285            }
286        }
287
288        tx.commit()?;
289
290        Ok(())
291    }
292
293    pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
294    where
295        I: IntoIterator<Item = (TxHash, TxNumber)>,
296    {
297        self.commit(|tx| {
298            tx_hash_numbers.into_iter().try_for_each(|(tx_hash, tx_num)| {
299                // Insert into tx hash numbers table.
300                Ok(tx.put::<tables::TransactionHashNumbers>(tx_hash, tx_num)?)
301            })
302        })
303    }
304
305    /// Insert collection of ([`TxNumber`], [Receipt]) into the corresponding table.
306    pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
307    where
308        I: IntoIterator<Item = (TxNumber, Receipt)>,
309    {
310        self.commit(|tx| {
311            receipts.into_iter().try_for_each(|(tx_num, receipt)| {
312                // Insert into receipts table.
313                Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
314            })
315        })
316    }
317
318    /// Insert collection of ([`TxNumber`], [Receipt]) organized by respective block numbers into
319    /// the corresponding table or static file segment.
320    pub fn insert_receipts_by_block<I, J>(
321        &self,
322        receipts: I,
323        storage_kind: StorageKind,
324    ) -> ProviderResult<()>
325    where
326        I: IntoIterator<Item = (BlockNumber, J)>,
327        J: IntoIterator<Item = (TxNumber, Receipt)>,
328    {
329        match storage_kind {
330            StorageKind::Database(_) => self.commit(|tx| {
331                receipts.into_iter().try_for_each(|(_, receipts)| {
332                    for (tx_num, receipt) in receipts {
333                        tx.put::<tables::Receipts>(tx_num, receipt)?;
334                    }
335                    Ok(())
336                })
337            }),
338            StorageKind::Static => {
339                let provider = self.factory.static_file_provider();
340                let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
341                let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
342                    writer.increment_block(block_num)?;
343                    writer.append_receipts(receipts.into_iter().map(Ok))?;
344                    Ok(())
345                });
346                writer.commit_without_sync_all()?;
347                res
348            }
349        }
350    }
351
352    pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
353    where
354        I: IntoIterator<Item = (TxNumber, Address)>,
355    {
356        self.commit(|tx| {
357            transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
358                // Insert into receipts table.
359                Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
360            })
361        })
362    }
363
364    /// Insert collection of ([Address], [Account]) into corresponding tables.
365    pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
366    where
367        I: IntoIterator<Item = (Address, (Account, S))>,
368        S: IntoIterator<Item = StorageEntry>,
369    {
370        self.commit(|tx| {
371            accounts.into_iter().try_for_each(|(address, (account, storage))| {
372                let hashed_address = keccak256(address);
373
374                // Insert into account tables.
375                tx.put::<tables::PlainAccountState>(address, account)?;
376                tx.put::<tables::HashedAccounts>(hashed_address, account)?;
377
378                // Insert into storage tables.
379                storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
380                    let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
381
382                    let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
383                    if cursor
384                        .seek_by_key_subkey(address, entry.key)?
385                        .filter(|e| e.key == entry.key)
386                        .is_some()
387                    {
388                        cursor.delete_current()?;
389                    }
390                    cursor.upsert(address, &entry)?;
391
392                    let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
393                    if cursor
394                        .seek_by_key_subkey(hashed_address, hashed_entry.key)?
395                        .filter(|e| e.key == hashed_entry.key)
396                        .is_some()
397                    {
398                        cursor.delete_current()?;
399                    }
400                    cursor.upsert(hashed_address, &hashed_entry)?;
401
402                    Ok(())
403                })
404            })
405        })
406    }
407
408    /// Insert collection of [`ChangeSet`] into corresponding tables.
409    pub fn insert_changesets<I>(
410        &self,
411        changesets: I,
412        block_offset: Option<u64>,
413    ) -> ProviderResult<()>
414    where
415        I: IntoIterator<Item = ChangeSet>,
416    {
417        let offset = block_offset.unwrap_or_default();
418        self.commit(|tx| {
419            changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
420                changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
421                    let block = offset + block as u64;
422                    // Insert into account changeset.
423                    tx.put::<tables::AccountChangeSets>(
424                        block,
425                        AccountBeforeTx { address, info: Some(old_account) },
426                    )?;
427
428                    let block_address = (block, address).into();
429
430                    // Insert into storage changeset.
431                    old_storage.into_iter().try_for_each(|entry| {
432                        Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
433                    })
434                })
435            })
436        })
437    }
438
439    pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
440    where
441        I: IntoIterator<Item = ChangeSet>,
442    {
443        let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
444        let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
445
446        for (block, changeset) in changesets.into_iter().enumerate() {
447            for (address, _, storage_entries) in changeset {
448                accounts.entry(address).or_default().push(block as u64);
449                for storage_entry in storage_entries {
450                    storages.entry((address, storage_entry.key)).or_default().push(block as u64);
451                }
452            }
453        }
454
455        let provider_rw = self.factory.provider_rw()?;
456        provider_rw.insert_account_history_index(accounts)?;
457        provider_rw.insert_storage_history_index(storages)?;
458        provider_rw.commit()?;
459
460        Ok(())
461    }
462}
463
464/// Used to identify where to store data when setting up a test.
465#[derive(Debug)]
466pub enum StorageKind {
467    Database(Option<u64>),
468    Static,
469}
470
471impl StorageKind {
472    #[expect(dead_code)]
473    const fn is_database(&self) -> bool {
474        matches!(self, Self::Database(_))
475    }
476
477    const fn is_static(&self) -> bool {
478        matches!(self, Self::Static)
479    }
480
481    fn tx_offset(&self) -> u64 {
482        if let Self::Database(offset) = self {
483            return offset.unwrap_or_default();
484        }
485        0
486    }
487}