1use crate::{
2 metrics::ParallelTrieMetrics, root::ParallelStateRootError, stats::ParallelTrieTracker,
3 StorageRootTargets,
4};
5use alloy_primitives::{
6 map::{B256Map, HashMap},
7 B256,
8};
9use alloy_rlp::{BufMut, Encodable};
10use itertools::Itertools;
11use reth_execution_errors::StorageRootError;
12use reth_provider::{
13 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
14 StateCommitmentProvider,
15};
16use reth_storage_errors::db::DatabaseError;
17use reth_trie::{
18 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
19 node_iter::{TrieElement, TrieNodeIter},
20 prefix_set::{PrefixSetMut, TriePrefixSetsMut},
21 proof::StorageProof,
22 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
23 updates::TrieUpdatesSorted,
24 walker::TrieWalker,
25 HashBuilder, HashedPostStateSorted, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof,
26 TRIE_ACCOUNT_RLP_MAX_SIZE,
27};
28use reth_trie_common::proof::ProofRetainer;
29use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
30use std::{sync::Arc, time::Instant};
31use tokio::runtime::Handle;
32use tracing::{debug, trace};
33
34#[derive(Debug)]
39pub struct ParallelProof<Factory> {
40 view: ConsistentDbView<Factory>,
42 pub nodes_sorted: Arc<TrieUpdatesSorted>,
45 pub state_sorted: Arc<HashedPostStateSorted>,
47 pub prefix_sets: Arc<TriePrefixSetsMut>,
51 collect_branch_node_masks: bool,
53 executor: Handle,
55 #[cfg(feature = "metrics")]
56 metrics: ParallelTrieMetrics,
57}
58
59impl<Factory> ParallelProof<Factory> {
60 pub fn new(
62 view: ConsistentDbView<Factory>,
63 nodes_sorted: Arc<TrieUpdatesSorted>,
64 state_sorted: Arc<HashedPostStateSorted>,
65 prefix_sets: Arc<TriePrefixSetsMut>,
66 executor: Handle,
67 ) -> Self {
68 Self {
69 view,
70 nodes_sorted,
71 state_sorted,
72 prefix_sets,
73 collect_branch_node_masks: false,
74 executor,
75 #[cfg(feature = "metrics")]
76 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
77 }
78 }
79
80 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
82 self.collect_branch_node_masks = branch_node_masks;
83 self
84 }
85}
86
87impl<Factory> ParallelProof<Factory>
88where
89 Factory:
90 DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
91{
92 pub fn multiproof(
94 self,
95 targets: MultiProofTargets,
96 ) -> Result<MultiProof, ParallelStateRootError> {
97 let mut tracker = ParallelTrieTracker::default();
98
99 let mut prefix_sets = (*self.prefix_sets).clone();
101 prefix_sets.extend(TriePrefixSetsMut {
102 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
103 storage_prefix_sets: targets
104 .iter()
105 .filter(|&(_hashed_address, slots)| !slots.is_empty())
106 .map(|(hashed_address, slots)| {
107 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
108 })
109 .collect(),
110 destroyed_accounts: Default::default(),
111 });
112 let prefix_sets = prefix_sets.freeze();
113
114 let storage_root_targets = StorageRootTargets::new(
115 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
116 prefix_sets.storage_prefix_sets.clone(),
117 );
118 let storage_root_targets_len = storage_root_targets.len();
119
120 debug!(
121 target: "trie::parallel_proof",
122 total_targets = storage_root_targets_len,
123 "Starting parallel proof generation"
124 );
125
126 tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
128
129 let mut storage_proofs =
132 B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
133
134 for (hashed_address, prefix_set) in
135 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
136 {
137 let view = self.view.clone();
138 let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
139 let trie_nodes_sorted = self.nodes_sorted.clone();
140 let hashed_state_sorted = self.state_sorted.clone();
141 let collect_masks = self.collect_branch_node_masks;
142
143 let (tx, rx) = std::sync::mpsc::sync_channel(1);
144
145 self.executor.spawn_blocking(move || {
147 debug!(
148 target: "trie::parallel_proof",
149 ?hashed_address,
150 "Starting proof calculation"
151 );
152
153 let task_start = Instant::now();
154 let result = (|| -> Result<_, ParallelStateRootError> {
155 let provider_start = Instant::now();
156 let provider_ro = view.provider_ro()?;
157 trace!(
158 target: "trie::parallel_proof",
159 ?hashed_address,
160 provider_time = ?provider_start.elapsed(),
161 "Got provider"
162 );
163
164 let cursor_start = Instant::now();
165 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
166 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
167 &trie_nodes_sorted,
168 );
169 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
170 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
171 &hashed_state_sorted,
172 );
173 trace!(
174 target: "trie::parallel_proof",
175 ?hashed_address,
176 cursor_time = ?cursor_start.elapsed(),
177 "Created cursors"
178 );
179
180 let target_slots_len = target_slots.len();
181 let proof_start = Instant::now();
182 let proof_result = StorageProof::new_hashed(
183 trie_cursor_factory,
184 hashed_cursor_factory,
185 hashed_address,
186 )
187 .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
188 .with_branch_node_masks(collect_masks)
189 .storage_multiproof(target_slots)
190 .map_err(|e| ParallelStateRootError::Other(e.to_string()));
191
192 trace!(
193 target: "trie::parallel_proof",
194 ?hashed_address,
195 prefix_set = ?prefix_set.len(),
196 target_slots = ?target_slots_len,
197 proof_time = ?proof_start.elapsed(),
198 "Completed proof calculation"
199 );
200
201 proof_result
202 })();
203
204 if let Err(e) = tx.send(result) {
208 debug!(
209 target: "trie::parallel_proof",
210 ?hashed_address,
211 error = ?e,
212 task_time = ?task_start.elapsed(),
213 "Failed to send proof result"
214 );
215 }
216 });
217
218 storage_proofs.insert(hashed_address, rx);
221 }
222
223 let provider_ro = self.view.provider_ro()?;
224 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
225 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
226 &self.nodes_sorted,
227 );
228 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
229 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
230 &self.state_sorted,
231 );
232
233 let walker = TrieWalker::new(
235 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
236 prefix_sets.account_prefix_set,
237 )
238 .with_deletions_retained(true);
239
240 let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
242 let mut hash_builder = HashBuilder::default()
243 .with_proof_retainer(retainer)
244 .with_updates(self.collect_branch_node_masks);
245
246 let mut storages: B256Map<_> =
249 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
250 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
251 let mut account_node_iter = TrieNodeIter::new(
252 walker,
253 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
254 );
255 while let Some(account_node) =
256 account_node_iter.try_next().map_err(ProviderError::Database)?
257 {
258 match account_node {
259 TrieElement::Branch(node) => {
260 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
261 }
262 TrieElement::Leaf(hashed_address, account) => {
263 let storage_multiproof = match storage_proofs.remove(&hashed_address) {
264 Some(rx) => rx.recv().map_err(|_| {
265 ParallelStateRootError::StorageRoot(StorageRootError::Database(
266 DatabaseError::Other(format!(
267 "channel closed for {hashed_address}"
268 )),
269 ))
270 })??,
271 None => {
274 tracker.inc_missed_leaves();
275 StorageProof::new_hashed(
276 trie_cursor_factory.clone(),
277 hashed_cursor_factory.clone(),
278 hashed_address,
279 )
280 .with_prefix_set_mut(Default::default())
281 .storage_multiproof(
282 targets.get(&hashed_address).cloned().unwrap_or_default(),
283 )
284 .map_err(|e| {
285 ParallelStateRootError::StorageRoot(StorageRootError::Database(
286 DatabaseError::Other(e.to_string()),
287 ))
288 })?
289 }
290 };
291
292 account_rlp.clear();
294 let account = account.into_trie_account(storage_multiproof.root);
295 account.encode(&mut account_rlp as &mut dyn BufMut);
296
297 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
298
299 if targets.contains_key(&hashed_address) {
301 storages.insert(hashed_address, storage_multiproof);
302 }
303 }
304 }
305 }
306 let _ = hash_builder.root();
307
308 let stats = tracker.finish();
309 #[cfg(feature = "metrics")]
310 self.metrics.record(stats);
311
312 let account_subtree = hash_builder.take_proof_nodes();
313 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
314 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
315 (
316 updated_branch_nodes
317 .iter()
318 .map(|(path, node)| (path.clone(), node.hash_mask))
319 .collect(),
320 updated_branch_nodes
321 .into_iter()
322 .map(|(path, node)| (path, node.tree_mask))
323 .collect(),
324 )
325 } else {
326 (HashMap::default(), HashMap::default())
327 };
328
329 debug!(
330 target: "trie::parallel_proof",
331 total_targets = storage_root_targets_len,
332 duration = ?stats.duration(),
333 branches_added = stats.branches_added(),
334 leaves_added = stats.leaves_added(),
335 missed_leaves = stats.missed_leaves(),
336 precomputed_storage_roots = stats.precomputed_storage_roots(),
337 "Calculated proof"
338 );
339
340 Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use alloy_primitives::{
348 keccak256,
349 map::{B256Set, DefaultHashBuilder},
350 Address, U256,
351 };
352 use rand::Rng;
353 use reth_primitives::{Account, StorageEntry};
354 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
355 use reth_trie::proof::Proof;
356 use tokio::runtime::Runtime;
357
358 #[test]
359 fn random_parallel_proof() {
360 let factory = create_test_provider_factory();
361 let consistent_view = ConsistentDbView::new(factory.clone(), None);
362
363 let mut rng = rand::thread_rng();
364 let state = (0..100)
365 .map(|_| {
366 let address = Address::random();
367 let account =
368 Account { balance: U256::from(rng.gen::<u64>()), ..Default::default() };
369 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
370 let has_storage = rng.gen_bool(0.7);
371 if has_storage {
372 for _ in 0..100 {
373 storage.insert(
374 B256::from(U256::from(rng.gen::<u64>())),
375 U256::from(rng.gen::<u64>()),
376 );
377 }
378 }
379 (address, (account, storage))
380 })
381 .collect::<HashMap<_, _, DefaultHashBuilder>>();
382
383 {
384 let provider_rw = factory.provider_rw().unwrap();
385 provider_rw
386 .insert_account_for_hashing(
387 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
388 )
389 .unwrap();
390 provider_rw
391 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
392 (
393 *address,
394 storage
395 .iter()
396 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
397 )
398 }))
399 .unwrap();
400 provider_rw.commit().unwrap();
401 }
402
403 let mut targets = MultiProofTargets::default();
404 for (address, (_, storage)) in state.iter().take(10) {
405 let hashed_address = keccak256(*address);
406 let mut target_slots = B256Set::default();
407
408 for (slot, _) in storage.iter().take(5) {
409 target_slots.insert(*slot);
410 }
411
412 if !target_slots.is_empty() {
413 targets.insert(hashed_address, target_slots);
414 }
415 }
416
417 let provider_rw = factory.provider_rw().unwrap();
418 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
419 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
420
421 let rt = Runtime::new().unwrap();
422
423 assert_eq!(
424 ParallelProof::new(
425 consistent_view,
426 Default::default(),
427 Default::default(),
428 Default::default(),
429 rt.handle().clone()
430 )
431 .multiproof(targets.clone())
432 .unwrap(),
433 Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap()
434 );
435 }
436}