1use crate::{
2 metrics::ParallelTrieMetrics,
3 proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4 root::ParallelStateRootError,
5 stats::ParallelTrieTracker,
6 StorageRootTargets,
7};
8use alloy_primitives::{
9 map::{B256Map, B256Set, HashMap},
10 B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use dashmap::DashMap;
14use itertools::Itertools;
15use reth_execution_errors::StorageRootError;
16use reth_provider::{
17 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
18 ProviderError,
19};
20use reth_storage_errors::db::DatabaseError;
21use reth_trie::{
22 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
23 node_iter::{TrieElement, TrieNodeIter},
24 prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
25 proof::StorageProof,
26 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
27 updates::TrieUpdatesSorted,
28 walker::TrieWalker,
29 DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted,
30 MultiProofTargets, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
31};
32use reth_trie_common::{
33 added_removed_keys::MultiAddedRemovedKeys,
34 proof::{DecodedProofNodes, ProofRetainer},
35};
36use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
37use std::sync::{mpsc::Receiver, Arc};
38use tracing::trace;
39
40#[derive(Debug)]
45pub struct ParallelProof<Factory: DatabaseProviderFactory> {
46 view: ConsistentDbView<Factory>,
48 pub nodes_sorted: Arc<TrieUpdatesSorted>,
51 pub state_sorted: Arc<HashedPostStateSorted>,
53 pub prefix_sets: Arc<TriePrefixSetsMut>,
57 collect_branch_node_masks: bool,
59 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
61 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
63 missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
66 #[cfg(feature = "metrics")]
67 metrics: ParallelTrieMetrics,
68}
69
70impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
71 pub fn new(
73 view: ConsistentDbView<Factory>,
74 nodes_sorted: Arc<TrieUpdatesSorted>,
75 state_sorted: Arc<HashedPostStateSorted>,
76 prefix_sets: Arc<TriePrefixSetsMut>,
77 missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
78 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
79 ) -> Self {
80 Self {
81 view,
82 nodes_sorted,
83 state_sorted,
84 prefix_sets,
85 missed_leaves_storage_roots,
86 collect_branch_node_masks: false,
87 multi_added_removed_keys: None,
88 storage_proof_task_handle,
89 #[cfg(feature = "metrics")]
90 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
91 }
92 }
93
94 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
96 self.collect_branch_node_masks = branch_node_masks;
97 self
98 }
99
100 pub fn with_multi_added_removed_keys(
103 mut self,
104 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
105 ) -> Self {
106 self.multi_added_removed_keys = multi_added_removed_keys;
107 self
108 }
109}
110
111impl<Factory> ParallelProof<Factory>
112where
113 Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
114{
115 fn queue_storage_proof(
117 &self,
118 hashed_address: B256,
119 prefix_set: PrefixSet,
120 target_slots: B256Set,
121 ) -> Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>> {
122 let input = StorageProofInput::new(
123 hashed_address,
124 prefix_set,
125 target_slots,
126 self.collect_branch_node_masks,
127 self.multi_added_removed_keys.clone(),
128 );
129
130 let (sender, receiver) = std::sync::mpsc::channel();
131 let _ =
132 self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
133 receiver
134 }
135
136 pub fn storage_proof(
138 self,
139 hashed_address: B256,
140 target_slots: B256Set,
141 ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
142 let total_targets = target_slots.len();
143 let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
144 let prefix_set = prefix_set.freeze();
145
146 trace!(
147 target: "trie::parallel_proof",
148 total_targets,
149 ?hashed_address,
150 "Starting storage proof generation"
151 );
152
153 let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
154 let proof_result = receiver.recv().map_err(|_| {
155 ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
156 format!("channel closed for {hashed_address}"),
157 )))
158 })?;
159
160 trace!(
161 target: "trie::parallel_proof",
162 total_targets,
163 ?hashed_address,
164 "Storage proof generation completed"
165 );
166
167 proof_result
168 }
169
170 pub fn decoded_multiproof(
172 self,
173 targets: MultiProofTargets,
174 ) -> Result<DecodedMultiProof, ParallelStateRootError> {
175 let mut tracker = ParallelTrieTracker::default();
176
177 let mut prefix_sets = (*self.prefix_sets).clone();
179 prefix_sets.extend(TriePrefixSetsMut {
180 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
181 storage_prefix_sets: targets
182 .iter()
183 .filter(|&(_hashed_address, slots)| !slots.is_empty())
184 .map(|(hashed_address, slots)| {
185 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
186 })
187 .collect(),
188 destroyed_accounts: Default::default(),
189 });
190 let prefix_sets = prefix_sets.freeze();
191
192 let storage_root_targets = StorageRootTargets::new(
193 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
194 prefix_sets.storage_prefix_sets.clone(),
195 );
196 let storage_root_targets_len = storage_root_targets.len();
197
198 trace!(
199 target: "trie::parallel_proof",
200 total_targets = storage_root_targets_len,
201 "Starting parallel proof generation"
202 );
203
204 tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
206
207 let mut storage_proof_receivers =
210 B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
211
212 for (hashed_address, prefix_set) in
213 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
214 {
215 let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
216 let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
217
218 storage_proof_receivers.insert(hashed_address, receiver);
221 }
222
223 let provider_ro = self.view.provider_ro()?;
224 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
225 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
226 &self.nodes_sorted,
227 );
228 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
229 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
230 &self.state_sorted,
231 );
232
233 let accounts_added_removed_keys =
234 self.multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
235
236 let walker = TrieWalker::<_>::state_trie(
238 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
239 prefix_sets.account_prefix_set,
240 )
241 .with_added_removed_keys(accounts_added_removed_keys)
242 .with_deletions_retained(true);
243
244 let retainer = targets
246 .keys()
247 .map(Nibbles::unpack)
248 .collect::<ProofRetainer>()
249 .with_added_removed_keys(accounts_added_removed_keys);
250 let mut hash_builder = HashBuilder::default()
251 .with_proof_retainer(retainer)
252 .with_updates(self.collect_branch_node_masks);
253
254 let mut collected_decoded_storages: B256Map<DecodedStorageMultiProof> =
257 targets.keys().map(|key| (*key, DecodedStorageMultiProof::empty())).collect();
258 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
259 let mut account_node_iter = TrieNodeIter::state_trie(
260 walker,
261 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
262 );
263 while let Some(account_node) =
264 account_node_iter.try_next().map_err(ProviderError::Database)?
265 {
266 match account_node {
267 TrieElement::Branch(node) => {
268 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
269 }
270 TrieElement::Leaf(hashed_address, account) => {
271 let root = match storage_proof_receivers.remove(&hashed_address) {
272 Some(rx) => {
273 let decoded_storage_multiproof = rx.recv().map_err(|e| {
274 ParallelStateRootError::StorageRoot(StorageRootError::Database(
275 DatabaseError::Other(format!(
276 "channel closed for {hashed_address}: {e}"
277 )),
278 ))
279 })??;
280 let root = decoded_storage_multiproof.root;
281 collected_decoded_storages
282 .insert(hashed_address, decoded_storage_multiproof);
283 root
284 }
285 None => {
288 tracker.inc_missed_leaves();
289
290 match self.missed_leaves_storage_roots.entry(hashed_address) {
291 dashmap::Entry::Occupied(occ) => *occ.get(),
292 dashmap::Entry::Vacant(vac) => {
293 let root = StorageProof::new_hashed(
294 trie_cursor_factory.clone(),
295 hashed_cursor_factory.clone(),
296 hashed_address,
297 )
298 .with_prefix_set_mut(Default::default())
299 .storage_multiproof(
300 targets.get(&hashed_address).cloned().unwrap_or_default(),
301 )
302 .map_err(|e| {
303 ParallelStateRootError::StorageRoot(
304 StorageRootError::Database(DatabaseError::Other(
305 e.to_string(),
306 )),
307 )
308 })?
309 .root;
310 vac.insert(root);
311 root
312 }
313 }
314 }
315 };
316
317 account_rlp.clear();
319 let account = account.into_trie_account(root);
320 account.encode(&mut account_rlp as &mut dyn BufMut);
321
322 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
323 }
324 }
325 }
326 let _ = hash_builder.root();
327
328 let stats = tracker.finish();
329 #[cfg(feature = "metrics")]
330 self.metrics.record(stats);
331
332 let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
333 let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
334
335 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
336 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
337 (
338 updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
339 updated_branch_nodes
340 .into_iter()
341 .map(|(path, node)| (path, node.tree_mask))
342 .collect(),
343 )
344 } else {
345 (HashMap::default(), HashMap::default())
346 };
347
348 trace!(
349 target: "trie::parallel_proof",
350 total_targets = storage_root_targets_len,
351 duration = ?stats.duration(),
352 branches_added = stats.branches_added(),
353 leaves_added = stats.leaves_added(),
354 missed_leaves = stats.missed_leaves(),
355 precomputed_storage_roots = stats.precomputed_storage_roots(),
356 "Calculated decoded proof"
357 );
358
359 Ok(DecodedMultiProof {
360 account_subtree: decoded_account_subtree,
361 branch_node_hash_masks,
362 branch_node_tree_masks,
363 storages: collected_decoded_storages,
364 })
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
372 use alloy_primitives::{
373 keccak256,
374 map::{B256Set, DefaultHashBuilder},
375 Address, U256,
376 };
377 use rand::Rng;
378 use reth_primitives_traits::{Account, StorageEntry};
379 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
380 use reth_trie::proof::Proof;
381 use tokio::runtime::Runtime;
382
383 #[test]
384 fn random_parallel_proof() {
385 let factory = create_test_provider_factory();
386 let consistent_view = ConsistentDbView::new(factory.clone(), None);
387
388 let mut rng = rand::rng();
389 let state = (0..100)
390 .map(|_| {
391 let address = Address::random();
392 let account =
393 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
394 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
395 let has_storage = rng.random_bool(0.7);
396 if has_storage {
397 for _ in 0..100 {
398 storage.insert(
399 B256::from(U256::from(rng.random::<u64>())),
400 U256::from(rng.random::<u64>()),
401 );
402 }
403 }
404 (address, (account, storage))
405 })
406 .collect::<HashMap<_, _, DefaultHashBuilder>>();
407
408 {
409 let provider_rw = factory.provider_rw().unwrap();
410 provider_rw
411 .insert_account_for_hashing(
412 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
413 )
414 .unwrap();
415 provider_rw
416 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
417 (
418 *address,
419 storage
420 .iter()
421 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
422 )
423 }))
424 .unwrap();
425 provider_rw.commit().unwrap();
426 }
427
428 let mut targets = MultiProofTargets::default();
429 for (address, (_, storage)) in state.iter().take(10) {
430 let hashed_address = keccak256(*address);
431 let mut target_slots = B256Set::default();
432
433 for (slot, _) in storage.iter().take(5) {
434 target_slots.insert(*slot);
435 }
436
437 if !target_slots.is_empty() {
438 targets.insert(hashed_address, target_slots);
439 }
440 }
441
442 let provider_rw = factory.provider_rw().unwrap();
443 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
444 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
445
446 let rt = Runtime::new().unwrap();
447
448 let task_ctx =
449 ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
450 let proof_task =
451 ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1, 1)
452 .unwrap();
453 let proof_task_handle = proof_task.handle();
454
455 let join_handle = rt.spawn_blocking(move || proof_task.run());
458
459 let parallel_result = ParallelProof::new(
460 consistent_view,
461 Default::default(),
462 Default::default(),
463 Default::default(),
464 Default::default(),
465 proof_task_handle.clone(),
466 )
467 .decoded_multiproof(targets.clone())
468 .unwrap();
469
470 let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
471 .multiproof(targets.clone())
472 .unwrap(); let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
474 .try_into()
475 .expect("Failed to decode sequential_result for test comparison");
476
477 assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
479
480 assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
482
483 for (hashed_address, storage_proof) in ¶llel_result.storages {
485 let sequential_storage_proof =
486 sequential_result_decoded.storages.get(hashed_address).unwrap();
487 assert_eq!(storage_proof, sequential_storage_proof);
488 }
489
490 assert_eq!(parallel_result, sequential_result_decoded);
492
493 drop(proof_task_handle);
496 rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
497 }
498}