1use crate::{
2 metrics::ParallelTrieMetrics,
3 proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4 root::ParallelStateRootError,
5 stats::ParallelTrieTracker,
6 StorageRootTargets,
7};
8use alloy_primitives::{
9 map::{B256Map, B256Set, HashMap},
10 B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use itertools::Itertools;
14use reth_execution_errors::StorageRootError;
15use reth_provider::{
16 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17 ProviderError,
18};
19use reth_storage_errors::db::DatabaseError;
20use reth_trie::{
21 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22 node_iter::{TrieElement, TrieNodeIter},
23 prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
24 proof::StorageProof,
25 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
26 updates::TrieUpdatesSorted,
27 walker::TrieWalker,
28 DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted,
29 MultiProofTargets, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
30};
31use reth_trie_common::{
32 added_removed_keys::MultiAddedRemovedKeys,
33 proof::{DecodedProofNodes, ProofRetainer},
34};
35use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
36use std::sync::{mpsc::Receiver, Arc};
37use tracing::debug;
38
39#[derive(Debug)]
44pub struct ParallelProof<Factory: DatabaseProviderFactory> {
45 view: ConsistentDbView<Factory>,
47 pub nodes_sorted: Arc<TrieUpdatesSorted>,
50 pub state_sorted: Arc<HashedPostStateSorted>,
52 pub prefix_sets: Arc<TriePrefixSetsMut>,
56 collect_branch_node_masks: bool,
58 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
60 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
62 #[cfg(feature = "metrics")]
63 metrics: ParallelTrieMetrics,
64}
65
66impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
67 pub fn new(
69 view: ConsistentDbView<Factory>,
70 nodes_sorted: Arc<TrieUpdatesSorted>,
71 state_sorted: Arc<HashedPostStateSorted>,
72 prefix_sets: Arc<TriePrefixSetsMut>,
73 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
74 ) -> Self {
75 Self {
76 view,
77 nodes_sorted,
78 state_sorted,
79 prefix_sets,
80 collect_branch_node_masks: false,
81 multi_added_removed_keys: None,
82 storage_proof_task_handle,
83 #[cfg(feature = "metrics")]
84 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
85 }
86 }
87
88 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
90 self.collect_branch_node_masks = branch_node_masks;
91 self
92 }
93
94 pub fn with_multi_added_removed_keys(
97 mut self,
98 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
99 ) -> Self {
100 self.multi_added_removed_keys = multi_added_removed_keys;
101 self
102 }
103}
104
105impl<Factory> ParallelProof<Factory>
106where
107 Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
108{
109 fn spawn_storage_proof(
111 &self,
112 hashed_address: B256,
113 prefix_set: PrefixSet,
114 target_slots: B256Set,
115 ) -> Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>> {
116 let input = StorageProofInput::new(
117 hashed_address,
118 prefix_set,
119 target_slots,
120 self.collect_branch_node_masks,
121 self.multi_added_removed_keys.clone(),
122 );
123
124 let (sender, receiver) = std::sync::mpsc::channel();
125 let _ =
126 self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
127 receiver
128 }
129
130 pub fn storage_proof(
132 self,
133 hashed_address: B256,
134 target_slots: B256Set,
135 ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
136 let total_targets = target_slots.len();
137 let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
138 let prefix_set = prefix_set.freeze();
139
140 debug!(
141 target: "trie::parallel_proof",
142 total_targets,
143 ?hashed_address,
144 "Starting storage proof generation"
145 );
146
147 let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
148 let proof_result = receiver.recv().map_err(|_| {
149 ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
150 format!("channel closed for {hashed_address}"),
151 )))
152 })?;
153
154 debug!(
155 target: "trie::parallel_proof",
156 total_targets,
157 ?hashed_address,
158 "Storage proof generation completed"
159 );
160
161 proof_result
162 }
163
164 pub fn decoded_storage_proof(
167 self,
168 hashed_address: B256,
169 target_slots: B256Set,
170 ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
171 self.storage_proof(hashed_address, target_slots)
172 }
173
174 pub fn decoded_multiproof(
176 self,
177 targets: MultiProofTargets,
178 ) -> Result<DecodedMultiProof, ParallelStateRootError> {
179 let mut tracker = ParallelTrieTracker::default();
180
181 let mut prefix_sets = (*self.prefix_sets).clone();
183 prefix_sets.extend(TriePrefixSetsMut {
184 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
185 storage_prefix_sets: targets
186 .iter()
187 .filter(|&(_hashed_address, slots)| !slots.is_empty())
188 .map(|(hashed_address, slots)| {
189 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
190 })
191 .collect(),
192 destroyed_accounts: Default::default(),
193 });
194 let prefix_sets = prefix_sets.freeze();
195
196 let storage_root_targets = StorageRootTargets::new(
197 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
198 prefix_sets.storage_prefix_sets.clone(),
199 );
200 let storage_root_targets_len = storage_root_targets.len();
201
202 debug!(
203 target: "trie::parallel_proof",
204 total_targets = storage_root_targets_len,
205 "Starting parallel proof generation"
206 );
207
208 tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
210
211 let mut storage_proof_receivers =
214 B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
215
216 for (hashed_address, prefix_set) in
217 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
218 {
219 let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
220 let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
221
222 storage_proof_receivers.insert(hashed_address, receiver);
225 }
226
227 let provider_ro = self.view.provider_ro()?;
228 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
229 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
230 &self.nodes_sorted,
231 );
232 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
233 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
234 &self.state_sorted,
235 );
236
237 let accounts_added_removed_keys =
238 self.multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
239
240 let walker = TrieWalker::<_>::state_trie(
242 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
243 prefix_sets.account_prefix_set,
244 )
245 .with_added_removed_keys(accounts_added_removed_keys)
246 .with_deletions_retained(true);
247
248 let retainer = targets
250 .keys()
251 .map(Nibbles::unpack)
252 .collect::<ProofRetainer>()
253 .with_added_removed_keys(accounts_added_removed_keys);
254 let mut hash_builder = HashBuilder::default()
255 .with_proof_retainer(retainer)
256 .with_updates(self.collect_branch_node_masks);
257
258 let mut collected_decoded_storages: B256Map<DecodedStorageMultiProof> =
261 targets.keys().map(|key| (*key, DecodedStorageMultiProof::empty())).collect();
262 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
263 let mut account_node_iter = TrieNodeIter::state_trie(
264 walker,
265 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
266 );
267 while let Some(account_node) =
268 account_node_iter.try_next().map_err(ProviderError::Database)?
269 {
270 match account_node {
271 TrieElement::Branch(node) => {
272 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
273 }
274 TrieElement::Leaf(hashed_address, account) => {
275 let decoded_storage_multiproof = match storage_proof_receivers
276 .remove(&hashed_address)
277 {
278 Some(rx) => rx.recv().map_err(|e| {
279 ParallelStateRootError::StorageRoot(StorageRootError::Database(
280 DatabaseError::Other(format!(
281 "channel closed for {hashed_address}: {e}"
282 )),
283 ))
284 })??,
285 None => {
288 tracker.inc_missed_leaves();
289
290 let raw_fallback_proof = StorageProof::new_hashed(
291 trie_cursor_factory.clone(),
292 hashed_cursor_factory.clone(),
293 hashed_address,
294 )
295 .with_prefix_set_mut(Default::default())
296 .storage_multiproof(
297 targets.get(&hashed_address).cloned().unwrap_or_default(),
298 )
299 .map_err(|e| {
300 ParallelStateRootError::StorageRoot(StorageRootError::Database(
301 DatabaseError::Other(e.to_string()),
302 ))
303 })?;
304
305 raw_fallback_proof.try_into()?
306 }
307 };
308
309 account_rlp.clear();
311 let account = account.into_trie_account(decoded_storage_multiproof.root);
312 account.encode(&mut account_rlp as &mut dyn BufMut);
313
314 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
315
316 if targets.contains_key(&hashed_address) {
318 collected_decoded_storages
319 .insert(hashed_address, decoded_storage_multiproof);
320 }
321 }
322 }
323 }
324 let _ = hash_builder.root();
325
326 let stats = tracker.finish();
327 #[cfg(feature = "metrics")]
328 self.metrics.record(stats);
329
330 let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
331 let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
332
333 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
334 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
335 (
336 updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
337 updated_branch_nodes
338 .into_iter()
339 .map(|(path, node)| (path, node.tree_mask))
340 .collect(),
341 )
342 } else {
343 (HashMap::default(), HashMap::default())
344 };
345
346 debug!(
347 target: "trie::parallel_proof",
348 total_targets = storage_root_targets_len,
349 duration = ?stats.duration(),
350 branches_added = stats.branches_added(),
351 leaves_added = stats.leaves_added(),
352 missed_leaves = stats.missed_leaves(),
353 precomputed_storage_roots = stats.precomputed_storage_roots(),
354 "Calculated decoded proof"
355 );
356
357 Ok(DecodedMultiProof {
358 account_subtree: decoded_account_subtree,
359 branch_node_hash_masks,
360 branch_node_tree_masks,
361 storages: collected_decoded_storages,
362 })
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
370 use alloy_primitives::{
371 keccak256,
372 map::{B256Set, DefaultHashBuilder},
373 Address, U256,
374 };
375 use rand::Rng;
376 use reth_primitives_traits::{Account, StorageEntry};
377 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
378 use reth_trie::proof::Proof;
379 use tokio::runtime::Runtime;
380
381 #[test]
382 fn random_parallel_proof() {
383 let factory = create_test_provider_factory();
384 let consistent_view = ConsistentDbView::new(factory.clone(), None);
385
386 let mut rng = rand::rng();
387 let state = (0..100)
388 .map(|_| {
389 let address = Address::random();
390 let account =
391 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
392 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
393 let has_storage = rng.random_bool(0.7);
394 if has_storage {
395 for _ in 0..100 {
396 storage.insert(
397 B256::from(U256::from(rng.random::<u64>())),
398 U256::from(rng.random::<u64>()),
399 );
400 }
401 }
402 (address, (account, storage))
403 })
404 .collect::<HashMap<_, _, DefaultHashBuilder>>();
405
406 {
407 let provider_rw = factory.provider_rw().unwrap();
408 provider_rw
409 .insert_account_for_hashing(
410 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
411 )
412 .unwrap();
413 provider_rw
414 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
415 (
416 *address,
417 storage
418 .iter()
419 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
420 )
421 }))
422 .unwrap();
423 provider_rw.commit().unwrap();
424 }
425
426 let mut targets = MultiProofTargets::default();
427 for (address, (_, storage)) in state.iter().take(10) {
428 let hashed_address = keccak256(*address);
429 let mut target_slots = B256Set::default();
430
431 for (slot, _) in storage.iter().take(5) {
432 target_slots.insert(*slot);
433 }
434
435 if !target_slots.is_empty() {
436 targets.insert(hashed_address, target_slots);
437 }
438 }
439
440 let provider_rw = factory.provider_rw().unwrap();
441 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
442 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
443
444 let rt = Runtime::new().unwrap();
445
446 let task_ctx =
447 ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
448 let proof_task =
449 ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
450 let proof_task_handle = proof_task.handle();
451
452 let join_handle = rt.spawn_blocking(move || proof_task.run());
455
456 let parallel_result = ParallelProof::new(
457 consistent_view,
458 Default::default(),
459 Default::default(),
460 Default::default(),
461 proof_task_handle.clone(),
462 )
463 .decoded_multiproof(targets.clone())
464 .unwrap();
465
466 let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
467 .multiproof(targets.clone())
468 .unwrap(); let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
470 .try_into()
471 .expect("Failed to decode sequential_result for test comparison");
472
473 assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
475
476 assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
478
479 for (hashed_address, storage_proof) in ¶llel_result.storages {
481 let sequential_storage_proof =
482 sequential_result_decoded.storages.get(hashed_address).unwrap();
483 assert_eq!(storage_proof, sequential_storage_proof);
484 }
485
486 assert_eq!(parallel_result, sequential_result_decoded);
488
489 drop(proof_task_handle);
492 rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
493 }
494}