reth_trie_parallel/
proof.rs

1use crate::{
2    metrics::ParallelTrieMetrics,
3    proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4    root::ParallelStateRootError,
5    stats::ParallelTrieTracker,
6    StorageRootTargets,
7};
8use alloy_primitives::{
9    map::{B256Map, B256Set, HashMap},
10    B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use itertools::Itertools;
14use reth_execution_errors::StorageRootError;
15use reth_provider::{
16    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17    ProviderError, StateCommitmentProvider,
18};
19use reth_storage_errors::db::DatabaseError;
20use reth_trie::{
21    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22    node_iter::{TrieElement, TrieNodeIter},
23    prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
24    proof::StorageProof,
25    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
26    updates::TrieUpdatesSorted,
27    walker::TrieWalker,
28    HashBuilder, HashedPostStateSorted, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof,
29    TRIE_ACCOUNT_RLP_MAX_SIZE,
30};
31use reth_trie_common::proof::ProofRetainer;
32use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
33use std::sync::{mpsc::Receiver, Arc};
34use tracing::debug;
35
36/// Parallel proof calculator.
37///
38/// This can collect proof for many targets in parallel, spawning a task for each hashed address
39/// that has proof targets.
40#[derive(Debug)]
41pub struct ParallelProof<Factory: DatabaseProviderFactory> {
42    /// Consistent view of the database.
43    view: ConsistentDbView<Factory>,
44    /// The sorted collection of cached in-memory intermediate trie nodes that
45    /// can be reused for computation.
46    pub nodes_sorted: Arc<TrieUpdatesSorted>,
47    /// The sorted in-memory overlay hashed state.
48    pub state_sorted: Arc<HashedPostStateSorted>,
49    /// The collection of prefix sets for the computation. Since the prefix sets _always_
50    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
51    /// if we have cached nodes for them.
52    pub prefix_sets: Arc<TriePrefixSetsMut>,
53    /// Flag indicating whether to include branch node masks in the proof.
54    collect_branch_node_masks: bool,
55    /// Handle to the storage proof task.
56    storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
57    #[cfg(feature = "metrics")]
58    metrics: ParallelTrieMetrics,
59}
60
61impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
62    /// Create new state proof generator.
63    pub fn new(
64        view: ConsistentDbView<Factory>,
65        nodes_sorted: Arc<TrieUpdatesSorted>,
66        state_sorted: Arc<HashedPostStateSorted>,
67        prefix_sets: Arc<TriePrefixSetsMut>,
68        storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
69    ) -> Self {
70        Self {
71            view,
72            nodes_sorted,
73            state_sorted,
74            prefix_sets,
75            collect_branch_node_masks: false,
76            storage_proof_task_handle,
77            #[cfg(feature = "metrics")]
78            metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
79        }
80    }
81
82    /// Set the flag indicating whether to include branch node masks in the proof.
83    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
84        self.collect_branch_node_masks = branch_node_masks;
85        self
86    }
87}
88
89impl<Factory> ParallelProof<Factory>
90where
91    Factory:
92        DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
93{
94    /// Spawns a storage proof on the storage proof task and returns a receiver for the result.
95    fn spawn_storage_proof(
96        &self,
97        hashed_address: B256,
98        prefix_set: PrefixSet,
99        target_slots: B256Set,
100    ) -> Receiver<Result<StorageMultiProof, ParallelStateRootError>> {
101        let input = StorageProofInput::new(
102            hashed_address,
103            prefix_set,
104            target_slots,
105            self.collect_branch_node_masks,
106        );
107
108        let (sender, receiver) = std::sync::mpsc::channel();
109        let _ =
110            self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
111        receiver
112    }
113
114    /// Generate a storage multiproof according to the specified targets and hashed address.
115    pub fn storage_proof(
116        self,
117        hashed_address: B256,
118        target_slots: B256Set,
119    ) -> Result<StorageMultiProof, ParallelStateRootError> {
120        let total_targets = target_slots.len();
121        let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
122        let prefix_set = prefix_set.freeze();
123
124        debug!(
125            target: "trie::parallel_proof",
126            total_targets,
127            ?hashed_address,
128            "Starting storage proof generation"
129        );
130
131        let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
132        let proof_result = receiver.recv().map_err(|_| {
133            ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
134                format!("channel closed for {hashed_address}"),
135            )))
136        })?;
137
138        debug!(
139            target: "trie::parallel_proof",
140            total_targets,
141            ?hashed_address,
142            "Storage proof generation completed"
143        );
144
145        proof_result
146    }
147
148    /// Generate a state multiproof according to specified targets.
149    pub fn multiproof(
150        self,
151        targets: MultiProofTargets,
152    ) -> Result<MultiProof, ParallelStateRootError> {
153        let mut tracker = ParallelTrieTracker::default();
154
155        // Extend prefix sets with targets
156        let mut prefix_sets = (*self.prefix_sets).clone();
157        prefix_sets.extend(TriePrefixSetsMut {
158            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
159            storage_prefix_sets: targets
160                .iter()
161                .filter(|&(_hashed_address, slots)| !slots.is_empty())
162                .map(|(hashed_address, slots)| {
163                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
164                })
165                .collect(),
166            destroyed_accounts: Default::default(),
167        });
168        let prefix_sets = prefix_sets.freeze();
169
170        let storage_root_targets = StorageRootTargets::new(
171            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
172            prefix_sets.storage_prefix_sets.clone(),
173        );
174        let storage_root_targets_len = storage_root_targets.len();
175
176        debug!(
177            target: "trie::parallel_proof",
178            total_targets = storage_root_targets_len,
179            "Starting parallel proof generation"
180        );
181
182        // Pre-calculate storage roots for accounts which were changed.
183        tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
184
185        // stores the receiver for the storage proof outcome for the hashed addresses
186        // this way we can lazily await the outcome when we iterate over the map
187        let mut storage_proofs =
188            B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
189
190        for (hashed_address, prefix_set) in
191            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
192        {
193            let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
194            let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
195
196            // store the receiver for that result with the hashed address so we can await this in
197            // place when we iterate over the trie
198            storage_proofs.insert(hashed_address, receiver);
199        }
200
201        let provider_ro = self.view.provider_ro()?;
202        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
203            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
204            &self.nodes_sorted,
205        );
206        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
207            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
208            &self.state_sorted,
209        );
210
211        // Create the walker.
212        let walker = TrieWalker::new(
213            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
214            prefix_sets.account_prefix_set,
215        )
216        .with_deletions_retained(true);
217
218        // Create a hash builder to rebuild the root node since it is not available in the database.
219        let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
220        let mut hash_builder = HashBuilder::default()
221            .with_proof_retainer(retainer)
222            .with_updates(self.collect_branch_node_masks);
223
224        // Initialize all storage multiproofs as empty.
225        // Storage multiproofs for non empty tries will be overwritten if necessary.
226        let mut storages: B256Map<_> =
227            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
228        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
229        let mut account_node_iter = TrieNodeIter::state_trie(
230            walker,
231            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
232        );
233        while let Some(account_node) =
234            account_node_iter.try_next().map_err(ProviderError::Database)?
235        {
236            match account_node {
237                TrieElement::Branch(node) => {
238                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
239                }
240                TrieElement::Leaf(hashed_address, account) => {
241                    let storage_multiproof = match storage_proofs.remove(&hashed_address) {
242                        Some(rx) => rx.recv().map_err(|_| {
243                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
244                                DatabaseError::Other(format!(
245                                    "channel closed for {hashed_address}"
246                                )),
247                            ))
248                        })??,
249                        // Since we do not store all intermediate nodes in the database, there might
250                        // be a possibility of re-adding a non-modified leaf to the hash builder.
251                        None => {
252                            tracker.inc_missed_leaves();
253                            StorageProof::new_hashed(
254                                trie_cursor_factory.clone(),
255                                hashed_cursor_factory.clone(),
256                                hashed_address,
257                            )
258                            .with_prefix_set_mut(Default::default())
259                            .storage_multiproof(
260                                targets.get(&hashed_address).cloned().unwrap_or_default(),
261                            )
262                            .map_err(|e| {
263                                ParallelStateRootError::StorageRoot(StorageRootError::Database(
264                                    DatabaseError::Other(e.to_string()),
265                                ))
266                            })?
267                        }
268                    };
269
270                    // Encode account
271                    account_rlp.clear();
272                    let account = account.into_trie_account(storage_multiproof.root);
273                    account.encode(&mut account_rlp as &mut dyn BufMut);
274
275                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
276
277                    // We might be adding leaves that are not necessarily our proof targets.
278                    if targets.contains_key(&hashed_address) {
279                        storages.insert(hashed_address, storage_multiproof);
280                    }
281                }
282            }
283        }
284        let _ = hash_builder.root();
285
286        let stats = tracker.finish();
287        #[cfg(feature = "metrics")]
288        self.metrics.record(stats);
289
290        let account_subtree = hash_builder.take_proof_nodes();
291        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
292            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
293            (
294                updated_branch_nodes
295                    .iter()
296                    .map(|(path, node)| (path.clone(), node.hash_mask))
297                    .collect(),
298                updated_branch_nodes
299                    .into_iter()
300                    .map(|(path, node)| (path, node.tree_mask))
301                    .collect(),
302            )
303        } else {
304            (HashMap::default(), HashMap::default())
305        };
306
307        debug!(
308            target: "trie::parallel_proof",
309            total_targets = storage_root_targets_len,
310            duration = ?stats.duration(),
311            branches_added = stats.branches_added(),
312            leaves_added = stats.leaves_added(),
313            missed_leaves = stats.missed_leaves(),
314            precomputed_storage_roots = stats.precomputed_storage_roots(),
315            "Calculated proof"
316        );
317
318        Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
326    use alloy_primitives::{
327        keccak256,
328        map::{B256Set, DefaultHashBuilder},
329        Address, U256,
330    };
331    use rand::Rng;
332    use reth_primitives_traits::{Account, StorageEntry};
333    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
334    use reth_trie::proof::Proof;
335    use tokio::runtime::Runtime;
336
337    #[test]
338    fn random_parallel_proof() {
339        let factory = create_test_provider_factory();
340        let consistent_view = ConsistentDbView::new(factory.clone(), None);
341
342        let mut rng = rand::rng();
343        let state = (0..100)
344            .map(|_| {
345                let address = Address::random();
346                let account =
347                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
348                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
349                let has_storage = rng.random_bool(0.7);
350                if has_storage {
351                    for _ in 0..100 {
352                        storage.insert(
353                            B256::from(U256::from(rng.random::<u64>())),
354                            U256::from(rng.random::<u64>()),
355                        );
356                    }
357                }
358                (address, (account, storage))
359            })
360            .collect::<HashMap<_, _, DefaultHashBuilder>>();
361
362        {
363            let provider_rw = factory.provider_rw().unwrap();
364            provider_rw
365                .insert_account_for_hashing(
366                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
367                )
368                .unwrap();
369            provider_rw
370                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
371                    (
372                        *address,
373                        storage
374                            .iter()
375                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
376                    )
377                }))
378                .unwrap();
379            provider_rw.commit().unwrap();
380        }
381
382        let mut targets = MultiProofTargets::default();
383        for (address, (_, storage)) in state.iter().take(10) {
384            let hashed_address = keccak256(*address);
385            let mut target_slots = B256Set::default();
386
387            for (slot, _) in storage.iter().take(5) {
388                target_slots.insert(*slot);
389            }
390
391            if !target_slots.is_empty() {
392                targets.insert(hashed_address, target_slots);
393            }
394        }
395
396        let provider_rw = factory.provider_rw().unwrap();
397        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
398        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
399
400        let rt = Runtime::new().unwrap();
401
402        let task_ctx =
403            ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
404        let proof_task =
405            ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
406        let proof_task_handle = proof_task.handle();
407
408        // keep the join handle around to make sure it does not return any errors
409        // after we compute the state root
410        let join_handle = rt.spawn_blocking(move || proof_task.run());
411
412        let parallel_result = ParallelProof::new(
413            consistent_view,
414            Default::default(),
415            Default::default(),
416            Default::default(),
417            proof_task_handle.clone(),
418        )
419        .multiproof(targets.clone())
420        .unwrap();
421
422        let sequential_result =
423            Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap();
424
425        // to help narrow down what is wrong - first compare account subtries
426        assert_eq!(parallel_result.account_subtree, sequential_result.account_subtree);
427
428        // then compare length of all storage subtries
429        assert_eq!(parallel_result.storages.len(), sequential_result.storages.len());
430
431        // then compare each storage subtrie
432        for (hashed_address, storage_proof) in &parallel_result.storages {
433            let sequential_storage_proof = sequential_result.storages.get(hashed_address).unwrap();
434            assert_eq!(storage_proof, sequential_storage_proof);
435        }
436
437        // then compare the entire thing for any mask differences
438        assert_eq!(parallel_result, sequential_result);
439
440        // drop the handle to terminate the task and then block on the proof task handle to make
441        // sure it does not return any errors
442        drop(proof_task_handle);
443        rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
444    }
445}