reth_stages/test_utils/
test_db.rs

1use alloy_primitives::{keccak256, Address, BlockNumber, TxHash, TxNumber, B256, U256};
2use reth_chainspec::MAINNET;
3use reth_db::{
4    test_utils::{create_test_rw_db, create_test_rw_db_with_path, create_test_static_files_dir},
5    DatabaseEnv,
6};
7use reth_db_api::{
8    common::KeyValue,
9    cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO},
10    database::Database,
11    models::{AccountBeforeTx, StoredBlockBodyIndices},
12    table::Table,
13    tables,
14    transaction::{DbTx, DbTxMut},
15    DatabaseError as DbError,
16};
17use reth_ethereum_primitives::{Block, EthPrimitives, Receipt};
18use reth_primitives_traits::{Account, SealedBlock, SealedHeader, StorageEntry};
19use reth_provider::{
20    providers::{StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter},
21    test_utils::MockNodeTypesWithDB,
22    HistoryWriter, ProviderError, ProviderFactory, StaticFileProviderFactory, StatsReader,
23};
24use reth_static_file_types::StaticFileSegment;
25use reth_storage_errors::provider::ProviderResult;
26use reth_testing_utils::generators::ChangeSet;
27use std::{collections::BTreeMap, fmt::Debug, path::Path};
28use tempfile::TempDir;
29
30/// Test database that is used for testing stage implementations.
31#[derive(Debug)]
32pub struct TestStageDB {
33    pub factory: ProviderFactory<MockNodeTypesWithDB>,
34    pub temp_static_files_dir: TempDir,
35}
36
37impl Default for TestStageDB {
38    /// Create a new instance of [`TestStageDB`]
39    fn default() -> Self {
40        let (static_dir, static_dir_path) = create_test_static_files_dir();
41        Self {
42            temp_static_files_dir: static_dir,
43            factory: ProviderFactory::new(
44                create_test_rw_db(),
45                MAINNET.clone(),
46                StaticFileProvider::read_write(static_dir_path).unwrap(),
47            ),
48        }
49    }
50}
51
52impl TestStageDB {
53    pub fn new(path: &Path) -> Self {
54        let (static_dir, static_dir_path) = create_test_static_files_dir();
55
56        Self {
57            temp_static_files_dir: static_dir,
58            factory: ProviderFactory::new(
59                create_test_rw_db_with_path(path),
60                MAINNET.clone(),
61                StaticFileProvider::read_write(static_dir_path).unwrap(),
62            ),
63        }
64    }
65
66    /// Invoke a callback with transaction committing it afterwards
67    pub fn commit<F>(&self, f: F) -> ProviderResult<()>
68    where
69        F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
70    {
71        let tx = self.factory.provider_rw()?;
72        f(tx.tx_ref())?;
73        tx.commit().expect("failed to commit");
74        Ok(())
75    }
76
77    /// Invoke a callback with a read transaction
78    pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
79    where
80        F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
81    {
82        f(self.factory.provider()?.tx_ref())
83    }
84
85    /// Check if the table is empty
86    pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
87        self.query(|tx| {
88            let last = tx.cursor_read::<T>()?.last()?;
89            Ok(last.is_none())
90        })
91    }
92
93    /// Return full table as Vec
94    pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
95    where
96        T::Key: Default + Ord,
97    {
98        self.query(|tx| {
99            Ok(tx
100                .cursor_read::<T>()?
101                .walk(Some(T::Key::default()))?
102                .collect::<Result<Vec<_>, DbError>>()?)
103        })
104    }
105
106    /// Return the number of entries in the table or static file segment
107    pub fn count_entries<T: Table>(&self) -> ProviderResult<usize> {
108        self.factory.provider()?.count_entries::<T>()
109    }
110
111    /// Check that there is no table entry above a given
112    /// number by [`Table::Key`]
113    pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
114    where
115        T: Table,
116        F: FnMut(T::Key) -> BlockNumber,
117    {
118        self.query(|tx| {
119            let mut cursor = tx.cursor_read::<T>()?;
120            if let Some((key, _)) = cursor.last()? {
121                assert!(selector(key) <= num);
122            }
123            Ok(())
124        })
125    }
126
127    /// Check that there is no table entry above a given
128    /// number by [`Table::Value`]
129    pub fn ensure_no_entry_above_by_value<T, F>(
130        &self,
131        num: u64,
132        mut selector: F,
133    ) -> ProviderResult<()>
134    where
135        T: Table,
136        F: FnMut(T::Value) -> BlockNumber,
137    {
138        self.query(|tx| {
139            let mut cursor = tx.cursor_read::<T>()?;
140            let mut rev_walker = cursor.walk_back(None)?;
141            while let Some((_, value)) = rev_walker.next().transpose()? {
142                assert!(selector(value) <= num);
143            }
144            Ok(())
145        })
146    }
147
148    /// Insert header to static file if `writer` exists, otherwise to DB.
149    pub fn insert_header<TX: DbTx + DbTxMut>(
150        writer: Option<&mut StaticFileProviderRWRefMut<'_, EthPrimitives>>,
151        tx: &TX,
152        header: &SealedHeader,
153        td: U256,
154    ) -> ProviderResult<()> {
155        if let Some(writer) = writer {
156            // Backfill: some tests start at a forward block number, but static files require no
157            // gaps.
158            let segment_header = writer.user_header();
159            if segment_header.block_end().is_none() && segment_header.expected_block_start() == 0 {
160                for block_number in 0..header.number {
161                    let mut prev = header.clone_header();
162                    prev.number = block_number;
163                    writer.append_header(&prev, U256::ZERO, &B256::ZERO)?;
164                }
165            }
166
167            writer.append_header(header.header(), td, &header.hash())?;
168        } else {
169            tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
170            tx.put::<tables::HeaderTerminalDifficulties>(header.number, td.into())?;
171            tx.put::<tables::Headers>(header.number, header.header().clone())?;
172        }
173
174        tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
175        Ok(())
176    }
177
178    fn insert_headers_inner<'a, I, const TD: bool>(&self, headers: I) -> ProviderResult<()>
179    where
180        I: IntoIterator<Item = &'a SealedHeader>,
181    {
182        let provider = self.factory.static_file_provider();
183        let mut writer = provider.latest_writer(StaticFileSegment::Headers)?;
184        let tx = self.factory.provider_rw()?.into_tx();
185        let mut td = U256::ZERO;
186
187        for header in headers {
188            if TD {
189                td += header.difficulty;
190            }
191            Self::insert_header(Some(&mut writer), &tx, header, td)?;
192        }
193
194        writer.commit()?;
195        tx.commit()?;
196
197        Ok(())
198    }
199
200    /// Insert ordered collection of [`SealedHeader`] into the corresponding static file and tables
201    /// that are supposed to be populated by the headers stage.
202    pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
203    where
204        I: IntoIterator<Item = &'a SealedHeader>,
205    {
206        self.insert_headers_inner::<I, false>(headers)
207    }
208
209    /// Inserts total difficulty of headers into the corresponding static file and tables.
210    ///
211    /// Superset functionality of [`TestStageDB::insert_headers`].
212    pub fn insert_headers_with_td<'a, I>(&self, headers: I) -> ProviderResult<()>
213    where
214        I: IntoIterator<Item = &'a SealedHeader>,
215    {
216        self.insert_headers_inner::<I, true>(headers)
217    }
218
219    /// Insert ordered collection of [`SealedBlock`] into corresponding tables.
220    /// Superset functionality of [`TestStageDB::insert_headers`].
221    ///
222    /// If `tx_offset` is set to `None`, then transactions will be stored on static files, otherwise
223    /// database.
224    ///
225    /// Assumes that there's a single transition for each transaction (i.e. no block rewards).
226    pub fn insert_blocks<'a, I>(&self, blocks: I, storage_kind: StorageKind) -> ProviderResult<()>
227    where
228        I: IntoIterator<Item = &'a SealedBlock<Block>>,
229    {
230        let provider = self.factory.static_file_provider();
231
232        let tx = self.factory.provider_rw().unwrap().into_tx();
233        let mut next_tx_num = storage_kind.tx_offset();
234
235        let blocks = blocks.into_iter().collect::<Vec<_>>();
236
237        {
238            let mut headers_writer = storage_kind
239                .is_static()
240                .then(|| provider.latest_writer(StaticFileSegment::Headers).unwrap());
241
242            blocks.iter().try_for_each(|block| {
243                Self::insert_header(headers_writer.as_mut(), &tx, block.sealed_header(), U256::ZERO)
244            })?;
245
246            if let Some(mut writer) = headers_writer {
247                writer.commit()?;
248            }
249        }
250
251        {
252            let mut txs_writer = storage_kind
253                .is_static()
254                .then(|| provider.latest_writer(StaticFileSegment::Transactions).unwrap());
255
256            blocks.into_iter().try_for_each(|block| {
257                // Insert into body tables.
258                let block_body_indices = StoredBlockBodyIndices {
259                    first_tx_num: next_tx_num,
260                    tx_count: block.transaction_count() as u64,
261                };
262
263                if !block.body().transactions.is_empty() {
264                    tx.put::<tables::TransactionBlocks>(
265                        block_body_indices.last_tx_num(),
266                        block.number,
267                    )?;
268                }
269                tx.put::<tables::BlockBodyIndices>(block.number, block_body_indices)?;
270
271                let res = block.body().transactions.iter().try_for_each(|body_tx| {
272                    if let Some(txs_writer) = &mut txs_writer {
273                        txs_writer.append_transaction(next_tx_num, body_tx)?;
274                    } else {
275                        tx.put::<tables::Transactions>(next_tx_num, body_tx.clone())?
276                    }
277                    next_tx_num += 1;
278                    Ok::<(), ProviderError>(())
279                });
280
281                if let Some(txs_writer) = &mut txs_writer {
282                    // Backfill: some tests start at a forward block number, but static files
283                    // require no gaps.
284                    let segment_header = txs_writer.user_header();
285                    if segment_header.block_end().is_none() &&
286                        segment_header.expected_block_start() == 0
287                    {
288                        for block in 0..block.number {
289                            txs_writer.increment_block(block)?;
290                        }
291                    }
292                    txs_writer.increment_block(block.number)?;
293                }
294                res
295            })?;
296
297            if let Some(txs_writer) = &mut txs_writer {
298                txs_writer.commit()?;
299            }
300        }
301
302        tx.commit()?;
303
304        Ok(())
305    }
306
307    pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
308    where
309        I: IntoIterator<Item = (TxHash, TxNumber)>,
310    {
311        self.commit(|tx| {
312            tx_hash_numbers.into_iter().try_for_each(|(tx_hash, tx_num)| {
313                // Insert into tx hash numbers table.
314                Ok(tx.put::<tables::TransactionHashNumbers>(tx_hash, tx_num)?)
315            })
316        })
317    }
318
319    /// Insert collection of ([`TxNumber`], [Receipt]) into the corresponding table.
320    pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
321    where
322        I: IntoIterator<Item = (TxNumber, Receipt)>,
323    {
324        self.commit(|tx| {
325            receipts.into_iter().try_for_each(|(tx_num, receipt)| {
326                // Insert into receipts table.
327                Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
328            })
329        })
330    }
331
332    /// Insert collection of ([`TxNumber`], [Receipt]) organized by respective block numbers into
333    /// the corresponding table or static file segment.
334    pub fn insert_receipts_by_block<I, J>(
335        &self,
336        receipts: I,
337        storage_kind: StorageKind,
338    ) -> ProviderResult<()>
339    where
340        I: IntoIterator<Item = (BlockNumber, J)>,
341        J: IntoIterator<Item = (TxNumber, Receipt)>,
342    {
343        match storage_kind {
344            StorageKind::Database(_) => self.commit(|tx| {
345                receipts.into_iter().try_for_each(|(_, receipts)| {
346                    for (tx_num, receipt) in receipts {
347                        tx.put::<tables::Receipts>(tx_num, receipt)?;
348                    }
349                    Ok(())
350                })
351            }),
352            StorageKind::Static => {
353                let provider = self.factory.static_file_provider();
354                let mut writer = provider.latest_writer(StaticFileSegment::Receipts)?;
355                let res = receipts.into_iter().try_for_each(|(block_num, receipts)| {
356                    writer.increment_block(block_num)?;
357                    writer.append_receipts(receipts.into_iter().map(Ok))?;
358                    Ok(())
359                });
360                writer.commit_without_sync_all()?;
361                res
362            }
363        }
364    }
365
366    pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
367    where
368        I: IntoIterator<Item = (TxNumber, Address)>,
369    {
370        self.commit(|tx| {
371            transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
372                // Insert into receipts table.
373                Ok(tx.put::<tables::TransactionSenders>(tx_num, sender)?)
374            })
375        })
376    }
377
378    /// Insert collection of ([Address], [Account]) into corresponding tables.
379    pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
380    where
381        I: IntoIterator<Item = (Address, (Account, S))>,
382        S: IntoIterator<Item = StorageEntry>,
383    {
384        self.commit(|tx| {
385            accounts.into_iter().try_for_each(|(address, (account, storage))| {
386                let hashed_address = keccak256(address);
387
388                // Insert into account tables.
389                tx.put::<tables::PlainAccountState>(address, account)?;
390                tx.put::<tables::HashedAccounts>(hashed_address, account)?;
391
392                // Insert into storage tables.
393                storage.into_iter().filter(|e| !e.value.is_zero()).try_for_each(|entry| {
394                    let hashed_entry = StorageEntry { key: keccak256(entry.key), ..entry };
395
396                    let mut cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
397                    if cursor
398                        .seek_by_key_subkey(address, entry.key)?
399                        .filter(|e| e.key == entry.key)
400                        .is_some()
401                    {
402                        cursor.delete_current()?;
403                    }
404                    cursor.upsert(address, &entry)?;
405
406                    let mut cursor = tx.cursor_dup_write::<tables::HashedStorages>()?;
407                    if cursor
408                        .seek_by_key_subkey(hashed_address, hashed_entry.key)?
409                        .filter(|e| e.key == hashed_entry.key)
410                        .is_some()
411                    {
412                        cursor.delete_current()?;
413                    }
414                    cursor.upsert(hashed_address, &hashed_entry)?;
415
416                    Ok(())
417                })
418            })
419        })
420    }
421
422    /// Insert collection of [`ChangeSet`] into corresponding tables.
423    pub fn insert_changesets<I>(
424        &self,
425        changesets: I,
426        block_offset: Option<u64>,
427    ) -> ProviderResult<()>
428    where
429        I: IntoIterator<Item = ChangeSet>,
430    {
431        let offset = block_offset.unwrap_or_default();
432        self.commit(|tx| {
433            changesets.into_iter().enumerate().try_for_each(|(block, changeset)| {
434                changeset.into_iter().try_for_each(|(address, old_account, old_storage)| {
435                    let block = offset + block as u64;
436                    // Insert into account changeset.
437                    tx.put::<tables::AccountChangeSets>(
438                        block,
439                        AccountBeforeTx { address, info: Some(old_account) },
440                    )?;
441
442                    let block_address = (block, address).into();
443
444                    // Insert into storage changeset.
445                    old_storage.into_iter().try_for_each(|entry| {
446                        Ok(tx.put::<tables::StorageChangeSets>(block_address, entry)?)
447                    })
448                })
449            })
450        })
451    }
452
453    pub fn insert_history<I>(&self, changesets: I, _block_offset: Option<u64>) -> ProviderResult<()>
454    where
455        I: IntoIterator<Item = ChangeSet>,
456    {
457        let mut accounts = BTreeMap::<Address, Vec<u64>>::new();
458        let mut storages = BTreeMap::<(Address, B256), Vec<u64>>::new();
459
460        for (block, changeset) in changesets.into_iter().enumerate() {
461            for (address, _, storage_entries) in changeset {
462                accounts.entry(address).or_default().push(block as u64);
463                for storage_entry in storage_entries {
464                    storages.entry((address, storage_entry.key)).or_default().push(block as u64);
465                }
466            }
467        }
468
469        let provider_rw = self.factory.provider_rw()?;
470        provider_rw.insert_account_history_index(accounts)?;
471        provider_rw.insert_storage_history_index(storages)?;
472        provider_rw.commit()?;
473
474        Ok(())
475    }
476}
477
478/// Used to identify where to store data when setting up a test.
479#[derive(Debug)]
480pub enum StorageKind {
481    Database(Option<u64>),
482    Static,
483}
484
485impl StorageKind {
486    #[expect(dead_code)]
487    const fn is_database(&self) -> bool {
488        matches!(self, Self::Database(_))
489    }
490
491    const fn is_static(&self) -> bool {
492        matches!(self, Self::Static)
493    }
494
495    fn tx_offset(&self) -> u64 {
496        if let Self::Database(offset) = self {
497            return offset.unwrap_or_default();
498        }
499        0
500    }
501}