Skip to main content

reth_trie_parallel/
root.rs

1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::{SparseTrieError, StateProofError, StorageRootError};
8use reth_provider::{DatabaseProviderROFactory, ProviderError};
9use reth_storage_errors::db::DatabaseError;
10use reth_tasks::Runtime;
11use reth_trie::{
12    hashed_cursor::HashedCursorFactory,
13    node_iter::{TrieElement, TrieNodeIter},
14    prefix_set::TriePrefixSets,
15    trie_cursor::TrieCursorFactory,
16    updates::TrieUpdates,
17    walker::TrieWalker,
18    HashBuilder, Nibbles, StorageRoot, TRIE_ACCOUNT_RLP_MAX_SIZE,
19};
20use std::{collections::HashMap, sync::mpsc};
21use thiserror::Error;
22use tracing::*;
23
24/// Parallel incremental state root calculator.
25///
26/// The calculator starts off by launching tasks to compute storage roots.
27/// Then, it immediately starts walking the state trie updating the necessary trie
28/// nodes in the process. Upon encountering a leaf node, it will poll the storage root
29/// task for the corresponding hashed address.
30///
31/// Note: This implementation only serves as a fallback for the sparse trie-based
32/// state root calculation. The sparse trie approach is more efficient as it avoids traversing
33/// the entire trie, only operating on the modified parts.
34#[derive(Debug)]
35pub struct ParallelStateRoot<Factory> {
36    /// Factory for creating state providers.
37    factory: Factory,
38    // Prefix sets indicating which portions of the trie need to be recomputed.
39    prefix_sets: TriePrefixSets,
40    /// The runtime handle for spawning blocking tasks.
41    runtime: Runtime,
42    /// Parallel state root metrics.
43    #[cfg(feature = "metrics")]
44    metrics: ParallelStateRootMetrics,
45}
46
47impl<Factory> ParallelStateRoot<Factory> {
48    /// Create new parallel state root calculator.
49    pub fn new(factory: Factory, prefix_sets: TriePrefixSets, runtime: Runtime) -> Self {
50        Self {
51            factory,
52            prefix_sets,
53            runtime,
54            #[cfg(feature = "metrics")]
55            metrics: ParallelStateRootMetrics::default(),
56        }
57    }
58}
59
60impl<Factory> ParallelStateRoot<Factory>
61where
62    Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
63        + Clone
64        + Send
65        + 'static,
66{
67    /// Calculate incremental state root in parallel.
68    pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
69        self.calculate(false).map(|(root, _)| root)
70    }
71
72    /// Calculate incremental state root with updates in parallel.
73    pub fn incremental_root_with_updates(
74        self,
75    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
76        self.calculate(true)
77    }
78
79    /// Computes the state root by calculating storage roots in parallel for modified accounts,
80    /// then walking the state trie to build the final state root hash.
81    fn calculate(
82        self,
83        retain_updates: bool,
84    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
85        let mut tracker = ParallelTrieTracker::default();
86        let storage_root_targets = StorageRootTargets::new(
87            self.prefix_sets
88                .account_prefix_set
89                .iter()
90                .map(|nibbles| B256::from_slice(&nibbles.pack())),
91            self.prefix_sets.storage_prefix_sets,
92        );
93
94        // Pre-calculate storage roots in parallel for accounts which were changed.
95        tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
96        debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
97        let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
98
99        let handle = self.runtime.handle().clone();
100
101        for (hashed_address, prefix_set) in
102            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
103        {
104            let factory = self.factory.clone();
105            #[cfg(feature = "metrics")]
106            let metrics = self.metrics.storage_trie.clone();
107
108            let (tx, rx) = mpsc::sync_channel(1);
109
110            // Spawn a blocking task to calculate account's storage root from database I/O
111            drop(handle.spawn_blocking(move || {
112                let result = (|| -> Result<_, ParallelStateRootError> {
113                    let provider = factory.database_provider_ro()?;
114                    Ok(StorageRoot::new_hashed(
115                        &provider,
116                        &provider,
117                        hashed_address,
118                        prefix_set,
119                        #[cfg(feature = "metrics")]
120                        metrics,
121                    )
122                    .calculate(retain_updates)?)
123                })();
124                let _ = tx.send(result);
125            }));
126            storage_roots.insert(hashed_address, rx);
127        }
128
129        trace!(target: "trie::parallel_state_root", "calculating state root");
130        let mut trie_updates = TrieUpdates::default();
131
132        let provider = self.factory.database_provider_ro()?;
133
134        let walker = TrieWalker::<_>::state_trie(
135            provider.account_trie_cursor().map_err(ProviderError::Database)?,
136            self.prefix_sets.account_prefix_set,
137        )
138        .with_deletions_retained(retain_updates);
139        let mut account_node_iter = TrieNodeIter::state_trie(
140            walker,
141            provider.hashed_account_cursor().map_err(ProviderError::Database)?,
142        );
143
144        let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
145        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
146        while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
147            match node {
148                TrieElement::Branch(node) => {
149                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
150                }
151                TrieElement::Leaf(hashed_address, account) => {
152                    let storage_root_result = match storage_roots.remove(&hashed_address) {
153                        Some(rx) => rx.recv().map_err(|_| {
154                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
155                                DatabaseError::Other(format!(
156                                    "channel closed for {hashed_address}"
157                                )),
158                            ))
159                        })??,
160                        // Since we do not store all intermediate nodes in the database, there might
161                        // be a possibility of re-adding a non-modified leaf to the hash builder.
162                        None => {
163                            tracker.inc_missed_leaves();
164                            StorageRoot::new_hashed(
165                                &provider,
166                                &provider,
167                                hashed_address,
168                                Default::default(),
169                                #[cfg(feature = "metrics")]
170                                self.metrics.storage_trie.clone(),
171                            )
172                            .calculate(retain_updates)?
173                        }
174                    };
175
176                    let (storage_root, _, updates) = match storage_root_result {
177                        reth_trie::StorageRootProgress::Complete(root, _, updates) => (root, (), updates),
178                        reth_trie::StorageRootProgress::Progress(..) => {
179                            return Err(ParallelStateRootError::StorageRoot(
180                                StorageRootError::Database(DatabaseError::Other(
181                                    "StorageRoot returned Progress variant in parallel trie calculation".to_string()
182                                ))
183                            ))
184                        }
185                    };
186
187                    if retain_updates {
188                        trie_updates.insert_storage_updates(hashed_address, updates);
189                    }
190
191                    account_rlp.clear();
192                    let account = account.into_trie_account(storage_root);
193                    account.encode(&mut account_rlp as &mut dyn BufMut);
194                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
195                }
196            }
197        }
198
199        let root = hash_builder.root();
200
201        let removed_keys = account_node_iter.walker.take_removed_keys();
202        trie_updates.finalize(hash_builder, removed_keys, self.prefix_sets.destroyed_accounts);
203
204        let stats = tracker.finish();
205
206        #[cfg(feature = "metrics")]
207        self.metrics.record_state_trie(stats);
208
209        trace!(
210            target: "trie::parallel_state_root",
211            %root,
212            duration = ?stats.duration(),
213            branches_added = stats.branches_added(),
214            leaves_added = stats.leaves_added(),
215            missed_leaves = stats.missed_leaves(),
216            precomputed_storage_roots = stats.precomputed_storage_roots(),
217            "Calculated state root"
218        );
219
220        Ok((root, trie_updates))
221    }
222}
223
224/// Error during parallel state root calculation.
225#[derive(Error, Debug)]
226pub enum ParallelStateRootError {
227    /// Error while calculating storage root.
228    #[error(transparent)]
229    StorageRoot(#[from] StorageRootError),
230    /// Provider error.
231    #[error(transparent)]
232    Provider(#[from] ProviderError),
233    /// Sparse trie error.
234    #[error(transparent)]
235    SparseTrie(#[from] SparseTrieError),
236    /// Other unspecified error.
237    #[error("{_0}")]
238    Other(String),
239}
240
241impl From<ParallelStateRootError> for ProviderError {
242    fn from(error: ParallelStateRootError) -> Self {
243        match error {
244            ParallelStateRootError::Provider(error) => error,
245            ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
246                Self::Database(error)
247            }
248            ParallelStateRootError::SparseTrie(error) => Self::other(error),
249            ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
250        }
251    }
252}
253
254impl From<alloy_rlp::Error> for ParallelStateRootError {
255    fn from(error: alloy_rlp::Error) -> Self {
256        Self::Provider(ProviderError::Rlp(error))
257    }
258}
259
260impl From<StateProofError> for ParallelStateRootError {
261    fn from(error: StateProofError) -> Self {
262        match error {
263            StateProofError::Database(err) => Self::Provider(ProviderError::Database(err)),
264            StateProofError::Rlp(err) => Self::Provider(ProviderError::Rlp(err)),
265            StateProofError::TrieInconsistency(msg) => {
266                Self::Provider(ProviderError::TrieWitnessError(msg))
267            }
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use alloy_primitives::{keccak256, Address, U256};
276    use rand::Rng;
277    use reth_chainspec::{ChainSpec, EthChainSpec};
278    use reth_ethereum_primitives::{Block, BlockBody};
279    use reth_primitives_traits::{Account, RecoveredBlock, SealedBlock, StorageEntry};
280    use reth_provider::{
281        test_utils::create_test_provider_factory_with_chain_spec, BlockWriter, ExecutionOutcome,
282        HashingWriter,
283    };
284    use reth_trie::{test_utils, HashedPostState, HashedStorage};
285    use std::sync::Arc;
286
287    #[tokio::test]
288    async fn random_parallel_root() {
289        let chain_spec = Arc::new(ChainSpec::default());
290        let anchor_hash = chain_spec.genesis_hash();
291        let factory = create_test_provider_factory_with_chain_spec(chain_spec.clone());
292        let changeset_cache = reth_trie_db::ChangesetCache::new();
293        let overlay_builder = reth_provider::providers::OverlayBuilder::<
294            reth_ethereum_primitives::EthPrimitives,
295        >::new(anchor_hash, changeset_cache);
296        let mut overlay_factory = reth_provider::providers::OverlayStateProviderFactory::new(
297            factory.clone(),
298            overlay_builder.clone(),
299        );
300
301        let mut rng = rand::rng();
302        let mut state = (0..100)
303            .map(|_| {
304                let address = Address::random();
305                let account =
306                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
307                let mut storage = HashMap::<B256, U256>::default();
308                let has_storage = rng.random_bool(0.7);
309                if has_storage {
310                    for _ in 0..100 {
311                        storage.insert(
312                            B256::from(U256::from(rng.random::<u64>())),
313                            U256::from(rng.random::<u64>()),
314                        );
315                    }
316                }
317                (address, (account, storage))
318            })
319            .collect::<HashMap<_, _>>();
320
321        {
322            let provider_rw = factory.provider_rw().unwrap();
323            let genesis_block = RecoveredBlock::new_sealed(
324                SealedBlock::<Block>::seal_parts(
325                    chain_spec.genesis_header().clone(),
326                    BlockBody::default(),
327                ),
328                vec![],
329            );
330            provider_rw
331                .append_blocks_with_state(
332                    vec![genesis_block],
333                    &ExecutionOutcome::default(),
334                    Default::default(),
335                )
336                .unwrap();
337            provider_rw
338                .insert_account_for_hashing(
339                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
340                )
341                .unwrap();
342            provider_rw
343                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
344                    (
345                        *address,
346                        storage
347                            .iter()
348                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
349                    )
350                }))
351                .unwrap();
352            provider_rw.commit().unwrap();
353        }
354
355        let runtime = reth_tasks::Runtime::test();
356        assert_eq!(
357            ParallelStateRoot::new(overlay_factory.clone(), Default::default(), runtime.clone())
358                .incremental_root()
359                .unwrap(),
360            test_utils::state_root(state.clone())
361        );
362
363        let mut hashed_state = HashedPostState::default();
364        for (address, (account, storage)) in &mut state {
365            let hashed_address = keccak256(address);
366
367            let should_update_account = rng.random_bool(0.5);
368            if should_update_account {
369                *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
370                hashed_state.accounts.insert(hashed_address, Some(*account));
371            }
372
373            let should_update_storage = rng.random_bool(0.3);
374            if should_update_storage {
375                for (slot, value) in storage.iter_mut() {
376                    let hashed_slot = keccak256(slot);
377                    *value = U256::from(rng.random::<u64>());
378                    hashed_state
379                        .storages
380                        .entry(hashed_address)
381                        .or_insert_with(HashedStorage::default)
382                        .storage
383                        .insert(hashed_slot, *value);
384                }
385            }
386        }
387
388        let prefix_sets = hashed_state.construct_prefix_sets();
389        overlay_factory = reth_provider::providers::OverlayStateProviderFactory::new(
390            factory,
391            overlay_builder.with_hashed_state_overlay(Some(Arc::new(hashed_state.into_sorted()))),
392        );
393
394        assert_eq!(
395            ParallelStateRoot::new(overlay_factory, prefix_sets.freeze(), runtime)
396                .incremental_root()
397                .unwrap(),
398            test_utils::state_root(state)
399        );
400    }
401}