reth_trie_parallel/
root.rs

1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::{StateProofError, StorageRootError};
8use reth_provider::{DatabaseProviderROFactory, ProviderError};
9use reth_storage_errors::db::DatabaseError;
10use reth_trie::{
11    hashed_cursor::HashedCursorFactory,
12    node_iter::{TrieElement, TrieNodeIter},
13    prefix_set::TriePrefixSets,
14    trie_cursor::TrieCursorFactory,
15    updates::TrieUpdates,
16    walker::TrieWalker,
17    HashBuilder, Nibbles, StorageRoot, TRIE_ACCOUNT_RLP_MAX_SIZE,
18};
19use std::{
20    collections::HashMap,
21    sync::{mpsc, OnceLock},
22    time::Duration,
23};
24use thiserror::Error;
25use tokio::runtime::{Builder, Handle, Runtime};
26use tracing::*;
27
28/// Parallel incremental state root calculator.
29///
30/// The calculator starts off by launching tasks to compute storage roots.
31/// Then, it immediately starts walking the state trie updating the necessary trie
32/// nodes in the process. Upon encountering a leaf node, it will poll the storage root
33/// task for the corresponding hashed address.
34///
35/// Note: This implementation only serves as a fallback for the sparse trie-based
36/// state root calculation. The sparse trie approach is more efficient as it avoids traversing
37/// the entire trie, only operating on the modified parts.
38#[derive(Debug)]
39pub struct ParallelStateRoot<Factory> {
40    /// Factory for creating state providers.
41    factory: Factory,
42    // Prefix sets indicating which portions of the trie need to be recomputed.
43    prefix_sets: TriePrefixSets,
44    /// Parallel state root metrics.
45    #[cfg(feature = "metrics")]
46    metrics: ParallelStateRootMetrics,
47}
48
49impl<Factory> ParallelStateRoot<Factory> {
50    /// Create new parallel state root calculator.
51    pub fn new(factory: Factory, prefix_sets: TriePrefixSets) -> Self {
52        Self {
53            factory,
54            prefix_sets,
55            #[cfg(feature = "metrics")]
56            metrics: ParallelStateRootMetrics::default(),
57        }
58    }
59}
60
61impl<Factory> ParallelStateRoot<Factory>
62where
63    Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
64        + Clone
65        + Send
66        + 'static,
67{
68    /// Calculate incremental state root in parallel.
69    pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
70        self.calculate(false).map(|(root, _)| root)
71    }
72
73    /// Calculate incremental state root with updates in parallel.
74    pub fn incremental_root_with_updates(
75        self,
76    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
77        self.calculate(true)
78    }
79
80    /// Computes the state root by calculating storage roots in parallel for modified accounts,
81    /// then walking the state trie to build the final state root hash.
82    fn calculate(
83        self,
84        retain_updates: bool,
85    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
86        let mut tracker = ParallelTrieTracker::default();
87        let storage_root_targets = StorageRootTargets::new(
88            self.prefix_sets
89                .account_prefix_set
90                .iter()
91                .map(|nibbles| B256::from_slice(&nibbles.pack())),
92            self.prefix_sets.storage_prefix_sets,
93        );
94
95        // Pre-calculate storage roots in parallel for accounts which were changed.
96        tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
97        debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
98        let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
99
100        // Get runtime handle once outside the loop
101        let handle = get_runtime_handle();
102
103        for (hashed_address, prefix_set) in
104            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
105        {
106            let factory = self.factory.clone();
107            #[cfg(feature = "metrics")]
108            let metrics = self.metrics.storage_trie.clone();
109
110            let (tx, rx) = mpsc::sync_channel(1);
111
112            // Spawn a blocking task to calculate account's storage root from database I/O
113            drop(handle.spawn_blocking(move || {
114                let result = (|| -> Result<_, ParallelStateRootError> {
115                    let provider = factory.database_provider_ro()?;
116                    Ok(StorageRoot::new_hashed(
117                        &provider,
118                        &provider,
119                        hashed_address,
120                        prefix_set,
121                        #[cfg(feature = "metrics")]
122                        metrics,
123                    )
124                    .calculate(retain_updates)?)
125                })();
126                let _ = tx.send(result);
127            }));
128            storage_roots.insert(hashed_address, rx);
129        }
130
131        trace!(target: "trie::parallel_state_root", "calculating state root");
132        let mut trie_updates = TrieUpdates::default();
133
134        let provider = self.factory.database_provider_ro()?;
135
136        let walker = TrieWalker::<_>::state_trie(
137            provider.account_trie_cursor().map_err(ProviderError::Database)?,
138            self.prefix_sets.account_prefix_set,
139        )
140        .with_deletions_retained(retain_updates);
141        let mut account_node_iter = TrieNodeIter::state_trie(
142            walker,
143            provider.hashed_account_cursor().map_err(ProviderError::Database)?,
144        );
145
146        let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
147        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
148        while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
149            match node {
150                TrieElement::Branch(node) => {
151                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
152                }
153                TrieElement::Leaf(hashed_address, account) => {
154                    let storage_root_result = match storage_roots.remove(&hashed_address) {
155                        Some(rx) => rx.recv().map_err(|_| {
156                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
157                                DatabaseError::Other(format!(
158                                    "channel closed for {hashed_address}"
159                                )),
160                            ))
161                        })??,
162                        // Since we do not store all intermediate nodes in the database, there might
163                        // be a possibility of re-adding a non-modified leaf to the hash builder.
164                        None => {
165                            tracker.inc_missed_leaves();
166                            StorageRoot::new_hashed(
167                                &provider,
168                                &provider,
169                                hashed_address,
170                                Default::default(),
171                                #[cfg(feature = "metrics")]
172                                self.metrics.storage_trie.clone(),
173                            )
174                            .calculate(retain_updates)?
175                        }
176                    };
177
178                    let (storage_root, _, updates) = match storage_root_result {
179                        reth_trie::StorageRootProgress::Complete(root, _, updates) => (root, (), updates),
180                        reth_trie::StorageRootProgress::Progress(..) => {
181                            return Err(ParallelStateRootError::StorageRoot(
182                                StorageRootError::Database(DatabaseError::Other(
183                                    "StorageRoot returned Progress variant in parallel trie calculation".to_string()
184                                ))
185                            ))
186                        }
187                    };
188
189                    if retain_updates {
190                        trie_updates.insert_storage_updates(hashed_address, updates);
191                    }
192
193                    account_rlp.clear();
194                    let account = account.into_trie_account(storage_root);
195                    account.encode(&mut account_rlp as &mut dyn BufMut);
196                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
197                }
198            }
199        }
200
201        let root = hash_builder.root();
202
203        let removed_keys = account_node_iter.walker.take_removed_keys();
204        trie_updates.finalize(hash_builder, removed_keys, self.prefix_sets.destroyed_accounts);
205
206        let stats = tracker.finish();
207
208        #[cfg(feature = "metrics")]
209        self.metrics.record_state_trie(stats);
210
211        trace!(
212            target: "trie::parallel_state_root",
213            %root,
214            duration = ?stats.duration(),
215            branches_added = stats.branches_added(),
216            leaves_added = stats.leaves_added(),
217            missed_leaves = stats.missed_leaves(),
218            precomputed_storage_roots = stats.precomputed_storage_roots(),
219            "Calculated state root"
220        );
221
222        Ok((root, trie_updates))
223    }
224}
225
226/// Error during parallel state root calculation.
227#[derive(Error, Debug)]
228pub enum ParallelStateRootError {
229    /// Error while calculating storage root.
230    #[error(transparent)]
231    StorageRoot(#[from] StorageRootError),
232    /// Provider error.
233    #[error(transparent)]
234    Provider(#[from] ProviderError),
235    /// Other unspecified error.
236    #[error("{_0}")]
237    Other(String),
238}
239
240impl From<ParallelStateRootError> for ProviderError {
241    fn from(error: ParallelStateRootError) -> Self {
242        match error {
243            ParallelStateRootError::Provider(error) => error,
244            ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
245                Self::Database(error)
246            }
247            ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
248        }
249    }
250}
251
252impl From<alloy_rlp::Error> for ParallelStateRootError {
253    fn from(error: alloy_rlp::Error) -> Self {
254        Self::Provider(ProviderError::Rlp(error))
255    }
256}
257
258impl From<StateProofError> for ParallelStateRootError {
259    fn from(error: StateProofError) -> Self {
260        match error {
261            StateProofError::Database(err) => Self::Provider(ProviderError::Database(err)),
262            StateProofError::Rlp(err) => Self::Provider(ProviderError::Rlp(err)),
263        }
264    }
265}
266
267/// Gets or creates a tokio runtime handle for spawning blocking tasks.
268/// This ensures we always have a runtime available for I/O operations.
269fn get_runtime_handle() -> Handle {
270    Handle::try_current().unwrap_or_else(|_| {
271        // Create a new runtime if no runtime is available
272        static RT: OnceLock<Runtime> = OnceLock::new();
273
274        let rt = RT.get_or_init(|| {
275            Builder::new_multi_thread()
276                // Keep the threads alive for at least the block time (12 seconds) plus buffer.
277                // This prevents the costly process of spawning new threads on every
278                // new block, and instead reuses the existing threads.
279                .thread_keep_alive(Duration::from_secs(15))
280                .build()
281                .expect("Failed to create tokio runtime")
282        });
283
284        rt.handle().clone()
285    })
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use alloy_primitives::{keccak256, Address, U256};
292    use rand::Rng;
293    use reth_primitives_traits::{Account, StorageEntry};
294    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
295    use reth_trie::{test_utils, HashedPostState, HashedStorage};
296    use std::sync::Arc;
297
298    #[tokio::test]
299    async fn random_parallel_root() {
300        let factory = create_test_provider_factory();
301        let changeset_cache = reth_trie_db::ChangesetCache::new();
302        let mut overlay_factory = reth_provider::providers::OverlayStateProviderFactory::new(
303            factory.clone(),
304            changeset_cache,
305        );
306
307        let mut rng = rand::rng();
308        let mut state = (0..100)
309            .map(|_| {
310                let address = Address::random();
311                let account =
312                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
313                let mut storage = HashMap::<B256, U256>::default();
314                let has_storage = rng.random_bool(0.7);
315                if has_storage {
316                    for _ in 0..100 {
317                        storage.insert(
318                            B256::from(U256::from(rng.random::<u64>())),
319                            U256::from(rng.random::<u64>()),
320                        );
321                    }
322                }
323                (address, (account, storage))
324            })
325            .collect::<HashMap<_, _>>();
326
327        {
328            let provider_rw = factory.provider_rw().unwrap();
329            provider_rw
330                .insert_account_for_hashing(
331                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
332                )
333                .unwrap();
334            provider_rw
335                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
336                    (
337                        *address,
338                        storage
339                            .iter()
340                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
341                    )
342                }))
343                .unwrap();
344            provider_rw.commit().unwrap();
345        }
346
347        assert_eq!(
348            ParallelStateRoot::new(overlay_factory.clone(), Default::default())
349                .incremental_root()
350                .unwrap(),
351            test_utils::state_root(state.clone())
352        );
353
354        let mut hashed_state = HashedPostState::default();
355        for (address, (account, storage)) in &mut state {
356            let hashed_address = keccak256(address);
357
358            let should_update_account = rng.random_bool(0.5);
359            if should_update_account {
360                *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
361                hashed_state.accounts.insert(hashed_address, Some(*account));
362            }
363
364            let should_update_storage = rng.random_bool(0.3);
365            if should_update_storage {
366                for (slot, value) in storage.iter_mut() {
367                    let hashed_slot = keccak256(slot);
368                    *value = U256::from(rng.random::<u64>());
369                    hashed_state
370                        .storages
371                        .entry(hashed_address)
372                        .or_insert_with(HashedStorage::default)
373                        .storage
374                        .insert(hashed_slot, *value);
375                }
376            }
377        }
378
379        let prefix_sets = hashed_state.construct_prefix_sets();
380        overlay_factory =
381            overlay_factory.with_hashed_state_overlay(Some(Arc::new(hashed_state.into_sorted())));
382
383        assert_eq!(
384            ParallelStateRoot::new(overlay_factory, prefix_sets.freeze())
385                .incremental_root()
386                .unwrap(),
387            test_utils::state_root(state)
388        );
389    }
390}