1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::{SparseTrieError, StateProofError, StorageRootError};
8use reth_provider::{DatabaseProviderROFactory, ProviderError};
9use reth_storage_errors::db::DatabaseError;
10use reth_tasks::Runtime;
11use reth_trie::{
12 hashed_cursor::HashedCursorFactory,
13 node_iter::{TrieElement, TrieNodeIter},
14 prefix_set::TriePrefixSets,
15 trie_cursor::TrieCursorFactory,
16 updates::TrieUpdates,
17 walker::TrieWalker,
18 HashBuilder, Nibbles, StorageRoot, TRIE_ACCOUNT_RLP_MAX_SIZE,
19};
20use std::{collections::HashMap, sync::mpsc};
21use thiserror::Error;
22use tracing::*;
23
24#[derive(Debug)]
35pub struct ParallelStateRoot<Factory> {
36 factory: Factory,
38 prefix_sets: TriePrefixSets,
40 runtime: Runtime,
42 #[cfg(feature = "metrics")]
44 metrics: ParallelStateRootMetrics,
45}
46
47impl<Factory> ParallelStateRoot<Factory> {
48 pub fn new(factory: Factory, prefix_sets: TriePrefixSets, runtime: Runtime) -> Self {
50 Self {
51 factory,
52 prefix_sets,
53 runtime,
54 #[cfg(feature = "metrics")]
55 metrics: ParallelStateRootMetrics::default(),
56 }
57 }
58}
59
60impl<Factory> ParallelStateRoot<Factory>
61where
62 Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
63 + Clone
64 + Send
65 + 'static,
66{
67 pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
69 self.calculate(false).map(|(root, _)| root)
70 }
71
72 pub fn incremental_root_with_updates(
74 self,
75 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
76 self.calculate(true)
77 }
78
79 fn calculate(
82 self,
83 retain_updates: bool,
84 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
85 let mut tracker = ParallelTrieTracker::default();
86 let storage_root_targets = StorageRootTargets::new(
87 self.prefix_sets
88 .account_prefix_set
89 .iter()
90 .map(|nibbles| B256::from_slice(&nibbles.pack())),
91 self.prefix_sets.storage_prefix_sets,
92 );
93
94 tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
96 debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
97 let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
98
99 let handle = self.runtime.handle().clone();
100
101 for (hashed_address, prefix_set) in
102 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
103 {
104 let factory = self.factory.clone();
105 #[cfg(feature = "metrics")]
106 let metrics = self.metrics.storage_trie.clone();
107
108 let (tx, rx) = mpsc::sync_channel(1);
109
110 drop(handle.spawn_blocking(move || {
112 let result = (|| -> Result<_, ParallelStateRootError> {
113 let provider = factory.database_provider_ro()?;
114 Ok(StorageRoot::new_hashed(
115 &provider,
116 &provider,
117 hashed_address,
118 prefix_set,
119 #[cfg(feature = "metrics")]
120 metrics,
121 )
122 .calculate(retain_updates)?)
123 })();
124 let _ = tx.send(result);
125 }));
126 storage_roots.insert(hashed_address, rx);
127 }
128
129 trace!(target: "trie::parallel_state_root", "calculating state root");
130 let mut trie_updates = TrieUpdates::default();
131
132 let provider = self.factory.database_provider_ro()?;
133
134 let walker = TrieWalker::<_>::state_trie(
135 provider.account_trie_cursor().map_err(ProviderError::Database)?,
136 self.prefix_sets.account_prefix_set,
137 )
138 .with_deletions_retained(retain_updates);
139 let mut account_node_iter = TrieNodeIter::state_trie(
140 walker,
141 provider.hashed_account_cursor().map_err(ProviderError::Database)?,
142 );
143
144 let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
145 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
146 while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
147 match node {
148 TrieElement::Branch(node) => {
149 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
150 }
151 TrieElement::Leaf(hashed_address, account) => {
152 let storage_root_result = match storage_roots.remove(&hashed_address) {
153 Some(rx) => rx.recv().map_err(|_| {
154 ParallelStateRootError::StorageRoot(StorageRootError::Database(
155 DatabaseError::Other(format!(
156 "channel closed for {hashed_address}"
157 )),
158 ))
159 })??,
160 None => {
163 tracker.inc_missed_leaves();
164 StorageRoot::new_hashed(
165 &provider,
166 &provider,
167 hashed_address,
168 Default::default(),
169 #[cfg(feature = "metrics")]
170 self.metrics.storage_trie.clone(),
171 )
172 .calculate(retain_updates)?
173 }
174 };
175
176 let (storage_root, _, updates) = match storage_root_result {
177 reth_trie::StorageRootProgress::Complete(root, _, updates) => (root, (), updates),
178 reth_trie::StorageRootProgress::Progress(..) => {
179 return Err(ParallelStateRootError::StorageRoot(
180 StorageRootError::Database(DatabaseError::Other(
181 "StorageRoot returned Progress variant in parallel trie calculation".to_string()
182 ))
183 ))
184 }
185 };
186
187 if retain_updates {
188 trie_updates.insert_storage_updates(hashed_address, updates);
189 }
190
191 account_rlp.clear();
192 let account = account.into_trie_account(storage_root);
193 account.encode(&mut account_rlp as &mut dyn BufMut);
194 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
195 }
196 }
197 }
198
199 let root = hash_builder.root();
200
201 let removed_keys = account_node_iter.walker.take_removed_keys();
202 trie_updates.finalize(hash_builder, removed_keys, self.prefix_sets.destroyed_accounts);
203
204 let stats = tracker.finish();
205
206 #[cfg(feature = "metrics")]
207 self.metrics.record_state_trie(stats);
208
209 trace!(
210 target: "trie::parallel_state_root",
211 %root,
212 duration = ?stats.duration(),
213 branches_added = stats.branches_added(),
214 leaves_added = stats.leaves_added(),
215 missed_leaves = stats.missed_leaves(),
216 precomputed_storage_roots = stats.precomputed_storage_roots(),
217 "Calculated state root"
218 );
219
220 Ok((root, trie_updates))
221 }
222}
223
224#[derive(Error, Debug)]
226pub enum ParallelStateRootError {
227 #[error(transparent)]
229 StorageRoot(#[from] StorageRootError),
230 #[error(transparent)]
232 Provider(#[from] ProviderError),
233 #[error(transparent)]
235 SparseTrie(#[from] SparseTrieError),
236 #[error("{_0}")]
238 Other(String),
239}
240
241impl From<ParallelStateRootError> for ProviderError {
242 fn from(error: ParallelStateRootError) -> Self {
243 match error {
244 ParallelStateRootError::Provider(error) => error,
245 ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
246 Self::Database(error)
247 }
248 ParallelStateRootError::SparseTrie(error) => Self::other(error),
249 ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
250 }
251 }
252}
253
254impl From<alloy_rlp::Error> for ParallelStateRootError {
255 fn from(error: alloy_rlp::Error) -> Self {
256 Self::Provider(ProviderError::Rlp(error))
257 }
258}
259
260impl From<StateProofError> for ParallelStateRootError {
261 fn from(error: StateProofError) -> Self {
262 match error {
263 StateProofError::Database(err) => Self::Provider(ProviderError::Database(err)),
264 StateProofError::Rlp(err) => Self::Provider(ProviderError::Rlp(err)),
265 StateProofError::TrieInconsistency(msg) => {
266 Self::Provider(ProviderError::TrieWitnessError(msg))
267 }
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use alloy_primitives::{keccak256, Address, U256};
276 use rand::Rng;
277 use reth_primitives_traits::{Account, StorageEntry};
278 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
279 use reth_trie::{test_utils, HashedPostState, HashedStorage};
280 use std::sync::Arc;
281
282 #[tokio::test]
283 async fn random_parallel_root() {
284 let factory = create_test_provider_factory();
285 let changeset_cache = reth_trie_db::ChangesetCache::new();
286 let mut overlay_factory = reth_provider::providers::OverlayStateProviderFactory::new(
287 factory.clone(),
288 changeset_cache,
289 );
290
291 let mut rng = rand::rng();
292 let mut state = (0..100)
293 .map(|_| {
294 let address = Address::random();
295 let account =
296 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
297 let mut storage = HashMap::<B256, U256>::default();
298 let has_storage = rng.random_bool(0.7);
299 if has_storage {
300 for _ in 0..100 {
301 storage.insert(
302 B256::from(U256::from(rng.random::<u64>())),
303 U256::from(rng.random::<u64>()),
304 );
305 }
306 }
307 (address, (account, storage))
308 })
309 .collect::<HashMap<_, _>>();
310
311 {
312 let provider_rw = factory.provider_rw().unwrap();
313 provider_rw
314 .insert_account_for_hashing(
315 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
316 )
317 .unwrap();
318 provider_rw
319 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
320 (
321 *address,
322 storage
323 .iter()
324 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
325 )
326 }))
327 .unwrap();
328 provider_rw.commit().unwrap();
329 }
330
331 let runtime = reth_tasks::Runtime::test();
332 assert_eq!(
333 ParallelStateRoot::new(overlay_factory.clone(), Default::default(), runtime.clone())
334 .incremental_root()
335 .unwrap(),
336 test_utils::state_root(state.clone())
337 );
338
339 let mut hashed_state = HashedPostState::default();
340 for (address, (account, storage)) in &mut state {
341 let hashed_address = keccak256(address);
342
343 let should_update_account = rng.random_bool(0.5);
344 if should_update_account {
345 *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
346 hashed_state.accounts.insert(hashed_address, Some(*account));
347 }
348
349 let should_update_storage = rng.random_bool(0.3);
350 if should_update_storage {
351 for (slot, value) in storage.iter_mut() {
352 let hashed_slot = keccak256(slot);
353 *value = U256::from(rng.random::<u64>());
354 hashed_state
355 .storages
356 .entry(hashed_address)
357 .or_insert_with(HashedStorage::default)
358 .storage
359 .insert(hashed_slot, *value);
360 }
361 }
362 }
363
364 let prefix_sets = hashed_state.construct_prefix_sets();
365 overlay_factory =
366 overlay_factory.with_hashed_state_overlay(Some(Arc::new(hashed_state.into_sorted())));
367
368 assert_eq!(
369 ParallelStateRoot::new(overlay_factory, prefix_sets.freeze(), runtime)
370 .incremental_root()
371 .unwrap(),
372 test_utils::state_root(state)
373 );
374 }
375}