reth_trie/
trie.rs

1use crate::{
2    hashed_cursor::{HashedCursorFactory, HashedStorageCursor},
3    node_iter::{TrieElement, TrieNodeIter},
4    prefix_set::{PrefixSet, TriePrefixSets},
5    progress::{IntermediateStateRootState, StateRootProgress},
6    stats::TrieTracker,
7    trie_cursor::TrieCursorFactory,
8    updates::{StorageTrieUpdates, TrieUpdates},
9    walker::TrieWalker,
10    HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
11};
12use alloy_consensus::EMPTY_ROOT_HASH;
13use alloy_primitives::{keccak256, Address, B256};
14use alloy_rlp::{BufMut, Encodable};
15use reth_execution_errors::{StateRootError, StorageRootError};
16use tracing::trace;
17
18#[cfg(feature = "metrics")]
19use crate::metrics::{StateRootMetrics, TrieRootMetrics};
20
21/// `StateRoot` is used to compute the root node of a state trie.
22#[derive(Debug)]
23pub struct StateRoot<T, H> {
24    /// The factory for trie cursors.
25    pub trie_cursor_factory: T,
26    /// The factory for hashed cursors.
27    pub hashed_cursor_factory: H,
28    /// A set of prefix sets that have changed.
29    pub prefix_sets: TriePrefixSets,
30    /// Previous intermediate state.
31    previous_state: Option<IntermediateStateRootState>,
32    /// The number of updates after which the intermediate progress should be returned.
33    threshold: u64,
34    #[cfg(feature = "metrics")]
35    /// State root metrics.
36    metrics: StateRootMetrics,
37}
38
39impl<T, H> StateRoot<T, H> {
40    /// Creates [`StateRoot`] with `trie_cursor_factory` and `hashed_cursor_factory`. All other
41    /// parameters are set to reasonable defaults.
42    ///
43    /// The cursors created by given factories are then used to walk through the accounts and
44    /// calculate the state root value with.
45    pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
46        Self {
47            trie_cursor_factory,
48            hashed_cursor_factory,
49            prefix_sets: TriePrefixSets::default(),
50            previous_state: None,
51            threshold: 100_000,
52            #[cfg(feature = "metrics")]
53            metrics: StateRootMetrics::default(),
54        }
55    }
56
57    /// Set the prefix sets.
58    pub fn with_prefix_sets(mut self, prefix_sets: TriePrefixSets) -> Self {
59        self.prefix_sets = prefix_sets;
60        self
61    }
62
63    /// Set the threshold.
64    pub const fn with_threshold(mut self, threshold: u64) -> Self {
65        self.threshold = threshold;
66        self
67    }
68
69    /// Set the threshold to maximum value so that intermediate progress is not returned.
70    pub const fn with_no_threshold(mut self) -> Self {
71        self.threshold = u64::MAX;
72        self
73    }
74
75    /// Set the previously recorded intermediate state.
76    pub fn with_intermediate_state(mut self, state: Option<IntermediateStateRootState>) -> Self {
77        self.previous_state = state;
78        self
79    }
80
81    /// Set the hashed cursor factory.
82    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> StateRoot<T, HF> {
83        StateRoot {
84            trie_cursor_factory: self.trie_cursor_factory,
85            hashed_cursor_factory,
86            prefix_sets: self.prefix_sets,
87            threshold: self.threshold,
88            previous_state: self.previous_state,
89            #[cfg(feature = "metrics")]
90            metrics: self.metrics,
91        }
92    }
93
94    /// Set the trie cursor factory.
95    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StateRoot<TF, H> {
96        StateRoot {
97            trie_cursor_factory,
98            hashed_cursor_factory: self.hashed_cursor_factory,
99            prefix_sets: self.prefix_sets,
100            threshold: self.threshold,
101            previous_state: self.previous_state,
102            #[cfg(feature = "metrics")]
103            metrics: self.metrics,
104        }
105    }
106}
107
108impl<T, H> StateRoot<T, H>
109where
110    T: TrieCursorFactory + Clone,
111    H: HashedCursorFactory + Clone,
112{
113    /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the
114    /// nodes into the hash builder. Collects the updates in the process.
115    ///
116    /// Ignores the threshold.
117    ///
118    /// # Returns
119    ///
120    /// The intermediate progress of state root computation and the trie updates.
121    pub fn root_with_updates(self) -> Result<(B256, TrieUpdates), StateRootError> {
122        match self.with_no_threshold().calculate(true)? {
123            StateRootProgress::Complete(root, _, updates) => Ok((root, updates)),
124            StateRootProgress::Progress(..) => unreachable!(), // unreachable threshold
125        }
126    }
127
128    /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the
129    /// nodes into the hash builder.
130    ///
131    /// # Returns
132    ///
133    /// The state root hash.
134    pub fn root(self) -> Result<B256, StateRootError> {
135        match self.calculate(false)? {
136            StateRootProgress::Complete(root, _, _) => Ok(root),
137            StateRootProgress::Progress(..) => unreachable!(), // update retenion is disabled
138        }
139    }
140
141    /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the
142    /// nodes into the hash builder. Collects the updates in the process.
143    ///
144    /// # Returns
145    ///
146    /// The intermediate progress of state root computation.
147    pub fn root_with_progress(self) -> Result<StateRootProgress, StateRootError> {
148        self.calculate(true)
149    }
150
151    fn calculate(self, retain_updates: bool) -> Result<StateRootProgress, StateRootError> {
152        trace!(target: "trie::state_root", "calculating state root");
153        let mut tracker = TrieTracker::default();
154        let mut trie_updates = TrieUpdates::default();
155
156        let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
157
158        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
159        let (mut hash_builder, mut account_node_iter) = match self.previous_state {
160            Some(state) => {
161                let hash_builder = state.hash_builder.with_updates(retain_updates);
162                let walker = TrieWalker::from_stack(
163                    trie_cursor,
164                    state.walker_stack,
165                    self.prefix_sets.account_prefix_set,
166                )
167                .with_deletions_retained(retain_updates);
168                let node_iter = TrieNodeIter::new(walker, hashed_account_cursor)
169                    .with_last_hashed_key(state.last_account_key);
170                (hash_builder, node_iter)
171            }
172            None => {
173                let hash_builder = HashBuilder::default().with_updates(retain_updates);
174                let walker = TrieWalker::new(trie_cursor, self.prefix_sets.account_prefix_set)
175                    .with_deletions_retained(retain_updates);
176                let node_iter = TrieNodeIter::new(walker, hashed_account_cursor);
177                (hash_builder, node_iter)
178            }
179        };
180
181        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
182        let mut hashed_entries_walked = 0;
183        let mut updated_storage_nodes = 0;
184        while let Some(node) = account_node_iter.try_next()? {
185            match node {
186                TrieElement::Branch(node) => {
187                    tracker.inc_branch();
188                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
189                }
190                TrieElement::Leaf(hashed_address, account) => {
191                    tracker.inc_leaf();
192                    hashed_entries_walked += 1;
193
194                    // We assume we can always calculate a storage root without
195                    // OOMing. This opens us up to a potential DOS vector if
196                    // a contract had too many storage entries and they were
197                    // all buffered w/o us returning and committing our intermediate
198                    // progress.
199                    // TODO: We can consider introducing the TrieProgress::Progress/Complete
200                    // abstraction inside StorageRoot, but let's give it a try as-is for now.
201                    let storage_root_calculator = StorageRoot::new_hashed(
202                        self.trie_cursor_factory.clone(),
203                        self.hashed_cursor_factory.clone(),
204                        hashed_address,
205                        self.prefix_sets
206                            .storage_prefix_sets
207                            .get(&hashed_address)
208                            .cloned()
209                            .unwrap_or_default(),
210                        #[cfg(feature = "metrics")]
211                        self.metrics.storage_trie.clone(),
212                    );
213
214                    let storage_root = if retain_updates {
215                        let (root, storage_slots_walked, updates) =
216                            storage_root_calculator.root_with_updates()?;
217                        hashed_entries_walked += storage_slots_walked;
218                        // We only walk over hashed address once, so it's safe to insert.
219                        updated_storage_nodes += updates.len();
220                        trie_updates.insert_storage_updates(hashed_address, updates);
221                        root
222                    } else {
223                        storage_root_calculator.root()?
224                    };
225
226                    account_rlp.clear();
227                    let account = account.into_trie_account(storage_root);
228                    account.encode(&mut account_rlp as &mut dyn BufMut);
229                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
230
231                    // Decide if we need to return intermediate progress.
232                    let total_updates_len = updated_storage_nodes +
233                        account_node_iter.walker.removed_keys_len() +
234                        hash_builder.updates_len();
235                    if retain_updates && total_updates_len as u64 >= self.threshold {
236                        let (walker_stack, walker_deleted_keys) = account_node_iter.walker.split();
237                        trie_updates.removed_nodes.extend(walker_deleted_keys);
238                        let (hash_builder, hash_builder_updates) = hash_builder.split();
239                        trie_updates.account_nodes.extend(hash_builder_updates);
240
241                        let state = IntermediateStateRootState {
242                            hash_builder,
243                            walker_stack,
244                            last_account_key: hashed_address,
245                        };
246
247                        return Ok(StateRootProgress::Progress(
248                            Box::new(state),
249                            hashed_entries_walked,
250                            trie_updates,
251                        ))
252                    }
253                }
254            }
255        }
256
257        let root = hash_builder.root();
258
259        let removed_keys = account_node_iter.walker.take_removed_keys();
260        trie_updates.finalize(hash_builder, removed_keys, self.prefix_sets.destroyed_accounts);
261
262        let stats = tracker.finish();
263
264        #[cfg(feature = "metrics")]
265        self.metrics.state_trie.record(stats);
266
267        trace!(
268            target: "trie::state_root",
269            %root,
270            duration = ?stats.duration(),
271            branches_added = stats.branches_added(),
272            leaves_added = stats.leaves_added(),
273            "calculated state root"
274        );
275
276        Ok(StateRootProgress::Complete(root, hashed_entries_walked, trie_updates))
277    }
278}
279
280/// `StorageRoot` is used to compute the root node of an account storage trie.
281#[derive(Debug)]
282pub struct StorageRoot<T, H> {
283    /// A reference to the database transaction.
284    pub trie_cursor_factory: T,
285    /// The factory for hashed cursors.
286    pub hashed_cursor_factory: H,
287    /// The hashed address of an account.
288    pub hashed_address: B256,
289    /// The set of storage slot prefixes that have changed.
290    pub prefix_set: PrefixSet,
291    /// Storage root metrics.
292    #[cfg(feature = "metrics")]
293    metrics: TrieRootMetrics,
294}
295
296impl<T, H> StorageRoot<T, H> {
297    /// Creates a new storage root calculator given a raw address.
298    pub fn new(
299        trie_cursor_factory: T,
300        hashed_cursor_factory: H,
301        address: Address,
302        prefix_set: PrefixSet,
303        #[cfg(feature = "metrics")] metrics: TrieRootMetrics,
304    ) -> Self {
305        Self::new_hashed(
306            trie_cursor_factory,
307            hashed_cursor_factory,
308            keccak256(address),
309            prefix_set,
310            #[cfg(feature = "metrics")]
311            metrics,
312        )
313    }
314
315    /// Creates a new storage root calculator given a hashed address.
316    pub const fn new_hashed(
317        trie_cursor_factory: T,
318        hashed_cursor_factory: H,
319        hashed_address: B256,
320        prefix_set: PrefixSet,
321        #[cfg(feature = "metrics")] metrics: TrieRootMetrics,
322    ) -> Self {
323        Self {
324            trie_cursor_factory,
325            hashed_cursor_factory,
326            hashed_address,
327            prefix_set,
328            #[cfg(feature = "metrics")]
329            metrics,
330        }
331    }
332
333    /// Set the changed prefixes.
334    pub fn with_prefix_set(mut self, prefix_set: PrefixSet) -> Self {
335        self.prefix_set = prefix_set;
336        self
337    }
338
339    /// Set the hashed cursor factory.
340    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> StorageRoot<T, HF> {
341        StorageRoot {
342            trie_cursor_factory: self.trie_cursor_factory,
343            hashed_cursor_factory,
344            hashed_address: self.hashed_address,
345            prefix_set: self.prefix_set,
346            #[cfg(feature = "metrics")]
347            metrics: self.metrics,
348        }
349    }
350
351    /// Set the trie cursor factory.
352    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StorageRoot<TF, H> {
353        StorageRoot {
354            trie_cursor_factory,
355            hashed_cursor_factory: self.hashed_cursor_factory,
356            hashed_address: self.hashed_address,
357            prefix_set: self.prefix_set,
358            #[cfg(feature = "metrics")]
359            metrics: self.metrics,
360        }
361    }
362}
363
364impl<T, H> StorageRoot<T, H>
365where
366    T: TrieCursorFactory,
367    H: HashedCursorFactory,
368{
369    /// Walks the hashed storage table entries for a given address and calculates the storage root.
370    ///
371    /// # Returns
372    ///
373    /// The storage root and storage trie updates for a given address.
374    pub fn root_with_updates(self) -> Result<(B256, usize, StorageTrieUpdates), StorageRootError> {
375        self.calculate(true)
376    }
377
378    /// Walks the hashed storage table entries for a given address and calculates the storage root.
379    ///
380    /// # Returns
381    ///
382    /// The storage root.
383    pub fn root(self) -> Result<B256, StorageRootError> {
384        let (root, _, _) = self.calculate(false)?;
385        Ok(root)
386    }
387
388    /// Walks the hashed storage table entries for a given address and calculates the storage root.
389    ///
390    /// # Returns
391    ///
392    /// The storage root, number of walked entries and trie updates
393    /// for a given address if requested.
394    pub fn calculate(
395        self,
396        retain_updates: bool,
397    ) -> Result<(B256, usize, StorageTrieUpdates), StorageRootError> {
398        trace!(target: "trie::storage_root", hashed_address = ?self.hashed_address, "calculating storage root");
399
400        let mut hashed_storage_cursor =
401            self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
402
403        // short circuit on empty storage
404        if hashed_storage_cursor.is_storage_empty()? {
405            return Ok((EMPTY_ROOT_HASH, 0, StorageTrieUpdates::deleted()))
406        }
407
408        let mut tracker = TrieTracker::default();
409        let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
410        let walker =
411            TrieWalker::new(trie_cursor, self.prefix_set).with_deletions_retained(retain_updates);
412
413        let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
414
415        let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor);
416        while let Some(node) = storage_node_iter.try_next()? {
417            match node {
418                TrieElement::Branch(node) => {
419                    tracker.inc_branch();
420                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
421                }
422                TrieElement::Leaf(hashed_slot, value) => {
423                    tracker.inc_leaf();
424                    hash_builder.add_leaf(
425                        Nibbles::unpack(hashed_slot),
426                        alloy_rlp::encode_fixed_size(&value).as_ref(),
427                    );
428                }
429            }
430        }
431
432        let root = hash_builder.root();
433
434        let mut trie_updates = StorageTrieUpdates::default();
435        let removed_keys = storage_node_iter.walker.take_removed_keys();
436        trie_updates.finalize(hash_builder, removed_keys);
437
438        let stats = tracker.finish();
439
440        #[cfg(feature = "metrics")]
441        self.metrics.record(stats);
442
443        trace!(
444            target: "trie::storage_root",
445            %root,
446            hashed_address = %self.hashed_address,
447            duration = ?stats.duration(),
448            branches_added = stats.branches_added(),
449            leaves_added = stats.leaves_added(),
450            "calculated storage root"
451        );
452
453        let storage_slots_walked = stats.leaves_added() as usize;
454        Ok((root, storage_slots_walked, trie_updates))
455    }
456}