1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::StorageRootError;
8use reth_provider::{
9 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
10};
11use reth_storage_errors::db::DatabaseError;
12use reth_trie::{
13 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
14 node_iter::{TrieElement, TrieNodeIter},
15 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
16 updates::TrieUpdates,
17 walker::TrieWalker,
18 HashBuilder, Nibbles, StorageRoot, TrieInput, TRIE_ACCOUNT_RLP_MAX_SIZE,
19};
20use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
21use std::{
22 collections::HashMap,
23 sync::{mpsc, Arc, OnceLock},
24 time::Duration,
25};
26use thiserror::Error;
27use tokio::runtime::{Builder, Handle, Runtime};
28use tracing::*;
29
30#[derive(Debug)]
46pub struct ParallelStateRoot<Factory> {
47 view: ConsistentDbView<Factory>,
49 input: TrieInput,
51 #[cfg(feature = "metrics")]
53 metrics: ParallelStateRootMetrics,
54}
55
56impl<Factory> ParallelStateRoot<Factory> {
57 pub fn new(view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
59 Self {
60 view,
61 input,
62 #[cfg(feature = "metrics")]
63 metrics: ParallelStateRootMetrics::default(),
64 }
65 }
66}
67
68impl<Factory> ParallelStateRoot<Factory>
69where
70 Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + Send + Sync + 'static,
71{
72 pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
74 self.calculate(false).map(|(root, _)| root)
75 }
76
77 pub fn incremental_root_with_updates(
79 self,
80 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
81 self.calculate(true)
82 }
83
84 fn calculate(
87 self,
88 retain_updates: bool,
89 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
90 let mut tracker = ParallelTrieTracker::default();
91 let trie_nodes_sorted = Arc::new(self.input.nodes.into_sorted());
92 let hashed_state_sorted = Arc::new(self.input.state.into_sorted());
93 let prefix_sets = self.input.prefix_sets.freeze();
94 let storage_root_targets = StorageRootTargets::new(
95 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
96 prefix_sets.storage_prefix_sets,
97 );
98
99 tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
101 debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
102 let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
103
104 let handle = get_runtime_handle();
106
107 for (hashed_address, prefix_set) in
108 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
109 {
110 let view = self.view.clone();
111 let hashed_state_sorted = hashed_state_sorted.clone();
112 let trie_nodes_sorted = trie_nodes_sorted.clone();
113 #[cfg(feature = "metrics")]
114 let metrics = self.metrics.storage_trie.clone();
115
116 let (tx, rx) = mpsc::sync_channel(1);
117
118 drop(handle.spawn_blocking(move || {
120 let result = (|| -> Result<_, ParallelStateRootError> {
121 let provider_ro = view.provider_ro()?;
122 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
123 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
124 &trie_nodes_sorted,
125 );
126 let hashed_state = HashedPostStateCursorFactory::new(
127 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
128 &hashed_state_sorted,
129 );
130 Ok(StorageRoot::new_hashed(
131 trie_cursor_factory,
132 hashed_state,
133 hashed_address,
134 prefix_set,
135 #[cfg(feature = "metrics")]
136 metrics,
137 )
138 .calculate(retain_updates)?)
139 })();
140 let _ = tx.send(result);
141 }));
142 storage_roots.insert(hashed_address, rx);
143 }
144
145 trace!(target: "trie::parallel_state_root", "calculating state root");
146 let mut trie_updates = TrieUpdates::default();
147
148 let provider_ro = self.view.provider_ro()?;
149 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
150 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
151 &trie_nodes_sorted,
152 );
153 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
154 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
155 &hashed_state_sorted,
156 );
157
158 let walker = TrieWalker::<_>::state_trie(
159 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
160 prefix_sets.account_prefix_set,
161 )
162 .with_deletions_retained(retain_updates);
163 let mut account_node_iter = TrieNodeIter::state_trie(
164 walker,
165 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
166 );
167
168 let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
169 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
170 while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
171 match node {
172 TrieElement::Branch(node) => {
173 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
174 }
175 TrieElement::Leaf(hashed_address, account) => {
176 let storage_root_result = match storage_roots.remove(&hashed_address) {
177 Some(rx) => rx.recv().map_err(|_| {
178 ParallelStateRootError::StorageRoot(StorageRootError::Database(
179 DatabaseError::Other(format!(
180 "channel closed for {hashed_address}"
181 )),
182 ))
183 })??,
184 None => {
187 tracker.inc_missed_leaves();
188 StorageRoot::new_hashed(
189 trie_cursor_factory.clone(),
190 hashed_cursor_factory.clone(),
191 hashed_address,
192 Default::default(),
193 #[cfg(feature = "metrics")]
194 self.metrics.storage_trie.clone(),
195 )
196 .calculate(retain_updates)?
197 }
198 };
199
200 let (storage_root, _, updates) = match storage_root_result {
201 reth_trie::StorageRootProgress::Complete(root, _, updates) => (root, (), updates),
202 reth_trie::StorageRootProgress::Progress(..) => {
203 return Err(ParallelStateRootError::StorageRoot(
204 StorageRootError::Database(DatabaseError::Other(
205 "StorageRoot returned Progress variant in parallel trie calculation".to_string()
206 ))
207 ))
208 }
209 };
210
211 if retain_updates {
212 trie_updates.insert_storage_updates(hashed_address, updates);
213 }
214
215 account_rlp.clear();
216 let account = account.into_trie_account(storage_root);
217 account.encode(&mut account_rlp as &mut dyn BufMut);
218 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
219 }
220 }
221 }
222
223 let root = hash_builder.root();
224
225 let removed_keys = account_node_iter.walker.take_removed_keys();
226 trie_updates.finalize(hash_builder, removed_keys, prefix_sets.destroyed_accounts);
227
228 let stats = tracker.finish();
229
230 #[cfg(feature = "metrics")]
231 self.metrics.record_state_trie(stats);
232
233 trace!(
234 target: "trie::parallel_state_root",
235 %root,
236 duration = ?stats.duration(),
237 branches_added = stats.branches_added(),
238 leaves_added = stats.leaves_added(),
239 missed_leaves = stats.missed_leaves(),
240 precomputed_storage_roots = stats.precomputed_storage_roots(),
241 "Calculated state root"
242 );
243
244 Ok((root, trie_updates))
245 }
246}
247
248#[derive(Error, Debug)]
250pub enum ParallelStateRootError {
251 #[error(transparent)]
253 StorageRoot(#[from] StorageRootError),
254 #[error(transparent)]
256 Provider(#[from] ProviderError),
257 #[error("{_0}")]
259 Other(String),
260}
261
262impl From<ParallelStateRootError> for ProviderError {
263 fn from(error: ParallelStateRootError) -> Self {
264 match error {
265 ParallelStateRootError::Provider(error) => error,
266 ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
267 Self::Database(error)
268 }
269 ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
270 }
271 }
272}
273
274impl From<alloy_rlp::Error> for ParallelStateRootError {
275 fn from(error: alloy_rlp::Error) -> Self {
276 Self::Provider(ProviderError::Rlp(error))
277 }
278}
279
280fn get_runtime_handle() -> Handle {
283 Handle::try_current().unwrap_or_else(|_| {
284 static RT: OnceLock<Runtime> = OnceLock::new();
286
287 let rt = RT.get_or_init(|| {
288 Builder::new_multi_thread()
289 .thread_keep_alive(Duration::from_secs(15))
293 .build()
294 .expect("Failed to create tokio runtime")
295 });
296
297 rt.handle().clone()
298 })
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use alloy_primitives::{keccak256, Address, U256};
305 use rand::Rng;
306 use reth_primitives_traits::{Account, StorageEntry};
307 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
308 use reth_trie::{test_utils, HashedPostState, HashedStorage};
309
310 #[tokio::test]
311 async fn random_parallel_root() {
312 let factory = create_test_provider_factory();
313 let consistent_view = ConsistentDbView::new(factory.clone(), None);
314
315 let mut rng = rand::rng();
316 let mut state = (0..100)
317 .map(|_| {
318 let address = Address::random();
319 let account =
320 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
321 let mut storage = HashMap::<B256, U256>::default();
322 let has_storage = rng.random_bool(0.7);
323 if has_storage {
324 for _ in 0..100 {
325 storage.insert(
326 B256::from(U256::from(rng.random::<u64>())),
327 U256::from(rng.random::<u64>()),
328 );
329 }
330 }
331 (address, (account, storage))
332 })
333 .collect::<HashMap<_, _>>();
334
335 {
336 let provider_rw = factory.provider_rw().unwrap();
337 provider_rw
338 .insert_account_for_hashing(
339 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
340 )
341 .unwrap();
342 provider_rw
343 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
344 (
345 *address,
346 storage
347 .iter()
348 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
349 )
350 }))
351 .unwrap();
352 provider_rw.commit().unwrap();
353 }
354
355 assert_eq!(
356 ParallelStateRoot::new(consistent_view.clone(), Default::default())
357 .incremental_root()
358 .unwrap(),
359 test_utils::state_root(state.clone())
360 );
361
362 let mut hashed_state = HashedPostState::default();
363 for (address, (account, storage)) in &mut state {
364 let hashed_address = keccak256(address);
365
366 let should_update_account = rng.random_bool(0.5);
367 if should_update_account {
368 *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
369 hashed_state.accounts.insert(hashed_address, Some(*account));
370 }
371
372 let should_update_storage = rng.random_bool(0.3);
373 if should_update_storage {
374 for (slot, value) in storage.iter_mut() {
375 let hashed_slot = keccak256(slot);
376 *value = U256::from(rng.random::<u64>());
377 hashed_state
378 .storages
379 .entry(hashed_address)
380 .or_insert_with(HashedStorage::default)
381 .storage
382 .insert(hashed_slot, *value);
383 }
384 }
385 }
386
387 assert_eq!(
388 ParallelStateRoot::new(consistent_view, TrieInput::from_state(hashed_state))
389 .incremental_root()
390 .unwrap(),
391 test_utils::state_root(state)
392 );
393 }
394}