reth_trie_parallel/
root.rs

1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::StorageRootError;
8use reth_provider::{
9    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
10};
11use reth_storage_errors::db::DatabaseError;
12use reth_trie::{
13    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
14    node_iter::{TrieElement, TrieNodeIter},
15    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
16    updates::TrieUpdates,
17    walker::TrieWalker,
18    HashBuilder, Nibbles, StorageRoot, TrieInput, TRIE_ACCOUNT_RLP_MAX_SIZE,
19};
20use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
21use std::{
22    collections::HashMap,
23    sync::{mpsc, Arc, OnceLock},
24    time::Duration,
25};
26use thiserror::Error;
27use tokio::runtime::{Builder, Handle, Runtime};
28use tracing::*;
29
30/// Parallel incremental state root calculator.
31///
32/// The calculator starts off by launching tasks to compute storage roots.
33/// Then, it immediately starts walking the state trie updating the necessary trie
34/// nodes in the process. Upon encountering a leaf node, it will poll the storage root
35/// task for the corresponding hashed address.
36///
37/// Internally, the calculator uses [`ConsistentDbView`] since
38/// it needs to rely on database state saying the same until
39/// the last transaction is open.
40/// See docs of using [`ConsistentDbView`] for caveats.
41///
42/// Note: This implementation only serves as a fallback for the sparse trie-based
43/// state root calculation. The sparse trie approach is more efficient as it avoids traversing
44/// the entire trie, only operating on the modified parts.
45#[derive(Debug)]
46pub struct ParallelStateRoot<Factory> {
47    /// Consistent view of the database.
48    view: ConsistentDbView<Factory>,
49    /// Trie input.
50    input: TrieInput,
51    /// Parallel state root metrics.
52    #[cfg(feature = "metrics")]
53    metrics: ParallelStateRootMetrics,
54}
55
56impl<Factory> ParallelStateRoot<Factory> {
57    /// Create new parallel state root calculator.
58    pub fn new(view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
59        Self {
60            view,
61            input,
62            #[cfg(feature = "metrics")]
63            metrics: ParallelStateRootMetrics::default(),
64        }
65    }
66}
67
68impl<Factory> ParallelStateRoot<Factory>
69where
70    Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + Send + Sync + 'static,
71{
72    /// Calculate incremental state root in parallel.
73    pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
74        self.calculate(false).map(|(root, _)| root)
75    }
76
77    /// Calculate incremental state root with updates in parallel.
78    pub fn incremental_root_with_updates(
79        self,
80    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
81        self.calculate(true)
82    }
83
84    /// Computes the state root by calculating storage roots in parallel for modified accounts,
85    /// then walking the state trie to build the final state root hash.
86    fn calculate(
87        self,
88        retain_updates: bool,
89    ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
90        let mut tracker = ParallelTrieTracker::default();
91        let trie_nodes_sorted = Arc::new(self.input.nodes.into_sorted());
92        let hashed_state_sorted = Arc::new(self.input.state.into_sorted());
93        let prefix_sets = self.input.prefix_sets.freeze();
94        let storage_root_targets = StorageRootTargets::new(
95            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
96            prefix_sets.storage_prefix_sets,
97        );
98
99        // Pre-calculate storage roots in parallel for accounts which were changed.
100        tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
101        debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
102        let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
103
104        // Get runtime handle once outside the loop
105        let handle = get_runtime_handle();
106
107        for (hashed_address, prefix_set) in
108            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
109        {
110            let view = self.view.clone();
111            let hashed_state_sorted = hashed_state_sorted.clone();
112            let trie_nodes_sorted = trie_nodes_sorted.clone();
113            #[cfg(feature = "metrics")]
114            let metrics = self.metrics.storage_trie.clone();
115
116            let (tx, rx) = mpsc::sync_channel(1);
117
118            // Spawn a blocking task to calculate account's storage root from database I/O
119            drop(handle.spawn_blocking(move || {
120                let result = (|| -> Result<_, ParallelStateRootError> {
121                    let provider_ro = view.provider_ro()?;
122                    let trie_cursor_factory = InMemoryTrieCursorFactory::new(
123                        DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
124                        &trie_nodes_sorted,
125                    );
126                    let hashed_state = HashedPostStateCursorFactory::new(
127                        DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
128                        &hashed_state_sorted,
129                    );
130                    Ok(StorageRoot::new_hashed(
131                        trie_cursor_factory,
132                        hashed_state,
133                        hashed_address,
134                        prefix_set,
135                        #[cfg(feature = "metrics")]
136                        metrics,
137                    )
138                    .calculate(retain_updates)?)
139                })();
140                let _ = tx.send(result);
141            }));
142            storage_roots.insert(hashed_address, rx);
143        }
144
145        trace!(target: "trie::parallel_state_root", "calculating state root");
146        let mut trie_updates = TrieUpdates::default();
147
148        let provider_ro = self.view.provider_ro()?;
149        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
150            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
151            &trie_nodes_sorted,
152        );
153        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
154            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
155            &hashed_state_sorted,
156        );
157
158        let walker = TrieWalker::<_>::state_trie(
159            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
160            prefix_sets.account_prefix_set,
161        )
162        .with_deletions_retained(retain_updates);
163        let mut account_node_iter = TrieNodeIter::state_trie(
164            walker,
165            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
166        );
167
168        let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
169        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
170        while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
171            match node {
172                TrieElement::Branch(node) => {
173                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
174                }
175                TrieElement::Leaf(hashed_address, account) => {
176                    let storage_root_result = match storage_roots.remove(&hashed_address) {
177                        Some(rx) => rx.recv().map_err(|_| {
178                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
179                                DatabaseError::Other(format!(
180                                    "channel closed for {hashed_address}"
181                                )),
182                            ))
183                        })??,
184                        // Since we do not store all intermediate nodes in the database, there might
185                        // be a possibility of re-adding a non-modified leaf to the hash builder.
186                        None => {
187                            tracker.inc_missed_leaves();
188                            StorageRoot::new_hashed(
189                                trie_cursor_factory.clone(),
190                                hashed_cursor_factory.clone(),
191                                hashed_address,
192                                Default::default(),
193                                #[cfg(feature = "metrics")]
194                                self.metrics.storage_trie.clone(),
195                            )
196                            .calculate(retain_updates)?
197                        }
198                    };
199
200                    let (storage_root, _, updates) = match storage_root_result {
201                        reth_trie::StorageRootProgress::Complete(root, _, updates) => (root, (), updates),
202                        reth_trie::StorageRootProgress::Progress(..) => {
203                            return Err(ParallelStateRootError::StorageRoot(
204                                StorageRootError::Database(DatabaseError::Other(
205                                    "StorageRoot returned Progress variant in parallel trie calculation".to_string()
206                                ))
207                            ))
208                        }
209                    };
210
211                    if retain_updates {
212                        trie_updates.insert_storage_updates(hashed_address, updates);
213                    }
214
215                    account_rlp.clear();
216                    let account = account.into_trie_account(storage_root);
217                    account.encode(&mut account_rlp as &mut dyn BufMut);
218                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
219                }
220            }
221        }
222
223        let root = hash_builder.root();
224
225        let removed_keys = account_node_iter.walker.take_removed_keys();
226        trie_updates.finalize(hash_builder, removed_keys, prefix_sets.destroyed_accounts);
227
228        let stats = tracker.finish();
229
230        #[cfg(feature = "metrics")]
231        self.metrics.record_state_trie(stats);
232
233        trace!(
234            target: "trie::parallel_state_root",
235            %root,
236            duration = ?stats.duration(),
237            branches_added = stats.branches_added(),
238            leaves_added = stats.leaves_added(),
239            missed_leaves = stats.missed_leaves(),
240            precomputed_storage_roots = stats.precomputed_storage_roots(),
241            "Calculated state root"
242        );
243
244        Ok((root, trie_updates))
245    }
246}
247
248/// Error during parallel state root calculation.
249#[derive(Error, Debug)]
250pub enum ParallelStateRootError {
251    /// Error while calculating storage root.
252    #[error(transparent)]
253    StorageRoot(#[from] StorageRootError),
254    /// Provider error.
255    #[error(transparent)]
256    Provider(#[from] ProviderError),
257    /// Other unspecified error.
258    #[error("{_0}")]
259    Other(String),
260}
261
262impl From<ParallelStateRootError> for ProviderError {
263    fn from(error: ParallelStateRootError) -> Self {
264        match error {
265            ParallelStateRootError::Provider(error) => error,
266            ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
267                Self::Database(error)
268            }
269            ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
270        }
271    }
272}
273
274impl From<alloy_rlp::Error> for ParallelStateRootError {
275    fn from(error: alloy_rlp::Error) -> Self {
276        Self::Provider(ProviderError::Rlp(error))
277    }
278}
279
280/// Gets or creates a tokio runtime handle for spawning blocking tasks.
281/// This ensures we always have a runtime available for I/O operations.
282fn get_runtime_handle() -> Handle {
283    Handle::try_current().unwrap_or_else(|_| {
284        // Create a new runtime if no runtime is available
285        static RT: OnceLock<Runtime> = OnceLock::new();
286
287        let rt = RT.get_or_init(|| {
288            Builder::new_multi_thread()
289                // Keep the threads alive for at least the block time (12 seconds) plus buffer.
290                // This prevents the costly process of spawning new threads on every
291                // new block, and instead reuses the existing threads.
292                .thread_keep_alive(Duration::from_secs(15))
293                .build()
294                .expect("Failed to create tokio runtime")
295        });
296
297        rt.handle().clone()
298    })
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use alloy_primitives::{keccak256, Address, U256};
305    use rand::Rng;
306    use reth_primitives_traits::{Account, StorageEntry};
307    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
308    use reth_trie::{test_utils, HashedPostState, HashedStorage};
309
310    #[tokio::test]
311    async fn random_parallel_root() {
312        let factory = create_test_provider_factory();
313        let consistent_view = ConsistentDbView::new(factory.clone(), None);
314
315        let mut rng = rand::rng();
316        let mut state = (0..100)
317            .map(|_| {
318                let address = Address::random();
319                let account =
320                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
321                let mut storage = HashMap::<B256, U256>::default();
322                let has_storage = rng.random_bool(0.7);
323                if has_storage {
324                    for _ in 0..100 {
325                        storage.insert(
326                            B256::from(U256::from(rng.random::<u64>())),
327                            U256::from(rng.random::<u64>()),
328                        );
329                    }
330                }
331                (address, (account, storage))
332            })
333            .collect::<HashMap<_, _>>();
334
335        {
336            let provider_rw = factory.provider_rw().unwrap();
337            provider_rw
338                .insert_account_for_hashing(
339                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
340                )
341                .unwrap();
342            provider_rw
343                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
344                    (
345                        *address,
346                        storage
347                            .iter()
348                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
349                    )
350                }))
351                .unwrap();
352            provider_rw.commit().unwrap();
353        }
354
355        assert_eq!(
356            ParallelStateRoot::new(consistent_view.clone(), Default::default())
357                .incremental_root()
358                .unwrap(),
359            test_utils::state_root(state.clone())
360        );
361
362        let mut hashed_state = HashedPostState::default();
363        for (address, (account, storage)) in &mut state {
364            let hashed_address = keccak256(address);
365
366            let should_update_account = rng.random_bool(0.5);
367            if should_update_account {
368                *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
369                hashed_state.accounts.insert(hashed_address, Some(*account));
370            }
371
372            let should_update_storage = rng.random_bool(0.3);
373            if should_update_storage {
374                for (slot, value) in storage.iter_mut() {
375                    let hashed_slot = keccak256(slot);
376                    *value = U256::from(rng.random::<u64>());
377                    hashed_state
378                        .storages
379                        .entry(hashed_address)
380                        .or_insert_with(HashedStorage::default)
381                        .storage
382                        .insert(hashed_slot, *value);
383                }
384            }
385        }
386
387        assert_eq!(
388            ParallelStateRoot::new(consistent_view, TrieInput::from_state(hashed_state))
389                .incremental_root()
390                .unwrap(),
391            test_utils::state_root(state)
392        );
393    }
394}