reth_trie_parallel/
proof.rs

1use crate::{
2    metrics::ParallelTrieMetrics,
3    proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4    root::ParallelStateRootError,
5    stats::ParallelTrieTracker,
6    StorageRootTargets,
7};
8use alloy_primitives::{
9    map::{B256Map, B256Set, HashMap},
10    B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use dashmap::DashMap;
14use itertools::Itertools;
15use reth_execution_errors::StorageRootError;
16use reth_provider::{
17    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
18    ProviderError,
19};
20use reth_storage_errors::db::DatabaseError;
21use reth_trie::{
22    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
23    node_iter::{TrieElement, TrieNodeIter},
24    prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
25    proof::StorageProof,
26    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
27    updates::TrieUpdatesSorted,
28    walker::TrieWalker,
29    DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted,
30    MultiProofTargets, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
31};
32use reth_trie_common::{
33    added_removed_keys::MultiAddedRemovedKeys,
34    proof::{DecodedProofNodes, ProofRetainer},
35};
36use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
37use std::sync::{mpsc::Receiver, Arc};
38use tracing::trace;
39
40/// Parallel proof calculator.
41///
42/// This can collect proof for many targets in parallel, spawning a task for each hashed address
43/// that has proof targets.
44#[derive(Debug)]
45pub struct ParallelProof<Factory: DatabaseProviderFactory> {
46    /// Consistent view of the database.
47    view: ConsistentDbView<Factory>,
48    /// The sorted collection of cached in-memory intermediate trie nodes that
49    /// can be reused for computation.
50    pub nodes_sorted: Arc<TrieUpdatesSorted>,
51    /// The sorted in-memory overlay hashed state.
52    pub state_sorted: Arc<HashedPostStateSorted>,
53    /// The collection of prefix sets for the computation. Since the prefix sets _always_
54    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
55    /// if we have cached nodes for them.
56    pub prefix_sets: Arc<TriePrefixSetsMut>,
57    /// Flag indicating whether to include branch node masks in the proof.
58    collect_branch_node_masks: bool,
59    /// Provided by the user to give the necessary context to retain extra proofs.
60    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
61    /// Handle to the storage proof task.
62    storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
63    /// Cached storage proof roots for missed leaves; this maps
64    /// hashed (missed) addresses to their storage proof roots.
65    missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
66    #[cfg(feature = "metrics")]
67    metrics: ParallelTrieMetrics,
68}
69
70impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
71    /// Create new state proof generator.
72    pub fn new(
73        view: ConsistentDbView<Factory>,
74        nodes_sorted: Arc<TrieUpdatesSorted>,
75        state_sorted: Arc<HashedPostStateSorted>,
76        prefix_sets: Arc<TriePrefixSetsMut>,
77        missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
78        storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
79    ) -> Self {
80        Self {
81            view,
82            nodes_sorted,
83            state_sorted,
84            prefix_sets,
85            missed_leaves_storage_roots,
86            collect_branch_node_masks: false,
87            multi_added_removed_keys: None,
88            storage_proof_task_handle,
89            #[cfg(feature = "metrics")]
90            metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
91        }
92    }
93
94    /// Set the flag indicating whether to include branch node masks in the proof.
95    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
96        self.collect_branch_node_masks = branch_node_masks;
97        self
98    }
99
100    /// Configure the `ParallelProof` with a [`MultiAddedRemovedKeys`], allowing for retaining
101    /// extra proofs needed to add and remove leaf nodes from the tries.
102    pub fn with_multi_added_removed_keys(
103        mut self,
104        multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
105    ) -> Self {
106        self.multi_added_removed_keys = multi_added_removed_keys;
107        self
108    }
109}
110
111impl<Factory> ParallelProof<Factory>
112where
113    Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
114{
115    /// Queues a storage proof task and returns a receiver for the result.
116    fn queue_storage_proof(
117        &self,
118        hashed_address: B256,
119        prefix_set: PrefixSet,
120        target_slots: B256Set,
121    ) -> Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>> {
122        let input = StorageProofInput::new(
123            hashed_address,
124            prefix_set,
125            target_slots,
126            self.collect_branch_node_masks,
127            self.multi_added_removed_keys.clone(),
128        );
129
130        let (sender, receiver) = std::sync::mpsc::channel();
131        let _ =
132            self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
133        receiver
134    }
135
136    /// Generate a storage multiproof according to the specified targets and hashed address.
137    pub fn storage_proof(
138        self,
139        hashed_address: B256,
140        target_slots: B256Set,
141    ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
142        let total_targets = target_slots.len();
143        let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
144        let prefix_set = prefix_set.freeze();
145
146        trace!(
147            target: "trie::parallel_proof",
148            total_targets,
149            ?hashed_address,
150            "Starting storage proof generation"
151        );
152
153        let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
154        let proof_result = receiver.recv().map_err(|_| {
155            ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
156                format!("channel closed for {hashed_address}"),
157            )))
158        })?;
159
160        trace!(
161            target: "trie::parallel_proof",
162            total_targets,
163            ?hashed_address,
164            "Storage proof generation completed"
165        );
166
167        proof_result
168    }
169
170    /// Generate a state multiproof according to specified targets.
171    pub fn decoded_multiproof(
172        self,
173        targets: MultiProofTargets,
174    ) -> Result<DecodedMultiProof, ParallelStateRootError> {
175        let mut tracker = ParallelTrieTracker::default();
176
177        // Extend prefix sets with targets
178        let mut prefix_sets = (*self.prefix_sets).clone();
179        prefix_sets.extend(TriePrefixSetsMut {
180            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
181            storage_prefix_sets: targets
182                .iter()
183                .filter(|&(_hashed_address, slots)| !slots.is_empty())
184                .map(|(hashed_address, slots)| {
185                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
186                })
187                .collect(),
188            destroyed_accounts: Default::default(),
189        });
190        let prefix_sets = prefix_sets.freeze();
191
192        let storage_root_targets = StorageRootTargets::new(
193            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
194            prefix_sets.storage_prefix_sets.clone(),
195        );
196        let storage_root_targets_len = storage_root_targets.len();
197
198        trace!(
199            target: "trie::parallel_proof",
200            total_targets = storage_root_targets_len,
201            "Starting parallel proof generation"
202        );
203
204        // Pre-calculate storage roots for accounts which were changed.
205        tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
206
207        // stores the receiver for the storage proof outcome for the hashed addresses
208        // this way we can lazily await the outcome when we iterate over the map
209        let mut storage_proof_receivers =
210            B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
211
212        for (hashed_address, prefix_set) in
213            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
214        {
215            let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
216            let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
217
218            // store the receiver for that result with the hashed address so we can await this in
219            // place when we iterate over the trie
220            storage_proof_receivers.insert(hashed_address, receiver);
221        }
222
223        let provider_ro = self.view.provider_ro()?;
224        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
225            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
226            &self.nodes_sorted,
227        );
228        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
229            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
230            &self.state_sorted,
231        );
232
233        let accounts_added_removed_keys =
234            self.multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
235
236        // Create the walker.
237        let walker = TrieWalker::<_>::state_trie(
238            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
239            prefix_sets.account_prefix_set,
240        )
241        .with_added_removed_keys(accounts_added_removed_keys)
242        .with_deletions_retained(true);
243
244        // Create a hash builder to rebuild the root node since it is not available in the database.
245        let retainer = targets
246            .keys()
247            .map(Nibbles::unpack)
248            .collect::<ProofRetainer>()
249            .with_added_removed_keys(accounts_added_removed_keys);
250        let mut hash_builder = HashBuilder::default()
251            .with_proof_retainer(retainer)
252            .with_updates(self.collect_branch_node_masks);
253
254        // Initialize all storage multiproofs as empty.
255        // Storage multiproofs for non empty tries will be overwritten if necessary.
256        let mut collected_decoded_storages: B256Map<DecodedStorageMultiProof> =
257            targets.keys().map(|key| (*key, DecodedStorageMultiProof::empty())).collect();
258        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
259        let mut account_node_iter = TrieNodeIter::state_trie(
260            walker,
261            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
262        );
263        while let Some(account_node) =
264            account_node_iter.try_next().map_err(ProviderError::Database)?
265        {
266            match account_node {
267                TrieElement::Branch(node) => {
268                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
269                }
270                TrieElement::Leaf(hashed_address, account) => {
271                    let root = match storage_proof_receivers.remove(&hashed_address) {
272                        Some(rx) => {
273                            let decoded_storage_multiproof = rx.recv().map_err(|e| {
274                                ParallelStateRootError::StorageRoot(StorageRootError::Database(
275                                    DatabaseError::Other(format!(
276                                        "channel closed for {hashed_address}: {e}"
277                                    )),
278                                ))
279                            })??;
280                            let root = decoded_storage_multiproof.root;
281                            collected_decoded_storages
282                                .insert(hashed_address, decoded_storage_multiproof);
283                            root
284                        }
285                        // Since we do not store all intermediate nodes in the database, there might
286                        // be a possibility of re-adding a non-modified leaf to the hash builder.
287                        None => {
288                            tracker.inc_missed_leaves();
289
290                            match self.missed_leaves_storage_roots.entry(hashed_address) {
291                                dashmap::Entry::Occupied(occ) => *occ.get(),
292                                dashmap::Entry::Vacant(vac) => {
293                                    let root = StorageProof::new_hashed(
294                                        trie_cursor_factory.clone(),
295                                        hashed_cursor_factory.clone(),
296                                        hashed_address,
297                                    )
298                                    .with_prefix_set_mut(Default::default())
299                                    .storage_multiproof(
300                                        targets.get(&hashed_address).cloned().unwrap_or_default(),
301                                    )
302                                    .map_err(|e| {
303                                        ParallelStateRootError::StorageRoot(
304                                            StorageRootError::Database(DatabaseError::Other(
305                                                e.to_string(),
306                                            )),
307                                        )
308                                    })?
309                                    .root;
310                                    vac.insert(root);
311                                    root
312                                }
313                            }
314                        }
315                    };
316
317                    // Encode account
318                    account_rlp.clear();
319                    let account = account.into_trie_account(root);
320                    account.encode(&mut account_rlp as &mut dyn BufMut);
321
322                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
323                }
324            }
325        }
326        let _ = hash_builder.root();
327
328        let stats = tracker.finish();
329        #[cfg(feature = "metrics")]
330        self.metrics.record(stats);
331
332        let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
333        let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
334
335        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
336            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
337            (
338                updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
339                updated_branch_nodes
340                    .into_iter()
341                    .map(|(path, node)| (path, node.tree_mask))
342                    .collect(),
343            )
344        } else {
345            (HashMap::default(), HashMap::default())
346        };
347
348        trace!(
349            target: "trie::parallel_proof",
350            total_targets = storage_root_targets_len,
351            duration = ?stats.duration(),
352            branches_added = stats.branches_added(),
353            leaves_added = stats.leaves_added(),
354            missed_leaves = stats.missed_leaves(),
355            precomputed_storage_roots = stats.precomputed_storage_roots(),
356            "Calculated decoded proof"
357        );
358
359        Ok(DecodedMultiProof {
360            account_subtree: decoded_account_subtree,
361            branch_node_hash_masks,
362            branch_node_tree_masks,
363            storages: collected_decoded_storages,
364        })
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
372    use alloy_primitives::{
373        keccak256,
374        map::{B256Set, DefaultHashBuilder},
375        Address, U256,
376    };
377    use rand::Rng;
378    use reth_primitives_traits::{Account, StorageEntry};
379    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
380    use reth_trie::proof::Proof;
381    use tokio::runtime::Runtime;
382
383    #[test]
384    fn random_parallel_proof() {
385        let factory = create_test_provider_factory();
386        let consistent_view = ConsistentDbView::new(factory.clone(), None);
387
388        let mut rng = rand::rng();
389        let state = (0..100)
390            .map(|_| {
391                let address = Address::random();
392                let account =
393                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
394                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
395                let has_storage = rng.random_bool(0.7);
396                if has_storage {
397                    for _ in 0..100 {
398                        storage.insert(
399                            B256::from(U256::from(rng.random::<u64>())),
400                            U256::from(rng.random::<u64>()),
401                        );
402                    }
403                }
404                (address, (account, storage))
405            })
406            .collect::<HashMap<_, _, DefaultHashBuilder>>();
407
408        {
409            let provider_rw = factory.provider_rw().unwrap();
410            provider_rw
411                .insert_account_for_hashing(
412                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
413                )
414                .unwrap();
415            provider_rw
416                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
417                    (
418                        *address,
419                        storage
420                            .iter()
421                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
422                    )
423                }))
424                .unwrap();
425            provider_rw.commit().unwrap();
426        }
427
428        let mut targets = MultiProofTargets::default();
429        for (address, (_, storage)) in state.iter().take(10) {
430            let hashed_address = keccak256(*address);
431            let mut target_slots = B256Set::default();
432
433            for (slot, _) in storage.iter().take(5) {
434                target_slots.insert(*slot);
435            }
436
437            if !target_slots.is_empty() {
438                targets.insert(hashed_address, target_slots);
439            }
440        }
441
442        let provider_rw = factory.provider_rw().unwrap();
443        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
444        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
445
446        let rt = Runtime::new().unwrap();
447
448        let task_ctx =
449            ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
450        let proof_task =
451            ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1, 1)
452                .unwrap();
453        let proof_task_handle = proof_task.handle();
454
455        // keep the join handle around to make sure it does not return any errors
456        // after we compute the state root
457        let join_handle = rt.spawn_blocking(move || proof_task.run());
458
459        let parallel_result = ParallelProof::new(
460            consistent_view,
461            Default::default(),
462            Default::default(),
463            Default::default(),
464            Default::default(),
465            proof_task_handle.clone(),
466        )
467        .decoded_multiproof(targets.clone())
468        .unwrap();
469
470        let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
471            .multiproof(targets.clone())
472            .unwrap(); // targets might be consumed by parallel_result
473        let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
474            .try_into()
475            .expect("Failed to decode sequential_result for test comparison");
476
477        // to help narrow down what is wrong - first compare account subtries
478        assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
479
480        // then compare length of all storage subtries
481        assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
482
483        // then compare each storage subtrie
484        for (hashed_address, storage_proof) in &parallel_result.storages {
485            let sequential_storage_proof =
486                sequential_result_decoded.storages.get(hashed_address).unwrap();
487            assert_eq!(storage_proof, sequential_storage_proof);
488        }
489
490        // then compare the entire thing for any mask differences
491        assert_eq!(parallel_result, sequential_result_decoded);
492
493        // drop the handle to terminate the task and then block on the proof task handle to make
494        // sure it does not return any errors
495        drop(proof_task_handle);
496        rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
497    }
498}