1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::{StateProofError, StorageRootError};
8use reth_provider::{DatabaseProviderROFactory, ProviderError};
9use reth_storage_errors::db::DatabaseError;
10use reth_trie::{
11 hashed_cursor::HashedCursorFactory,
12 node_iter::{TrieElement, TrieNodeIter},
13 prefix_set::TriePrefixSets,
14 trie_cursor::TrieCursorFactory,
15 updates::TrieUpdates,
16 walker::TrieWalker,
17 HashBuilder, Nibbles, StorageRoot, TRIE_ACCOUNT_RLP_MAX_SIZE,
18};
19use std::{
20 collections::HashMap,
21 sync::{mpsc, OnceLock},
22 time::Duration,
23};
24use thiserror::Error;
25use tokio::runtime::{Builder, Handle, Runtime};
26use tracing::*;
27
28#[derive(Debug)]
39pub struct ParallelStateRoot<Factory> {
40 factory: Factory,
42 prefix_sets: TriePrefixSets,
44 #[cfg(feature = "metrics")]
46 metrics: ParallelStateRootMetrics,
47}
48
49impl<Factory> ParallelStateRoot<Factory> {
50 pub fn new(factory: Factory, prefix_sets: TriePrefixSets) -> Self {
52 Self {
53 factory,
54 prefix_sets,
55 #[cfg(feature = "metrics")]
56 metrics: ParallelStateRootMetrics::default(),
57 }
58 }
59}
60
61impl<Factory> ParallelStateRoot<Factory>
62where
63 Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
64 + Clone
65 + Send
66 + 'static,
67{
68 pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
70 self.calculate(false).map(|(root, _)| root)
71 }
72
73 pub fn incremental_root_with_updates(
75 self,
76 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
77 self.calculate(true)
78 }
79
80 fn calculate(
83 self,
84 retain_updates: bool,
85 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
86 let mut tracker = ParallelTrieTracker::default();
87 let storage_root_targets = StorageRootTargets::new(
88 self.prefix_sets
89 .account_prefix_set
90 .iter()
91 .map(|nibbles| B256::from_slice(&nibbles.pack())),
92 self.prefix_sets.storage_prefix_sets,
93 );
94
95 tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
97 debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
98 let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
99
100 let handle = get_runtime_handle();
102
103 for (hashed_address, prefix_set) in
104 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
105 {
106 let factory = self.factory.clone();
107 #[cfg(feature = "metrics")]
108 let metrics = self.metrics.storage_trie.clone();
109
110 let (tx, rx) = mpsc::sync_channel(1);
111
112 drop(handle.spawn_blocking(move || {
114 let result = (|| -> Result<_, ParallelStateRootError> {
115 let provider = factory.database_provider_ro()?;
116 Ok(StorageRoot::new_hashed(
117 &provider,
118 &provider,
119 hashed_address,
120 prefix_set,
121 #[cfg(feature = "metrics")]
122 metrics,
123 )
124 .calculate(retain_updates)?)
125 })();
126 let _ = tx.send(result);
127 }));
128 storage_roots.insert(hashed_address, rx);
129 }
130
131 trace!(target: "trie::parallel_state_root", "calculating state root");
132 let mut trie_updates = TrieUpdates::default();
133
134 let provider = self.factory.database_provider_ro()?;
135
136 let walker = TrieWalker::<_>::state_trie(
137 provider.account_trie_cursor().map_err(ProviderError::Database)?,
138 self.prefix_sets.account_prefix_set,
139 )
140 .with_deletions_retained(retain_updates);
141 let mut account_node_iter = TrieNodeIter::state_trie(
142 walker,
143 provider.hashed_account_cursor().map_err(ProviderError::Database)?,
144 );
145
146 let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
147 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
148 while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
149 match node {
150 TrieElement::Branch(node) => {
151 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
152 }
153 TrieElement::Leaf(hashed_address, account) => {
154 let storage_root_result = match storage_roots.remove(&hashed_address) {
155 Some(rx) => rx.recv().map_err(|_| {
156 ParallelStateRootError::StorageRoot(StorageRootError::Database(
157 DatabaseError::Other(format!(
158 "channel closed for {hashed_address}"
159 )),
160 ))
161 })??,
162 None => {
165 tracker.inc_missed_leaves();
166 StorageRoot::new_hashed(
167 &provider,
168 &provider,
169 hashed_address,
170 Default::default(),
171 #[cfg(feature = "metrics")]
172 self.metrics.storage_trie.clone(),
173 )
174 .calculate(retain_updates)?
175 }
176 };
177
178 let (storage_root, _, updates) = match storage_root_result {
179 reth_trie::StorageRootProgress::Complete(root, _, updates) => (root, (), updates),
180 reth_trie::StorageRootProgress::Progress(..) => {
181 return Err(ParallelStateRootError::StorageRoot(
182 StorageRootError::Database(DatabaseError::Other(
183 "StorageRoot returned Progress variant in parallel trie calculation".to_string()
184 ))
185 ))
186 }
187 };
188
189 if retain_updates {
190 trie_updates.insert_storage_updates(hashed_address, updates);
191 }
192
193 account_rlp.clear();
194 let account = account.into_trie_account(storage_root);
195 account.encode(&mut account_rlp as &mut dyn BufMut);
196 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
197 }
198 }
199 }
200
201 let root = hash_builder.root();
202
203 let removed_keys = account_node_iter.walker.take_removed_keys();
204 trie_updates.finalize(hash_builder, removed_keys, self.prefix_sets.destroyed_accounts);
205
206 let stats = tracker.finish();
207
208 #[cfg(feature = "metrics")]
209 self.metrics.record_state_trie(stats);
210
211 trace!(
212 target: "trie::parallel_state_root",
213 %root,
214 duration = ?stats.duration(),
215 branches_added = stats.branches_added(),
216 leaves_added = stats.leaves_added(),
217 missed_leaves = stats.missed_leaves(),
218 precomputed_storage_roots = stats.precomputed_storage_roots(),
219 "Calculated state root"
220 );
221
222 Ok((root, trie_updates))
223 }
224}
225
226#[derive(Error, Debug)]
228pub enum ParallelStateRootError {
229 #[error(transparent)]
231 StorageRoot(#[from] StorageRootError),
232 #[error(transparent)]
234 Provider(#[from] ProviderError),
235 #[error("{_0}")]
237 Other(String),
238}
239
240impl From<ParallelStateRootError> for ProviderError {
241 fn from(error: ParallelStateRootError) -> Self {
242 match error {
243 ParallelStateRootError::Provider(error) => error,
244 ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
245 Self::Database(error)
246 }
247 ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
248 }
249 }
250}
251
252impl From<alloy_rlp::Error> for ParallelStateRootError {
253 fn from(error: alloy_rlp::Error) -> Self {
254 Self::Provider(ProviderError::Rlp(error))
255 }
256}
257
258impl From<StateProofError> for ParallelStateRootError {
259 fn from(error: StateProofError) -> Self {
260 match error {
261 StateProofError::Database(err) => Self::Provider(ProviderError::Database(err)),
262 StateProofError::Rlp(err) => Self::Provider(ProviderError::Rlp(err)),
263 }
264 }
265}
266
267fn get_runtime_handle() -> Handle {
270 Handle::try_current().unwrap_or_else(|_| {
271 static RT: OnceLock<Runtime> = OnceLock::new();
273
274 let rt = RT.get_or_init(|| {
275 Builder::new_multi_thread()
276 .thread_keep_alive(Duration::from_secs(15))
280 .build()
281 .expect("Failed to create tokio runtime")
282 });
283
284 rt.handle().clone()
285 })
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use alloy_primitives::{keccak256, Address, U256};
292 use rand::Rng;
293 use reth_primitives_traits::{Account, StorageEntry};
294 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
295 use reth_trie::{test_utils, HashedPostState, HashedStorage};
296 use std::sync::Arc;
297
298 #[tokio::test]
299 async fn random_parallel_root() {
300 let factory = create_test_provider_factory();
301 let changeset_cache = reth_trie_db::ChangesetCache::new();
302 let mut overlay_factory = reth_provider::providers::OverlayStateProviderFactory::new(
303 factory.clone(),
304 changeset_cache,
305 );
306
307 let mut rng = rand::rng();
308 let mut state = (0..100)
309 .map(|_| {
310 let address = Address::random();
311 let account =
312 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
313 let mut storage = HashMap::<B256, U256>::default();
314 let has_storage = rng.random_bool(0.7);
315 if has_storage {
316 for _ in 0..100 {
317 storage.insert(
318 B256::from(U256::from(rng.random::<u64>())),
319 U256::from(rng.random::<u64>()),
320 );
321 }
322 }
323 (address, (account, storage))
324 })
325 .collect::<HashMap<_, _>>();
326
327 {
328 let provider_rw = factory.provider_rw().unwrap();
329 provider_rw
330 .insert_account_for_hashing(
331 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
332 )
333 .unwrap();
334 provider_rw
335 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
336 (
337 *address,
338 storage
339 .iter()
340 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
341 )
342 }))
343 .unwrap();
344 provider_rw.commit().unwrap();
345 }
346
347 assert_eq!(
348 ParallelStateRoot::new(overlay_factory.clone(), Default::default())
349 .incremental_root()
350 .unwrap(),
351 test_utils::state_root(state.clone())
352 );
353
354 let mut hashed_state = HashedPostState::default();
355 for (address, (account, storage)) in &mut state {
356 let hashed_address = keccak256(address);
357
358 let should_update_account = rng.random_bool(0.5);
359 if should_update_account {
360 *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
361 hashed_state.accounts.insert(hashed_address, Some(*account));
362 }
363
364 let should_update_storage = rng.random_bool(0.3);
365 if should_update_storage {
366 for (slot, value) in storage.iter_mut() {
367 let hashed_slot = keccak256(slot);
368 *value = U256::from(rng.random::<u64>());
369 hashed_state
370 .storages
371 .entry(hashed_address)
372 .or_insert_with(HashedStorage::default)
373 .storage
374 .insert(hashed_slot, *value);
375 }
376 }
377 }
378
379 let prefix_sets = hashed_state.construct_prefix_sets();
380 overlay_factory =
381 overlay_factory.with_hashed_state_overlay(Some(Arc::new(hashed_state.into_sorted())));
382
383 assert_eq!(
384 ParallelStateRoot::new(overlay_factory, prefix_sets.freeze())
385 .incremental_root()
386 .unwrap(),
387 test_utils::state_root(state)
388 );
389 }
390}