reth_provider/providers/state/
overlay.rs

1use alloy_primitives::{BlockNumber, B256};
2use reth_db_api::DatabaseError;
3use reth_errors::{ProviderError, ProviderResult};
4use reth_prune_types::PruneSegment;
5use reth_stages_types::StageId;
6use reth_storage_api::{
7    BlockNumReader, DBProvider, DatabaseProviderFactory, DatabaseProviderROFactory,
8    PruneCheckpointReader, StageCheckpointReader, TrieReader,
9};
10use reth_trie::{
11    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
12    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
13    updates::TrieUpdatesSorted,
14    HashedPostState, HashedPostStateSorted, KeccakKeyHasher,
15};
16use reth_trie_db::{
17    DatabaseHashedCursorFactory, DatabaseHashedPostState, DatabaseTrieCursorFactory,
18};
19use std::sync::Arc;
20use tracing::debug;
21
22/// Factory for creating overlay state providers with optional reverts and overlays.
23///
24/// This factory allows building an `OverlayStateProvider` whose DB state has been reverted to a
25/// particular block, and/or with additional overlay information added on top.
26#[derive(Debug, Clone)]
27pub struct OverlayStateProviderFactory<F> {
28    /// The underlying database provider factory
29    factory: F,
30    /// Optional block hash for collecting reverts
31    block_hash: Option<B256>,
32    /// Optional trie overlay
33    trie_overlay: Option<Arc<TrieUpdatesSorted>>,
34    /// Optional hashed state overlay
35    hashed_state_overlay: Option<Arc<HashedPostStateSorted>>,
36}
37
38impl<F> OverlayStateProviderFactory<F> {
39    /// Create a new overlay state provider factory
40    pub const fn new(factory: F) -> Self {
41        Self { factory, block_hash: None, trie_overlay: None, hashed_state_overlay: None }
42    }
43
44    /// Set the block hash for collecting reverts. All state will be reverted to the point
45    /// _after_ this block has been processed.
46    pub const fn with_block_hash(mut self, block_hash: Option<B256>) -> Self {
47        self.block_hash = block_hash;
48        self
49    }
50
51    /// Set the trie overlay.
52    ///
53    /// This overlay will be applied on top of any reverts applied via `with_block_hash`.
54    pub fn with_trie_overlay(mut self, trie_overlay: Option<Arc<TrieUpdatesSorted>>) -> Self {
55        self.trie_overlay = trie_overlay;
56        self
57    }
58
59    /// Set the hashed state overlay
60    ///
61    /// This overlay will be applied on top of any reverts applied via `with_block_hash`.
62    pub fn with_hashed_state_overlay(
63        mut self,
64        hashed_state_overlay: Option<Arc<HashedPostStateSorted>>,
65    ) -> Self {
66        self.hashed_state_overlay = hashed_state_overlay;
67        self
68    }
69}
70
71impl<F> OverlayStateProviderFactory<F>
72where
73    F: DatabaseProviderFactory,
74    F::Provider: TrieReader + StageCheckpointReader + PruneCheckpointReader + BlockNumReader,
75{
76    /// Returns the block number for [`Self`]'s `block_hash` field, if any.
77    fn get_block_number(&self, provider: &F::Provider) -> ProviderResult<Option<BlockNumber>> {
78        if let Some(block_hash) = self.block_hash {
79            Ok(Some(
80                provider
81                    .convert_hash_or_number(block_hash.into())?
82                    .ok_or_else(|| ProviderError::BlockHashNotFound(block_hash))?,
83            ))
84        } else {
85            Ok(None)
86        }
87    }
88
89    /// Returns whether or not it is required to collect reverts, and validates that there are
90    /// sufficient changesets to revert to the requested block number if so.
91    ///
92    /// Returns an error if the `MerkleChangeSets` checkpoint doesn't cover the requested block.
93    /// Takes into account both the stage checkpoint and the prune checkpoint to determine the
94    /// available data range.
95    fn reverts_required(
96        &self,
97        provider: &F::Provider,
98        requested_block: BlockNumber,
99    ) -> ProviderResult<bool> {
100        // Get the MerkleChangeSets stage and prune checkpoints.
101        let stage_checkpoint = provider.get_stage_checkpoint(StageId::MerkleChangeSets)?;
102        let prune_checkpoint = provider.get_prune_checkpoint(PruneSegment::MerkleChangeSets)?;
103
104        // Get the upper bound from stage checkpoint
105        let upper_bound =
106            stage_checkpoint.as_ref().map(|chk| chk.block_number).ok_or_else(|| {
107                ProviderError::InsufficientChangesets {
108                    requested: requested_block,
109                    available: 0..=0,
110                }
111            })?;
112
113        // If the requested block is the DB tip (determined by the MerkleChangeSets stage
114        // checkpoint) then there won't be any reverts necessary, and we can simply return Ok.
115        if upper_bound == requested_block {
116            return Ok(false)
117        }
118
119        // Extract the lower bound from prune checkpoint if available
120        // The prune checkpoint's block_number is the highest pruned block, so data is available
121        // starting from the next block
122        let lower_bound = prune_checkpoint
123            .and_then(|chk| chk.block_number)
124            .map(|block_number| block_number + 1)
125            .ok_or_else(|| ProviderError::InsufficientChangesets {
126                requested: requested_block,
127                available: 0..=upper_bound,
128            })?;
129
130        let available_range = lower_bound..=upper_bound;
131
132        // Check if the requested block is within the available range
133        if !available_range.contains(&requested_block) {
134            return Err(ProviderError::InsufficientChangesets {
135                requested: requested_block,
136                available: available_range,
137            });
138        }
139
140        Ok(true)
141    }
142}
143
144impl<F> DatabaseProviderROFactory for OverlayStateProviderFactory<F>
145where
146    F: DatabaseProviderFactory,
147    F::Provider: TrieReader + StageCheckpointReader + PruneCheckpointReader + BlockNumReader,
148{
149    type Provider = OverlayStateProvider<F::Provider>;
150
151    /// Create a read-only [`OverlayStateProvider`].
152    fn database_provider_ro(&self) -> ProviderResult<OverlayStateProvider<F::Provider>> {
153        // Get a read-only provider
154        let provider = self.factory.database_provider_ro()?;
155
156        // If block_hash is provided, collect reverts
157        let (trie_updates, hashed_state) = if let Some(from_block) =
158            self.get_block_number(&provider)? &&
159            self.reverts_required(&provider, from_block)?
160        {
161            // Collect trie reverts
162            let mut trie_reverts = provider.trie_reverts(from_block + 1)?;
163
164            // Collect state reverts
165            //
166            // TODO(mediocregopher) make from_reverts return sorted
167            // https://github.com/paradigmxyz/reth/issues/19382
168            let mut hashed_state_reverts = HashedPostState::from_reverts::<KeccakKeyHasher>(
169                provider.tx_ref(),
170                from_block + 1..,
171            )?
172            .into_sorted();
173
174            // Extend with overlays if provided. If the reverts are empty we should just use the
175            // overlays directly, because `extend_ref` will actually clone the overlay.
176            let trie_updates = match self.trie_overlay.as_ref() {
177                Some(trie_overlay) if trie_reverts.is_empty() => Arc::clone(trie_overlay),
178                Some(trie_overlay) => {
179                    trie_reverts.extend_ref(trie_overlay);
180                    Arc::new(trie_reverts)
181                }
182                None => Arc::new(trie_reverts),
183            };
184
185            let hashed_state_updates = match self.hashed_state_overlay.as_ref() {
186                Some(hashed_state_overlay) if hashed_state_reverts.is_empty() => {
187                    Arc::clone(hashed_state_overlay)
188                }
189                Some(hashed_state_overlay) => {
190                    hashed_state_reverts.extend_ref(hashed_state_overlay);
191                    Arc::new(hashed_state_reverts)
192                }
193                None => Arc::new(hashed_state_reverts),
194            };
195
196            debug!(
197                target: "providers::state::overlay",
198                block_hash = ?self.block_hash,
199                ?from_block,
200                num_trie_updates = ?trie_updates.total_len(),
201                num_state_updates = ?hashed_state_updates.total_len(),
202                "Reverted to target block",
203            );
204
205            (trie_updates, hashed_state_updates)
206        } else {
207            // If no block_hash, use overlays directly or defaults
208            let trie_updates =
209                self.trie_overlay.clone().unwrap_or_else(|| Arc::new(TrieUpdatesSorted::default()));
210            let hashed_state = self
211                .hashed_state_overlay
212                .clone()
213                .unwrap_or_else(|| Arc::new(HashedPostStateSorted::default()));
214
215            (trie_updates, hashed_state)
216        };
217
218        Ok(OverlayStateProvider::new(provider, trie_updates, hashed_state))
219    }
220}
221
222/// State provider with in-memory overlay from trie updates and hashed post state.
223///
224/// This provider uses in-memory trie updates and hashed post state as an overlay
225/// on top of a database provider, implementing [`TrieCursorFactory`] and [`HashedCursorFactory`]
226/// using the in-memory overlay factories.
227#[derive(Debug)]
228pub struct OverlayStateProvider<Provider: DBProvider> {
229    provider: Provider,
230    trie_updates: Arc<TrieUpdatesSorted>,
231    hashed_post_state: Arc<HashedPostStateSorted>,
232}
233
234impl<Provider> OverlayStateProvider<Provider>
235where
236    Provider: DBProvider,
237{
238    /// Create new overlay state provider. The `Provider` must be cloneable, which generally means
239    /// it should be wrapped in an `Arc`.
240    pub const fn new(
241        provider: Provider,
242        trie_updates: Arc<TrieUpdatesSorted>,
243        hashed_post_state: Arc<HashedPostStateSorted>,
244    ) -> Self {
245        Self { provider, trie_updates, hashed_post_state }
246    }
247}
248
249impl<Provider> TrieCursorFactory for OverlayStateProvider<Provider>
250where
251    Provider: DBProvider,
252{
253    type AccountTrieCursor<'a>
254        = <InMemoryTrieCursorFactory<
255        DatabaseTrieCursorFactory<&'a Provider::Tx>,
256        &'a TrieUpdatesSorted,
257    > as TrieCursorFactory>::AccountTrieCursor<'a>
258    where
259        Self: 'a;
260
261    type StorageTrieCursor<'a>
262        = <InMemoryTrieCursorFactory<
263        DatabaseTrieCursorFactory<&'a Provider::Tx>,
264        &'a TrieUpdatesSorted,
265    > as TrieCursorFactory>::StorageTrieCursor<'a>
266    where
267        Self: 'a;
268
269    fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
270        let db_trie_cursor_factory = DatabaseTrieCursorFactory::new(self.provider.tx_ref());
271        let trie_cursor_factory =
272            InMemoryTrieCursorFactory::new(db_trie_cursor_factory, self.trie_updates.as_ref());
273        trie_cursor_factory.account_trie_cursor()
274    }
275
276    fn storage_trie_cursor(
277        &self,
278        hashed_address: B256,
279    ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
280        let db_trie_cursor_factory = DatabaseTrieCursorFactory::new(self.provider.tx_ref());
281        let trie_cursor_factory =
282            InMemoryTrieCursorFactory::new(db_trie_cursor_factory, self.trie_updates.as_ref());
283        trie_cursor_factory.storage_trie_cursor(hashed_address)
284    }
285}
286
287impl<Provider> HashedCursorFactory for OverlayStateProvider<Provider>
288where
289    Provider: DBProvider,
290{
291    type AccountCursor<'a>
292        = <HashedPostStateCursorFactory<
293        DatabaseHashedCursorFactory<&'a Provider::Tx>,
294        &'a Arc<HashedPostStateSorted>,
295    > as HashedCursorFactory>::AccountCursor<'a>
296    where
297        Self: 'a;
298
299    type StorageCursor<'a>
300        = <HashedPostStateCursorFactory<
301        DatabaseHashedCursorFactory<&'a Provider::Tx>,
302        &'a Arc<HashedPostStateSorted>,
303    > as HashedCursorFactory>::StorageCursor<'a>
304    where
305        Self: 'a;
306
307    fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
308        let db_hashed_cursor_factory = DatabaseHashedCursorFactory::new(self.provider.tx_ref());
309        let hashed_cursor_factory =
310            HashedPostStateCursorFactory::new(db_hashed_cursor_factory, &self.hashed_post_state);
311        hashed_cursor_factory.hashed_account_cursor()
312    }
313
314    fn hashed_storage_cursor(
315        &self,
316        hashed_address: B256,
317    ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
318        let db_hashed_cursor_factory = DatabaseHashedCursorFactory::new(self.provider.tx_ref());
319        let hashed_cursor_factory =
320            HashedPostStateCursorFactory::new(db_hashed_cursor_factory, &self.hashed_post_state);
321        hashed_cursor_factory.hashed_storage_cursor(hashed_address)
322    }
323}