1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::{SparseTrieError, StateProofError, StorageRootError};
8use reth_provider::{DatabaseProviderROFactory, ProviderError};
9use reth_storage_errors::db::DatabaseError;
10use reth_tasks::Runtime;
11use reth_trie::{
12 hashed_cursor::HashedCursorFactory,
13 node_iter::{TrieElement, TrieNodeIter},
14 prefix_set::TriePrefixSets,
15 trie_cursor::TrieCursorFactory,
16 updates::TrieUpdates,
17 walker::TrieWalker,
18 HashBuilder, Nibbles, StorageRoot, TRIE_ACCOUNT_RLP_MAX_SIZE,
19};
20use std::{collections::HashMap, sync::mpsc};
21use thiserror::Error;
22use tracing::*;
23
24#[derive(Debug)]
35pub struct ParallelStateRoot<Factory> {
36 factory: Factory,
38 prefix_sets: TriePrefixSets,
40 runtime: Runtime,
42 #[cfg(feature = "metrics")]
44 metrics: ParallelStateRootMetrics,
45}
46
47impl<Factory> ParallelStateRoot<Factory> {
48 pub fn new(factory: Factory, prefix_sets: TriePrefixSets, runtime: Runtime) -> Self {
50 Self {
51 factory,
52 prefix_sets,
53 runtime,
54 #[cfg(feature = "metrics")]
55 metrics: ParallelStateRootMetrics::default(),
56 }
57 }
58}
59
60impl<Factory> ParallelStateRoot<Factory>
61where
62 Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
63 + Clone
64 + Send
65 + 'static,
66{
67 pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
69 self.calculate(false).map(|(root, _)| root)
70 }
71
72 pub fn incremental_root_with_updates(
74 self,
75 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
76 self.calculate(true)
77 }
78
79 fn calculate(
82 self,
83 retain_updates: bool,
84 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
85 let mut tracker = ParallelTrieTracker::default();
86 let storage_root_targets = StorageRootTargets::new(
87 self.prefix_sets
88 .account_prefix_set
89 .iter()
90 .map(|nibbles| B256::from_slice(&nibbles.pack())),
91 self.prefix_sets.storage_prefix_sets,
92 );
93
94 tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
96 debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
97 let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
98
99 let handle = self.runtime.handle().clone();
100
101 for (hashed_address, prefix_set) in
102 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
103 {
104 let factory = self.factory.clone();
105 #[cfg(feature = "metrics")]
106 let metrics = self.metrics.storage_trie.clone();
107
108 let (tx, rx) = mpsc::sync_channel(1);
109
110 drop(handle.spawn_blocking(move || {
112 let result = (|| -> Result<_, ParallelStateRootError> {
113 let provider = factory.database_provider_ro()?;
114 Ok(StorageRoot::new_hashed(
115 &provider,
116 &provider,
117 hashed_address,
118 prefix_set,
119 #[cfg(feature = "metrics")]
120 metrics,
121 )
122 .calculate(retain_updates)?)
123 })();
124 let _ = tx.send(result);
125 }));
126 storage_roots.insert(hashed_address, rx);
127 }
128
129 trace!(target: "trie::parallel_state_root", "calculating state root");
130 let mut trie_updates = TrieUpdates::default();
131
132 let provider = self.factory.database_provider_ro()?;
133
134 let walker = TrieWalker::<_>::state_trie(
135 provider.account_trie_cursor().map_err(ProviderError::Database)?,
136 self.prefix_sets.account_prefix_set,
137 )
138 .with_deletions_retained(retain_updates);
139 let mut account_node_iter = TrieNodeIter::state_trie(
140 walker,
141 provider.hashed_account_cursor().map_err(ProviderError::Database)?,
142 );
143
144 let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
145 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
146 while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
147 match node {
148 TrieElement::Branch(node) => {
149 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
150 }
151 TrieElement::Leaf(hashed_address, account) => {
152 let storage_root_result = match storage_roots.remove(&hashed_address) {
153 Some(rx) => rx.recv().map_err(|_| {
154 ParallelStateRootError::StorageRoot(StorageRootError::Database(
155 DatabaseError::Other(format!(
156 "channel closed for {hashed_address}"
157 )),
158 ))
159 })??,
160 None => {
163 tracker.inc_missed_leaves();
164 StorageRoot::new_hashed(
165 &provider,
166 &provider,
167 hashed_address,
168 Default::default(),
169 #[cfg(feature = "metrics")]
170 self.metrics.storage_trie.clone(),
171 )
172 .calculate(retain_updates)?
173 }
174 };
175
176 let (storage_root, _, updates) = match storage_root_result {
177 reth_trie::StorageRootProgress::Complete(root, _, updates) => (root, (), updates),
178 reth_trie::StorageRootProgress::Progress(..) => {
179 return Err(ParallelStateRootError::StorageRoot(
180 StorageRootError::Database(DatabaseError::Other(
181 "StorageRoot returned Progress variant in parallel trie calculation".to_string()
182 ))
183 ))
184 }
185 };
186
187 if retain_updates {
188 trie_updates.insert_storage_updates(hashed_address, updates);
189 }
190
191 account_rlp.clear();
192 let account = account.into_trie_account(storage_root);
193 account.encode(&mut account_rlp as &mut dyn BufMut);
194 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
195 }
196 }
197 }
198
199 let root = hash_builder.root();
200
201 let removed_keys = account_node_iter.walker.take_removed_keys();
202 trie_updates.finalize(hash_builder, removed_keys, self.prefix_sets.destroyed_accounts);
203
204 let stats = tracker.finish();
205
206 #[cfg(feature = "metrics")]
207 self.metrics.record_state_trie(stats);
208
209 trace!(
210 target: "trie::parallel_state_root",
211 %root,
212 duration = ?stats.duration(),
213 branches_added = stats.branches_added(),
214 leaves_added = stats.leaves_added(),
215 missed_leaves = stats.missed_leaves(),
216 precomputed_storage_roots = stats.precomputed_storage_roots(),
217 "Calculated state root"
218 );
219
220 Ok((root, trie_updates))
221 }
222}
223
224#[derive(Error, Debug)]
226pub enum ParallelStateRootError {
227 #[error(transparent)]
229 StorageRoot(#[from] StorageRootError),
230 #[error(transparent)]
232 Provider(#[from] ProviderError),
233 #[error(transparent)]
235 SparseTrie(#[from] SparseTrieError),
236 #[error("{_0}")]
238 Other(String),
239}
240
241impl From<ParallelStateRootError> for ProviderError {
242 fn from(error: ParallelStateRootError) -> Self {
243 match error {
244 ParallelStateRootError::Provider(error) => error,
245 ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
246 Self::Database(error)
247 }
248 ParallelStateRootError::SparseTrie(error) => Self::other(error),
249 ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
250 }
251 }
252}
253
254impl From<alloy_rlp::Error> for ParallelStateRootError {
255 fn from(error: alloy_rlp::Error) -> Self {
256 Self::Provider(ProviderError::Rlp(error))
257 }
258}
259
260impl From<StateProofError> for ParallelStateRootError {
261 fn from(error: StateProofError) -> Self {
262 match error {
263 StateProofError::Database(err) => Self::Provider(ProviderError::Database(err)),
264 StateProofError::Rlp(err) => Self::Provider(ProviderError::Rlp(err)),
265 StateProofError::TrieInconsistency(msg) => {
266 Self::Provider(ProviderError::TrieWitnessError(msg))
267 }
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use alloy_primitives::{keccak256, Address, U256};
276 use rand::Rng;
277 use reth_chainspec::{ChainSpec, EthChainSpec};
278 use reth_ethereum_primitives::{Block, BlockBody};
279 use reth_primitives_traits::{Account, RecoveredBlock, SealedBlock, StorageEntry};
280 use reth_provider::{
281 test_utils::create_test_provider_factory_with_chain_spec, BlockWriter, ExecutionOutcome,
282 HashingWriter,
283 };
284 use reth_trie::{test_utils, HashedPostState, HashedStorage};
285 use std::sync::Arc;
286
287 #[tokio::test]
288 async fn random_parallel_root() {
289 let chain_spec = Arc::new(ChainSpec::default());
290 let anchor_hash = chain_spec.genesis_hash();
291 let factory = create_test_provider_factory_with_chain_spec(chain_spec.clone());
292 let changeset_cache = reth_trie_db::ChangesetCache::new();
293 let overlay_builder = reth_provider::providers::OverlayBuilder::<
294 reth_ethereum_primitives::EthPrimitives,
295 >::new(anchor_hash, changeset_cache);
296 let mut overlay_factory = reth_provider::providers::OverlayStateProviderFactory::new(
297 factory.clone(),
298 overlay_builder.clone(),
299 );
300
301 let mut rng = rand::rng();
302 let mut state = (0..100)
303 .map(|_| {
304 let address = Address::random();
305 let account =
306 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
307 let mut storage = HashMap::<B256, U256>::default();
308 let has_storage = rng.random_bool(0.7);
309 if has_storage {
310 for _ in 0..100 {
311 storage.insert(
312 B256::from(U256::from(rng.random::<u64>())),
313 U256::from(rng.random::<u64>()),
314 );
315 }
316 }
317 (address, (account, storage))
318 })
319 .collect::<HashMap<_, _>>();
320
321 {
322 let provider_rw = factory.provider_rw().unwrap();
323 let genesis_block = RecoveredBlock::new_sealed(
324 SealedBlock::<Block>::seal_parts(
325 chain_spec.genesis_header().clone(),
326 BlockBody::default(),
327 ),
328 vec![],
329 );
330 provider_rw
331 .append_blocks_with_state(
332 vec![genesis_block],
333 &ExecutionOutcome::default(),
334 Default::default(),
335 )
336 .unwrap();
337 provider_rw
338 .insert_account_for_hashing(
339 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
340 )
341 .unwrap();
342 provider_rw
343 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
344 (
345 *address,
346 storage
347 .iter()
348 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
349 )
350 }))
351 .unwrap();
352 provider_rw.commit().unwrap();
353 }
354
355 let runtime = reth_tasks::Runtime::test();
356 assert_eq!(
357 ParallelStateRoot::new(overlay_factory.clone(), Default::default(), runtime.clone())
358 .incremental_root()
359 .unwrap(),
360 test_utils::state_root(state.clone())
361 );
362
363 let mut hashed_state = HashedPostState::default();
364 for (address, (account, storage)) in &mut state {
365 let hashed_address = keccak256(address);
366
367 let should_update_account = rng.random_bool(0.5);
368 if should_update_account {
369 *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
370 hashed_state.accounts.insert(hashed_address, Some(*account));
371 }
372
373 let should_update_storage = rng.random_bool(0.3);
374 if should_update_storage {
375 for (slot, value) in storage.iter_mut() {
376 let hashed_slot = keccak256(slot);
377 *value = U256::from(rng.random::<u64>());
378 hashed_state
379 .storages
380 .entry(hashed_address)
381 .or_insert_with(HashedStorage::default)
382 .storage
383 .insert(hashed_slot, *value);
384 }
385 }
386 }
387
388 let prefix_sets = hashed_state.construct_prefix_sets();
389 overlay_factory = reth_provider::providers::OverlayStateProviderFactory::new(
390 factory,
391 overlay_builder.with_hashed_state_overlay(Some(Arc::new(hashed_state.into_sorted()))),
392 );
393
394 assert_eq!(
395 ParallelStateRoot::new(overlay_factory, prefix_sets.freeze(), runtime)
396 .incremental_root()
397 .unwrap(),
398 test_utils::state_root(state)
399 );
400 }
401}