reth_trie_parallel/
proof.rs

1use crate::{
2    metrics::ParallelTrieMetrics,
3    proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4    root::ParallelStateRootError,
5    stats::ParallelTrieTracker,
6    StorageRootTargets,
7};
8use alloy_primitives::{
9    map::{B256Map, B256Set, HashMap},
10    B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use itertools::Itertools;
14use reth_execution_errors::StorageRootError;
15use reth_provider::{
16    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17    ProviderError,
18};
19use reth_storage_errors::db::DatabaseError;
20use reth_trie::{
21    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22    node_iter::{TrieElement, TrieNodeIter},
23    prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
24    proof::StorageProof,
25    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
26    updates::TrieUpdatesSorted,
27    walker::TrieWalker,
28    DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted,
29    MultiProofTargets, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
30};
31use reth_trie_common::{
32    added_removed_keys::MultiAddedRemovedKeys,
33    proof::{DecodedProofNodes, ProofRetainer},
34};
35use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
36use std::sync::{mpsc::Receiver, Arc};
37use tracing::debug;
38
39/// Parallel proof calculator.
40///
41/// This can collect proof for many targets in parallel, spawning a task for each hashed address
42/// that has proof targets.
43#[derive(Debug)]
44pub struct ParallelProof<Factory: DatabaseProviderFactory> {
45    /// Consistent view of the database.
46    view: ConsistentDbView<Factory>,
47    /// The sorted collection of cached in-memory intermediate trie nodes that
48    /// can be reused for computation.
49    pub nodes_sorted: Arc<TrieUpdatesSorted>,
50    /// The sorted in-memory overlay hashed state.
51    pub state_sorted: Arc<HashedPostStateSorted>,
52    /// The collection of prefix sets for the computation. Since the prefix sets _always_
53    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
54    /// if we have cached nodes for them.
55    pub prefix_sets: Arc<TriePrefixSetsMut>,
56    /// Flag indicating whether to include branch node masks in the proof.
57    collect_branch_node_masks: bool,
58    /// Provided by the user to give the necessary context to retain extra proofs.
59    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
60    /// Handle to the storage proof task.
61    storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
62    #[cfg(feature = "metrics")]
63    metrics: ParallelTrieMetrics,
64}
65
66impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
67    /// Create new state proof generator.
68    pub fn new(
69        view: ConsistentDbView<Factory>,
70        nodes_sorted: Arc<TrieUpdatesSorted>,
71        state_sorted: Arc<HashedPostStateSorted>,
72        prefix_sets: Arc<TriePrefixSetsMut>,
73        storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
74    ) -> Self {
75        Self {
76            view,
77            nodes_sorted,
78            state_sorted,
79            prefix_sets,
80            collect_branch_node_masks: false,
81            multi_added_removed_keys: None,
82            storage_proof_task_handle,
83            #[cfg(feature = "metrics")]
84            metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
85        }
86    }
87
88    /// Set the flag indicating whether to include branch node masks in the proof.
89    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
90        self.collect_branch_node_masks = branch_node_masks;
91        self
92    }
93
94    /// Configure the `ParallelProof` with a [`MultiAddedRemovedKeys`], allowing for retaining
95    /// extra proofs needed to add and remove leaf nodes from the tries.
96    pub fn with_multi_added_removed_keys(
97        mut self,
98        multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
99    ) -> Self {
100        self.multi_added_removed_keys = multi_added_removed_keys;
101        self
102    }
103}
104
105impl<Factory> ParallelProof<Factory>
106where
107    Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
108{
109    /// Spawns a storage proof on the storage proof task and returns a receiver for the result.
110    fn spawn_storage_proof(
111        &self,
112        hashed_address: B256,
113        prefix_set: PrefixSet,
114        target_slots: B256Set,
115    ) -> Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>> {
116        let input = StorageProofInput::new(
117            hashed_address,
118            prefix_set,
119            target_slots,
120            self.collect_branch_node_masks,
121            self.multi_added_removed_keys.clone(),
122        );
123
124        let (sender, receiver) = std::sync::mpsc::channel();
125        let _ =
126            self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
127        receiver
128    }
129
130    /// Generate a storage multiproof according to the specified targets and hashed address.
131    pub fn storage_proof(
132        self,
133        hashed_address: B256,
134        target_slots: B256Set,
135    ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
136        let total_targets = target_slots.len();
137        let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
138        let prefix_set = prefix_set.freeze();
139
140        debug!(
141            target: "trie::parallel_proof",
142            total_targets,
143            ?hashed_address,
144            "Starting storage proof generation"
145        );
146
147        let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
148        let proof_result = receiver.recv().map_err(|_| {
149            ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
150                format!("channel closed for {hashed_address}"),
151            )))
152        })?;
153
154        debug!(
155            target: "trie::parallel_proof",
156            total_targets,
157            ?hashed_address,
158            "Storage proof generation completed"
159        );
160
161        proof_result
162    }
163
164    /// Generate a [`DecodedStorageMultiProof`] for the given proof by first calling
165    /// `storage_proof`, then decoding the proof nodes.
166    pub fn decoded_storage_proof(
167        self,
168        hashed_address: B256,
169        target_slots: B256Set,
170    ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
171        self.storage_proof(hashed_address, target_slots)
172    }
173
174    /// Generate a state multiproof according to specified targets.
175    pub fn decoded_multiproof(
176        self,
177        targets: MultiProofTargets,
178    ) -> Result<DecodedMultiProof, ParallelStateRootError> {
179        let mut tracker = ParallelTrieTracker::default();
180
181        // Extend prefix sets with targets
182        let mut prefix_sets = (*self.prefix_sets).clone();
183        prefix_sets.extend(TriePrefixSetsMut {
184            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
185            storage_prefix_sets: targets
186                .iter()
187                .filter(|&(_hashed_address, slots)| !slots.is_empty())
188                .map(|(hashed_address, slots)| {
189                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
190                })
191                .collect(),
192            destroyed_accounts: Default::default(),
193        });
194        let prefix_sets = prefix_sets.freeze();
195
196        let storage_root_targets = StorageRootTargets::new(
197            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
198            prefix_sets.storage_prefix_sets.clone(),
199        );
200        let storage_root_targets_len = storage_root_targets.len();
201
202        debug!(
203            target: "trie::parallel_proof",
204            total_targets = storage_root_targets_len,
205            "Starting parallel proof generation"
206        );
207
208        // Pre-calculate storage roots for accounts which were changed.
209        tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
210
211        // stores the receiver for the storage proof outcome for the hashed addresses
212        // this way we can lazily await the outcome when we iterate over the map
213        let mut storage_proof_receivers =
214            B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
215
216        for (hashed_address, prefix_set) in
217            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
218        {
219            let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
220            let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
221
222            // store the receiver for that result with the hashed address so we can await this in
223            // place when we iterate over the trie
224            storage_proof_receivers.insert(hashed_address, receiver);
225        }
226
227        let provider_ro = self.view.provider_ro()?;
228        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
229            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
230            &self.nodes_sorted,
231        );
232        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
233            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
234            &self.state_sorted,
235        );
236
237        let accounts_added_removed_keys =
238            self.multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
239
240        // Create the walker.
241        let walker = TrieWalker::<_>::state_trie(
242            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
243            prefix_sets.account_prefix_set,
244        )
245        .with_added_removed_keys(accounts_added_removed_keys)
246        .with_deletions_retained(true);
247
248        // Create a hash builder to rebuild the root node since it is not available in the database.
249        let retainer = targets
250            .keys()
251            .map(Nibbles::unpack)
252            .collect::<ProofRetainer>()
253            .with_added_removed_keys(accounts_added_removed_keys);
254        let mut hash_builder = HashBuilder::default()
255            .with_proof_retainer(retainer)
256            .with_updates(self.collect_branch_node_masks);
257
258        // Initialize all storage multiproofs as empty.
259        // Storage multiproofs for non empty tries will be overwritten if necessary.
260        let mut collected_decoded_storages: B256Map<DecodedStorageMultiProof> =
261            targets.keys().map(|key| (*key, DecodedStorageMultiProof::empty())).collect();
262        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
263        let mut account_node_iter = TrieNodeIter::state_trie(
264            walker,
265            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
266        );
267        while let Some(account_node) =
268            account_node_iter.try_next().map_err(ProviderError::Database)?
269        {
270            match account_node {
271                TrieElement::Branch(node) => {
272                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
273                }
274                TrieElement::Leaf(hashed_address, account) => {
275                    let decoded_storage_multiproof = match storage_proof_receivers
276                        .remove(&hashed_address)
277                    {
278                        Some(rx) => rx.recv().map_err(|e| {
279                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
280                                DatabaseError::Other(format!(
281                                    "channel closed for {hashed_address}: {e}"
282                                )),
283                            ))
284                        })??,
285                        // Since we do not store all intermediate nodes in the database, there might
286                        // be a possibility of re-adding a non-modified leaf to the hash builder.
287                        None => {
288                            tracker.inc_missed_leaves();
289
290                            let raw_fallback_proof = StorageProof::new_hashed(
291                                trie_cursor_factory.clone(),
292                                hashed_cursor_factory.clone(),
293                                hashed_address,
294                            )
295                            .with_prefix_set_mut(Default::default())
296                            .storage_multiproof(
297                                targets.get(&hashed_address).cloned().unwrap_or_default(),
298                            )
299                            .map_err(|e| {
300                                ParallelStateRootError::StorageRoot(StorageRootError::Database(
301                                    DatabaseError::Other(e.to_string()),
302                                ))
303                            })?;
304
305                            raw_fallback_proof.try_into()?
306                        }
307                    };
308
309                    // Encode account
310                    account_rlp.clear();
311                    let account = account.into_trie_account(decoded_storage_multiproof.root);
312                    account.encode(&mut account_rlp as &mut dyn BufMut);
313
314                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
315
316                    // We might be adding leaves that are not necessarily our proof targets.
317                    if targets.contains_key(&hashed_address) {
318                        collected_decoded_storages
319                            .insert(hashed_address, decoded_storage_multiproof);
320                    }
321                }
322            }
323        }
324        let _ = hash_builder.root();
325
326        let stats = tracker.finish();
327        #[cfg(feature = "metrics")]
328        self.metrics.record(stats);
329
330        let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
331        let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
332
333        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
334            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
335            (
336                updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
337                updated_branch_nodes
338                    .into_iter()
339                    .map(|(path, node)| (path, node.tree_mask))
340                    .collect(),
341            )
342        } else {
343            (HashMap::default(), HashMap::default())
344        };
345
346        debug!(
347            target: "trie::parallel_proof",
348            total_targets = storage_root_targets_len,
349            duration = ?stats.duration(),
350            branches_added = stats.branches_added(),
351            leaves_added = stats.leaves_added(),
352            missed_leaves = stats.missed_leaves(),
353            precomputed_storage_roots = stats.precomputed_storage_roots(),
354            "Calculated decoded proof"
355        );
356
357        Ok(DecodedMultiProof {
358            account_subtree: decoded_account_subtree,
359            branch_node_hash_masks,
360            branch_node_tree_masks,
361            storages: collected_decoded_storages,
362        })
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
370    use alloy_primitives::{
371        keccak256,
372        map::{B256Set, DefaultHashBuilder},
373        Address, U256,
374    };
375    use rand::Rng;
376    use reth_primitives_traits::{Account, StorageEntry};
377    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
378    use reth_trie::proof::Proof;
379    use tokio::runtime::Runtime;
380
381    #[test]
382    fn random_parallel_proof() {
383        let factory = create_test_provider_factory();
384        let consistent_view = ConsistentDbView::new(factory.clone(), None);
385
386        let mut rng = rand::rng();
387        let state = (0..100)
388            .map(|_| {
389                let address = Address::random();
390                let account =
391                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
392                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
393                let has_storage = rng.random_bool(0.7);
394                if has_storage {
395                    for _ in 0..100 {
396                        storage.insert(
397                            B256::from(U256::from(rng.random::<u64>())),
398                            U256::from(rng.random::<u64>()),
399                        );
400                    }
401                }
402                (address, (account, storage))
403            })
404            .collect::<HashMap<_, _, DefaultHashBuilder>>();
405
406        {
407            let provider_rw = factory.provider_rw().unwrap();
408            provider_rw
409                .insert_account_for_hashing(
410                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
411                )
412                .unwrap();
413            provider_rw
414                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
415                    (
416                        *address,
417                        storage
418                            .iter()
419                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
420                    )
421                }))
422                .unwrap();
423            provider_rw.commit().unwrap();
424        }
425
426        let mut targets = MultiProofTargets::default();
427        for (address, (_, storage)) in state.iter().take(10) {
428            let hashed_address = keccak256(*address);
429            let mut target_slots = B256Set::default();
430
431            for (slot, _) in storage.iter().take(5) {
432                target_slots.insert(*slot);
433            }
434
435            if !target_slots.is_empty() {
436                targets.insert(hashed_address, target_slots);
437            }
438        }
439
440        let provider_rw = factory.provider_rw().unwrap();
441        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
442        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
443
444        let rt = Runtime::new().unwrap();
445
446        let task_ctx =
447            ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
448        let proof_task =
449            ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
450        let proof_task_handle = proof_task.handle();
451
452        // keep the join handle around to make sure it does not return any errors
453        // after we compute the state root
454        let join_handle = rt.spawn_blocking(move || proof_task.run());
455
456        let parallel_result = ParallelProof::new(
457            consistent_view,
458            Default::default(),
459            Default::default(),
460            Default::default(),
461            proof_task_handle.clone(),
462        )
463        .decoded_multiproof(targets.clone())
464        .unwrap();
465
466        let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
467            .multiproof(targets.clone())
468            .unwrap(); // targets might be consumed by parallel_result
469        let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
470            .try_into()
471            .expect("Failed to decode sequential_result for test comparison");
472
473        // to help narrow down what is wrong - first compare account subtries
474        assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
475
476        // then compare length of all storage subtries
477        assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
478
479        // then compare each storage subtrie
480        for (hashed_address, storage_proof) in &parallel_result.storages {
481            let sequential_storage_proof =
482                sequential_result_decoded.storages.get(hashed_address).unwrap();
483            assert_eq!(storage_proof, sequential_storage_proof);
484        }
485
486        // then compare the entire thing for any mask differences
487        assert_eq!(parallel_result, sequential_result_decoded);
488
489        // drop the handle to terminate the task and then block on the proof task handle to make
490        // sure it does not return any errors
491        drop(proof_task_handle);
492        rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
493    }
494}