1use crate::{
2 metrics::ParallelTrieMetrics,
3 proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4 root::ParallelStateRootError,
5 stats::ParallelTrieTracker,
6 StorageRootTargets,
7};
8use alloy_primitives::{
9 map::{B256Map, B256Set, HashMap},
10 B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use itertools::Itertools;
14use reth_execution_errors::StorageRootError;
15use reth_provider::{
16 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17 ProviderError, StateCommitmentProvider,
18};
19use reth_storage_errors::db::DatabaseError;
20use reth_trie::{
21 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22 node_iter::{TrieElement, TrieNodeIter},
23 prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
24 proof::StorageProof,
25 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
26 updates::TrieUpdatesSorted,
27 walker::TrieWalker,
28 HashBuilder, HashedPostStateSorted, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof,
29 TRIE_ACCOUNT_RLP_MAX_SIZE,
30};
31use reth_trie_common::proof::ProofRetainer;
32use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
33use std::sync::{mpsc::Receiver, Arc};
34use tracing::debug;
35
36#[derive(Debug)]
41pub struct ParallelProof<Factory: DatabaseProviderFactory> {
42 view: ConsistentDbView<Factory>,
44 pub nodes_sorted: Arc<TrieUpdatesSorted>,
47 pub state_sorted: Arc<HashedPostStateSorted>,
49 pub prefix_sets: Arc<TriePrefixSetsMut>,
53 collect_branch_node_masks: bool,
55 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
57 #[cfg(feature = "metrics")]
58 metrics: ParallelTrieMetrics,
59}
60
61impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
62 pub fn new(
64 view: ConsistentDbView<Factory>,
65 nodes_sorted: Arc<TrieUpdatesSorted>,
66 state_sorted: Arc<HashedPostStateSorted>,
67 prefix_sets: Arc<TriePrefixSetsMut>,
68 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
69 ) -> Self {
70 Self {
71 view,
72 nodes_sorted,
73 state_sorted,
74 prefix_sets,
75 collect_branch_node_masks: false,
76 storage_proof_task_handle,
77 #[cfg(feature = "metrics")]
78 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
79 }
80 }
81
82 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
84 self.collect_branch_node_masks = branch_node_masks;
85 self
86 }
87}
88
89impl<Factory> ParallelProof<Factory>
90where
91 Factory:
92 DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
93{
94 fn spawn_storage_proof(
96 &self,
97 hashed_address: B256,
98 prefix_set: PrefixSet,
99 target_slots: B256Set,
100 ) -> Receiver<Result<StorageMultiProof, ParallelStateRootError>> {
101 let input = StorageProofInput::new(
102 hashed_address,
103 prefix_set,
104 target_slots,
105 self.collect_branch_node_masks,
106 );
107
108 let (sender, receiver) = std::sync::mpsc::channel();
109 let _ =
110 self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
111 receiver
112 }
113
114 pub fn storage_proof(
116 self,
117 hashed_address: B256,
118 target_slots: B256Set,
119 ) -> Result<StorageMultiProof, ParallelStateRootError> {
120 let total_targets = target_slots.len();
121 let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
122 let prefix_set = prefix_set.freeze();
123
124 debug!(
125 target: "trie::parallel_proof",
126 total_targets,
127 ?hashed_address,
128 "Starting storage proof generation"
129 );
130
131 let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
132 let proof_result = receiver.recv().map_err(|_| {
133 ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
134 format!("channel closed for {hashed_address}"),
135 )))
136 })?;
137
138 debug!(
139 target: "trie::parallel_proof",
140 total_targets,
141 ?hashed_address,
142 "Storage proof generation completed"
143 );
144
145 proof_result
146 }
147
148 pub fn multiproof(
150 self,
151 targets: MultiProofTargets,
152 ) -> Result<MultiProof, ParallelStateRootError> {
153 let mut tracker = ParallelTrieTracker::default();
154
155 let mut prefix_sets = (*self.prefix_sets).clone();
157 prefix_sets.extend(TriePrefixSetsMut {
158 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
159 storage_prefix_sets: targets
160 .iter()
161 .filter(|&(_hashed_address, slots)| !slots.is_empty())
162 .map(|(hashed_address, slots)| {
163 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
164 })
165 .collect(),
166 destroyed_accounts: Default::default(),
167 });
168 let prefix_sets = prefix_sets.freeze();
169
170 let storage_root_targets = StorageRootTargets::new(
171 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
172 prefix_sets.storage_prefix_sets.clone(),
173 );
174 let storage_root_targets_len = storage_root_targets.len();
175
176 debug!(
177 target: "trie::parallel_proof",
178 total_targets = storage_root_targets_len,
179 "Starting parallel proof generation"
180 );
181
182 tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
184
185 let mut storage_proofs =
188 B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
189
190 for (hashed_address, prefix_set) in
191 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
192 {
193 let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
194 let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
195
196 storage_proofs.insert(hashed_address, receiver);
199 }
200
201 let provider_ro = self.view.provider_ro()?;
202 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
203 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
204 &self.nodes_sorted,
205 );
206 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
207 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
208 &self.state_sorted,
209 );
210
211 let walker = TrieWalker::new(
213 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
214 prefix_sets.account_prefix_set,
215 )
216 .with_deletions_retained(true);
217
218 let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
220 let mut hash_builder = HashBuilder::default()
221 .with_proof_retainer(retainer)
222 .with_updates(self.collect_branch_node_masks);
223
224 let mut storages: B256Map<_> =
227 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
228 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
229 let mut account_node_iter = TrieNodeIter::state_trie(
230 walker,
231 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
232 );
233 while let Some(account_node) =
234 account_node_iter.try_next().map_err(ProviderError::Database)?
235 {
236 match account_node {
237 TrieElement::Branch(node) => {
238 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
239 }
240 TrieElement::Leaf(hashed_address, account) => {
241 let storage_multiproof = match storage_proofs.remove(&hashed_address) {
242 Some(rx) => rx.recv().map_err(|_| {
243 ParallelStateRootError::StorageRoot(StorageRootError::Database(
244 DatabaseError::Other(format!(
245 "channel closed for {hashed_address}"
246 )),
247 ))
248 })??,
249 None => {
252 tracker.inc_missed_leaves();
253 StorageProof::new_hashed(
254 trie_cursor_factory.clone(),
255 hashed_cursor_factory.clone(),
256 hashed_address,
257 )
258 .with_prefix_set_mut(Default::default())
259 .storage_multiproof(
260 targets.get(&hashed_address).cloned().unwrap_or_default(),
261 )
262 .map_err(|e| {
263 ParallelStateRootError::StorageRoot(StorageRootError::Database(
264 DatabaseError::Other(e.to_string()),
265 ))
266 })?
267 }
268 };
269
270 account_rlp.clear();
272 let account = account.into_trie_account(storage_multiproof.root);
273 account.encode(&mut account_rlp as &mut dyn BufMut);
274
275 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
276
277 if targets.contains_key(&hashed_address) {
279 storages.insert(hashed_address, storage_multiproof);
280 }
281 }
282 }
283 }
284 let _ = hash_builder.root();
285
286 let stats = tracker.finish();
287 #[cfg(feature = "metrics")]
288 self.metrics.record(stats);
289
290 let account_subtree = hash_builder.take_proof_nodes();
291 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
292 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
293 (
294 updated_branch_nodes
295 .iter()
296 .map(|(path, node)| (path.clone(), node.hash_mask))
297 .collect(),
298 updated_branch_nodes
299 .into_iter()
300 .map(|(path, node)| (path, node.tree_mask))
301 .collect(),
302 )
303 } else {
304 (HashMap::default(), HashMap::default())
305 };
306
307 debug!(
308 target: "trie::parallel_proof",
309 total_targets = storage_root_targets_len,
310 duration = ?stats.duration(),
311 branches_added = stats.branches_added(),
312 leaves_added = stats.leaves_added(),
313 missed_leaves = stats.missed_leaves(),
314 precomputed_storage_roots = stats.precomputed_storage_roots(),
315 "Calculated proof"
316 );
317
318 Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
326 use alloy_primitives::{
327 keccak256,
328 map::{B256Set, DefaultHashBuilder},
329 Address, U256,
330 };
331 use rand::Rng;
332 use reth_primitives_traits::{Account, StorageEntry};
333 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
334 use reth_trie::proof::Proof;
335 use tokio::runtime::Runtime;
336
337 #[test]
338 fn random_parallel_proof() {
339 let factory = create_test_provider_factory();
340 let consistent_view = ConsistentDbView::new(factory.clone(), None);
341
342 let mut rng = rand::rng();
343 let state = (0..100)
344 .map(|_| {
345 let address = Address::random();
346 let account =
347 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
348 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
349 let has_storage = rng.random_bool(0.7);
350 if has_storage {
351 for _ in 0..100 {
352 storage.insert(
353 B256::from(U256::from(rng.random::<u64>())),
354 U256::from(rng.random::<u64>()),
355 );
356 }
357 }
358 (address, (account, storage))
359 })
360 .collect::<HashMap<_, _, DefaultHashBuilder>>();
361
362 {
363 let provider_rw = factory.provider_rw().unwrap();
364 provider_rw
365 .insert_account_for_hashing(
366 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
367 )
368 .unwrap();
369 provider_rw
370 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
371 (
372 *address,
373 storage
374 .iter()
375 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
376 )
377 }))
378 .unwrap();
379 provider_rw.commit().unwrap();
380 }
381
382 let mut targets = MultiProofTargets::default();
383 for (address, (_, storage)) in state.iter().take(10) {
384 let hashed_address = keccak256(*address);
385 let mut target_slots = B256Set::default();
386
387 for (slot, _) in storage.iter().take(5) {
388 target_slots.insert(*slot);
389 }
390
391 if !target_slots.is_empty() {
392 targets.insert(hashed_address, target_slots);
393 }
394 }
395
396 let provider_rw = factory.provider_rw().unwrap();
397 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
398 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
399
400 let rt = Runtime::new().unwrap();
401
402 let task_ctx =
403 ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
404 let proof_task =
405 ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
406 let proof_task_handle = proof_task.handle();
407
408 let join_handle = rt.spawn_blocking(move || proof_task.run());
411
412 let parallel_result = ParallelProof::new(
413 consistent_view,
414 Default::default(),
415 Default::default(),
416 Default::default(),
417 proof_task_handle.clone(),
418 )
419 .multiproof(targets.clone())
420 .unwrap();
421
422 let sequential_result =
423 Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap();
424
425 assert_eq!(parallel_result.account_subtree, sequential_result.account_subtree);
427
428 assert_eq!(parallel_result.storages.len(), sequential_result.storages.len());
430
431 for (hashed_address, storage_proof) in ¶llel_result.storages {
433 let sequential_storage_proof = sequential_result.storages.get(hashed_address).unwrap();
434 assert_eq!(storage_proof, sequential_storage_proof);
435 }
436
437 assert_eq!(parallel_result, sequential_result);
439
440 drop(proof_task_handle);
443 rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
444 }
445}