reth_trie_parallel/
root.rs

1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::StorageRootError;
8use reth_provider::{
9    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
10    StateCommitmentProvider,
11};
12use reth_storage_errors::db::DatabaseError;
13use reth_trie::{
14    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
15    node_iter::{TrieElement, TrieNodeIter},
16    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
17    updates::TrieUpdates,
18    walker::TrieWalker,
19    HashBuilder, Nibbles, StorageRoot, TrieInput, TRIE_ACCOUNT_RLP_MAX_SIZE,
20};
21use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
22use std::{collections::HashMap, sync::Arc};
23use thiserror::Error;
24use tracing::*;
25
26/// Parallel incremental state root calculator.
27///
28/// The calculator starts off by launching tasks to compute storage roots.
29/// Then, it immediately starts walking the state trie updating the necessary trie
30/// nodes in the process. Upon encountering a leaf node, it will poll the storage root
31/// task for the corresponding hashed address.
32///
33/// Internally, the calculator uses [`ConsistentDbView`] since
34/// it needs to rely on database state saying the same until
35/// the last transaction is open.
36/// See docs of using [`ConsistentDbView`] for caveats.
37#[derive(Debug)]
38pub struct ParallelStateRoot<Factory> {
39    /// Consistent view of the database.
40    view: ConsistentDbView<Factory>,
41    /// Trie input.
42    input: TrieInput,
43    /// Parallel state root metrics.
44    #[cfg(feature = "metrics")]
45    metrics: ParallelStateRootMetrics,
46}
47
48impl<Factory> ParallelStateRoot<Factory> {
49    /// Create new parallel state root calculator.
50    pub fn new(view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
51        Self {
52            view,
53            input,
54            #[cfg(feature = "metrics")]
55            metrics: ParallelStateRootMetrics::default(),
56        }
57    }
58}
59
60impl<Factory> ParallelStateRoot<Factory>
61where
62    Factory: DatabaseProviderFactory<Provider: BlockReader>
63        + StateCommitmentProvider
64        + Clone
65        + Send
66        + Sync
67        + 'static,
68{
69    /// Calculate incremental state root in parallel.
70    pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
71        self.calculate(false).map(|(root, _)| root)
72    }
73
74    /// Calculate incremental state root with updates in parallel.
75    pub fn incremental_root_with_updates(
76        self,
77    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
78        self.calculate(true)
79    }
80
81    fn calculate(
82        self,
83        retain_updates: bool,
84    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
85        let mut tracker = ParallelTrieTracker::default();
86        let trie_nodes_sorted = Arc::new(self.input.nodes.into_sorted());
87        let hashed_state_sorted = Arc::new(self.input.state.into_sorted());
88        let prefix_sets = self.input.prefix_sets.freeze();
89        let storage_root_targets = StorageRootTargets::new(
90            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
91            prefix_sets.storage_prefix_sets,
92        );
93
94        // Pre-calculate storage roots in parallel for accounts which were changed.
95        tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
96        debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
97        let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
98        for (hashed_address, prefix_set) in
99            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
100        {
101            let view = self.view.clone();
102            let hashed_state_sorted = hashed_state_sorted.clone();
103            let trie_nodes_sorted = trie_nodes_sorted.clone();
104            #[cfg(feature = "metrics")]
105            let metrics = self.metrics.storage_trie.clone();
106
107            let (tx, rx) = std::sync::mpsc::sync_channel(1);
108
109            rayon::spawn_fifo(move || {
110                let result = (|| -> Result<_, ParallelStateRootError> {
111                    let provider_ro = view.provider_ro()?;
112                    let trie_cursor_factory = InMemoryTrieCursorFactory::new(
113                        DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
114                        &trie_nodes_sorted,
115                    );
116                    let hashed_state = HashedPostStateCursorFactory::new(
117                        DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
118                        &hashed_state_sorted,
119                    );
120                    Ok(StorageRoot::new_hashed(
121                        trie_cursor_factory,
122                        hashed_state,
123                        hashed_address,
124                        prefix_set,
125                        #[cfg(feature = "metrics")]
126                        metrics,
127                    )
128                    .calculate(retain_updates)?)
129                })();
130                let _ = tx.send(result);
131            });
132            storage_roots.insert(hashed_address, rx);
133        }
134
135        trace!(target: "trie::parallel_state_root", "calculating state root");
136        let mut trie_updates = TrieUpdates::default();
137
138        let provider_ro = self.view.provider_ro()?;
139        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
140            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
141            &trie_nodes_sorted,
142        );
143        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
144            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
145            &hashed_state_sorted,
146        );
147
148        let walker = TrieWalker::new(
149            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
150            prefix_sets.account_prefix_set,
151        )
152        .with_deletions_retained(retain_updates);
153        let mut account_node_iter = TrieNodeIter::state_trie(
154            walker,
155            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
156        );
157
158        let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
159        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
160        while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
161            match node {
162                TrieElement::Branch(node) => {
163                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
164                }
165                TrieElement::Leaf(hashed_address, account) => {
166                    let (storage_root, _, updates) = match storage_roots.remove(&hashed_address) {
167                        Some(rx) => rx.recv().map_err(|_| {
168                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
169                                DatabaseError::Other(format!(
170                                    "channel closed for {hashed_address}"
171                                )),
172                            ))
173                        })??,
174                        // Since we do not store all intermediate nodes in the database, there might
175                        // be a possibility of re-adding a non-modified leaf to the hash builder.
176                        None => {
177                            tracker.inc_missed_leaves();
178                            StorageRoot::new_hashed(
179                                trie_cursor_factory.clone(),
180                                hashed_cursor_factory.clone(),
181                                hashed_address,
182                                Default::default(),
183                                #[cfg(feature = "metrics")]
184                                self.metrics.storage_trie.clone(),
185                            )
186                            .calculate(retain_updates)?
187                        }
188                    };
189
190                    if retain_updates {
191                        trie_updates.insert_storage_updates(hashed_address, updates);
192                    }
193
194                    account_rlp.clear();
195                    let account = account.into_trie_account(storage_root);
196                    account.encode(&mut account_rlp as &mut dyn BufMut);
197                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
198                }
199            }
200        }
201
202        let root = hash_builder.root();
203
204        let removed_keys = account_node_iter.walker.take_removed_keys();
205        trie_updates.finalize(hash_builder, removed_keys, prefix_sets.destroyed_accounts);
206
207        let stats = tracker.finish();
208
209        #[cfg(feature = "metrics")]
210        self.metrics.record_state_trie(stats);
211
212        trace!(
213            target: "trie::parallel_state_root",
214            %root,
215            duration = ?stats.duration(),
216            branches_added = stats.branches_added(),
217            leaves_added = stats.leaves_added(),
218            missed_leaves = stats.missed_leaves(),
219            precomputed_storage_roots = stats.precomputed_storage_roots(),
220            "Calculated state root"
221        );
222
223        Ok((root, trie_updates))
224    }
225}
226
227/// Error during parallel state root calculation.
228#[derive(Error, Debug)]
229pub enum ParallelStateRootError {
230    /// Error while calculating storage root.
231    #[error(transparent)]
232    StorageRoot(#[from] StorageRootError),
233    /// Provider error.
234    #[error(transparent)]
235    Provider(#[from] ProviderError),
236    /// Other unspecified error.
237    #[error("{_0}")]
238    Other(String),
239}
240
241impl From<ParallelStateRootError> for ProviderError {
242    fn from(error: ParallelStateRootError) -> Self {
243        match error {
244            ParallelStateRootError::Provider(error) => error,
245            ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
246                Self::Database(error)
247            }
248            ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use alloy_primitives::{keccak256, Address, U256};
257    use rand::Rng;
258    use reth_primitives_traits::{Account, StorageEntry};
259    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
260    use reth_trie::{test_utils, HashedPostState, HashedStorage};
261
262    #[test]
263    fn random_parallel_root() {
264        let factory = create_test_provider_factory();
265        let consistent_view = ConsistentDbView::new(factory.clone(), None);
266
267        let mut rng = rand::rng();
268        let mut state = (0..100)
269            .map(|_| {
270                let address = Address::random();
271                let account =
272                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
273                let mut storage = HashMap::<B256, U256>::default();
274                let has_storage = rng.random_bool(0.7);
275                if has_storage {
276                    for _ in 0..100 {
277                        storage.insert(
278                            B256::from(U256::from(rng.random::<u64>())),
279                            U256::from(rng.random::<u64>()),
280                        );
281                    }
282                }
283                (address, (account, storage))
284            })
285            .collect::<HashMap<_, _>>();
286
287        {
288            let provider_rw = factory.provider_rw().unwrap();
289            provider_rw
290                .insert_account_for_hashing(
291                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
292                )
293                .unwrap();
294            provider_rw
295                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
296                    (
297                        *address,
298                        storage
299                            .iter()
300                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
301                    )
302                }))
303                .unwrap();
304            provider_rw.commit().unwrap();
305        }
306
307        assert_eq!(
308            ParallelStateRoot::new(consistent_view.clone(), Default::default())
309                .incremental_root()
310                .unwrap(),
311            test_utils::state_root(state.clone())
312        );
313
314        let mut hashed_state = HashedPostState::default();
315        for (address, (account, storage)) in &mut state {
316            let hashed_address = keccak256(address);
317
318            let should_update_account = rng.random_bool(0.5);
319            if should_update_account {
320                *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
321                hashed_state.accounts.insert(hashed_address, Some(*account));
322            }
323
324            let should_update_storage = rng.random_bool(0.3);
325            if should_update_storage {
326                for (slot, value) in storage.iter_mut() {
327                    let hashed_slot = keccak256(slot);
328                    *value = U256::from(rng.random::<u64>());
329                    hashed_state
330                        .storages
331                        .entry(hashed_address)
332                        .or_insert_with(HashedStorage::default)
333                        .storage
334                        .insert(hashed_slot, *value);
335                }
336            }
337        }
338
339        assert_eq!(
340            ParallelStateRoot::new(consistent_view, TrieInput::from_state(hashed_state))
341                .incremental_root()
342                .unwrap(),
343            test_utils::state_root(state)
344        );
345    }
346}