reth_trie_parallel/
proof.rs

1use crate::{
2    metrics::ParallelTrieMetrics,
3    proof_task::{
4        AccountMultiproofInput, ProofResultContext, ProofResultMessage, ProofWorkerHandle,
5        StorageProofInput,
6    },
7    root::ParallelStateRootError,
8    StorageRootTargets,
9};
10use alloy_primitives::{map::B256Set, B256};
11use crossbeam_channel::{unbounded as crossbeam_unbounded, Receiver as CrossbeamReceiver};
12use dashmap::DashMap;
13use reth_execution_errors::StorageRootError;
14use reth_storage_errors::db::DatabaseError;
15use reth_trie::{
16    prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSets, TriePrefixSetsMut},
17    DecodedMultiProof, DecodedStorageMultiProof, HashedPostState, MultiProofTargets, Nibbles,
18};
19use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
20use std::{sync::Arc, time::Instant};
21use tracing::trace;
22
23/// Parallel proof calculator.
24///
25/// This can collect proof for many targets in parallel, spawning a task for each hashed address
26/// that has proof targets.
27#[derive(Debug)]
28pub struct ParallelProof {
29    /// The collection of prefix sets for the computation.
30    pub prefix_sets: Arc<TriePrefixSetsMut>,
31    /// Flag indicating whether to include branch node masks in the proof.
32    collect_branch_node_masks: bool,
33    /// Provided by the user to give the necessary context to retain extra proofs.
34    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
35    /// Handle to the proof worker pools.
36    proof_worker_handle: ProofWorkerHandle,
37    /// Cached storage proof roots for missed leaves; this maps
38    /// hashed (missed) addresses to their storage proof roots.
39    missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
40    #[cfg(feature = "metrics")]
41    metrics: ParallelTrieMetrics,
42}
43
44impl ParallelProof {
45    /// Create new state proof generator.
46    pub fn new(
47        prefix_sets: Arc<TriePrefixSetsMut>,
48        missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
49        proof_worker_handle: ProofWorkerHandle,
50    ) -> Self {
51        Self {
52            prefix_sets,
53            missed_leaves_storage_roots,
54            collect_branch_node_masks: false,
55            multi_added_removed_keys: None,
56            proof_worker_handle,
57            #[cfg(feature = "metrics")]
58            metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
59        }
60    }
61
62    /// Set the flag indicating whether to include branch node masks in the proof.
63    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
64        self.collect_branch_node_masks = branch_node_masks;
65        self
66    }
67
68    /// Configure the `ParallelProof` with a [`MultiAddedRemovedKeys`], allowing for retaining
69    /// extra proofs needed to add and remove leaf nodes from the tries.
70    pub fn with_multi_added_removed_keys(
71        mut self,
72        multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
73    ) -> Self {
74        self.multi_added_removed_keys = multi_added_removed_keys;
75        self
76    }
77    /// Queues a storage proof task and returns a receiver for the result.
78    fn send_storage_proof(
79        &self,
80        hashed_address: B256,
81        prefix_set: PrefixSet,
82        target_slots: B256Set,
83    ) -> Result<CrossbeamReceiver<ProofResultMessage>, ParallelStateRootError> {
84        let (result_tx, result_rx) = crossbeam_channel::unbounded();
85        let start = Instant::now();
86
87        let input = StorageProofInput::new(
88            hashed_address,
89            prefix_set,
90            target_slots,
91            self.collect_branch_node_masks,
92            self.multi_added_removed_keys.clone(),
93        );
94
95        self.proof_worker_handle
96            .dispatch_storage_proof(
97                input,
98                ProofResultContext::new(result_tx, 0, HashedPostState::default(), start),
99            )
100            .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
101
102        Ok(result_rx)
103    }
104
105    /// Generate a storage multiproof according to the specified targets and hashed address.
106    pub fn storage_proof(
107        self,
108        hashed_address: B256,
109        target_slots: B256Set,
110    ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
111        let total_targets = target_slots.len();
112        let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
113        let prefix_set = prefix_set.freeze();
114
115        trace!(
116            target: "trie::parallel_proof",
117            total_targets,
118            ?hashed_address,
119            "Starting storage proof generation"
120        );
121
122        let receiver = self.send_storage_proof(hashed_address, prefix_set, target_slots)?;
123        let proof_msg = receiver.recv().map_err(|_| {
124            ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
125                format!("channel closed for {hashed_address}"),
126            )))
127        })?;
128
129        // Extract storage proof directly from the result
130        let storage_proof = match proof_msg.result? {
131            crate::proof_task::ProofResult::StorageProof { hashed_address: addr, proof } => {
132                debug_assert_eq!(
133                    addr,
134                    hashed_address,
135                    "storage worker must return same address: expected {hashed_address}, got {addr}"
136                );
137                proof
138            }
139            crate::proof_task::ProofResult::AccountMultiproof { .. } => {
140                unreachable!("storage worker only sends StorageProof variant")
141            }
142        };
143
144        trace!(
145            target: "trie::parallel_proof",
146            total_targets,
147            ?hashed_address,
148            "Storage proof generation completed"
149        );
150
151        Ok(storage_proof)
152    }
153
154    /// Extends prefix sets with the given multiproof targets and returns the frozen result.
155    ///
156    /// This is a helper function used to prepare prefix sets before computing multiproofs.
157    /// Returns frozen (immutable) prefix sets ready for use in proof computation.
158    pub fn extend_prefix_sets_with_targets(
159        base_prefix_sets: &TriePrefixSetsMut,
160        targets: &MultiProofTargets,
161    ) -> TriePrefixSets {
162        let mut extended = base_prefix_sets.clone();
163        extended.extend(TriePrefixSetsMut {
164            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
165            storage_prefix_sets: targets
166                .iter()
167                .filter(|&(_hashed_address, slots)| !slots.is_empty())
168                .map(|(hashed_address, slots)| {
169                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
170                })
171                .collect(),
172            destroyed_accounts: Default::default(),
173        });
174        extended.freeze()
175    }
176
177    /// Generate a state multiproof according to specified targets.
178    pub fn decoded_multiproof(
179        self,
180        targets: MultiProofTargets,
181    ) -> Result<DecodedMultiProof, ParallelStateRootError> {
182        // Extend prefix sets with targets
183        let prefix_sets = Self::extend_prefix_sets_with_targets(&self.prefix_sets, &targets);
184
185        let storage_root_targets_len = StorageRootTargets::count(
186            &prefix_sets.account_prefix_set,
187            &prefix_sets.storage_prefix_sets,
188        );
189
190        trace!(
191            target: "trie::parallel_proof",
192            total_targets = storage_root_targets_len,
193            "Starting parallel proof generation"
194        );
195
196        // Queue account multiproof request to account worker pool
197        // Create channel for receiving ProofResultMessage
198        let (result_tx, result_rx) = crossbeam_unbounded();
199        let account_multiproof_start_time = Instant::now();
200
201        let input = AccountMultiproofInput {
202            targets,
203            prefix_sets,
204            collect_branch_node_masks: self.collect_branch_node_masks,
205            multi_added_removed_keys: self.multi_added_removed_keys.clone(),
206            missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(),
207            proof_result_sender: ProofResultContext::new(
208                result_tx,
209                0,
210                HashedPostState::default(),
211                account_multiproof_start_time,
212            ),
213        };
214
215        self.proof_worker_handle
216            .dispatch_account_multiproof(input)
217            .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
218
219        // Wait for account multiproof result from worker
220        let proof_result_msg = result_rx.recv().map_err(|_| {
221            ParallelStateRootError::Other(
222                "Account multiproof channel dropped: worker died or pool shutdown".to_string(),
223            )
224        })?;
225
226        let (multiproof, stats) = match proof_result_msg.result? {
227            crate::proof_task::ProofResult::AccountMultiproof { proof, stats } => (proof, stats),
228            crate::proof_task::ProofResult::StorageProof { .. } => {
229                unreachable!("account worker only sends AccountMultiproof variant")
230            }
231        };
232
233        #[cfg(feature = "metrics")]
234        self.metrics.record(stats);
235
236        trace!(
237            target: "trie::parallel_proof",
238            total_targets = storage_root_targets_len,
239            duration = ?stats.duration(),
240            branches_added = stats.branches_added(),
241            leaves_added = stats.leaves_added(),
242            missed_leaves = stats.missed_leaves(),
243            precomputed_storage_roots = stats.precomputed_storage_roots(),
244            "Calculated decoded proof"
245        );
246
247        Ok(multiproof)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::proof_task::{ProofTaskCtx, ProofWorkerHandle};
255    use alloy_primitives::{
256        keccak256,
257        map::{B256Set, DefaultHashBuilder, HashMap},
258        Address, U256,
259    };
260    use rand::Rng;
261    use reth_primitives_traits::{Account, StorageEntry};
262    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
263    use reth_trie::proof::Proof;
264    use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
265    use tokio::runtime::Runtime;
266
267    #[test]
268    fn random_parallel_proof() {
269        let factory = create_test_provider_factory();
270
271        let mut rng = rand::rng();
272        let state = (0..100)
273            .map(|_| {
274                let address = Address::random();
275                let account =
276                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
277                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
278                let has_storage = rng.random_bool(0.7);
279                if has_storage {
280                    for _ in 0..100 {
281                        storage.insert(
282                            B256::from(U256::from(rng.random::<u64>())),
283                            U256::from(rng.random::<u64>()),
284                        );
285                    }
286                }
287                (address, (account, storage))
288            })
289            .collect::<HashMap<_, _, DefaultHashBuilder>>();
290
291        {
292            let provider_rw = factory.provider_rw().unwrap();
293            provider_rw
294                .insert_account_for_hashing(
295                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
296                )
297                .unwrap();
298            provider_rw
299                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
300                    (
301                        *address,
302                        storage
303                            .iter()
304                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
305                    )
306                }))
307                .unwrap();
308            provider_rw.commit().unwrap();
309        }
310
311        let mut targets = MultiProofTargets::default();
312        for (address, (_, storage)) in state.iter().take(10) {
313            let hashed_address = keccak256(*address);
314            let mut target_slots = B256Set::default();
315
316            for (slot, _) in storage.iter().take(5) {
317                target_slots.insert(*slot);
318            }
319
320            if !target_slots.is_empty() {
321                targets.insert(hashed_address, target_slots);
322            }
323        }
324
325        let provider_rw = factory.provider_rw().unwrap();
326        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
327        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
328
329        let rt = Runtime::new().unwrap();
330
331        let factory = reth_provider::providers::OverlayStateProviderFactory::new(factory);
332        let task_ctx = ProofTaskCtx::new(factory, Default::default());
333        let proof_worker_handle = ProofWorkerHandle::new(rt.handle().clone(), task_ctx, 1, 1);
334
335        let parallel_result =
336            ParallelProof::new(Default::default(), Default::default(), proof_worker_handle.clone())
337                .decoded_multiproof(targets.clone())
338                .unwrap();
339
340        let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
341            .multiproof(targets.clone())
342            .unwrap(); // targets might be consumed by parallel_result
343        let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
344            .try_into()
345            .expect("Failed to decode sequential_result for test comparison");
346
347        // to help narrow down what is wrong - first compare account subtries
348        assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
349
350        // then compare length of all storage subtries
351        assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
352
353        // then compare each storage subtrie
354        for (hashed_address, storage_proof) in &parallel_result.storages {
355            let sequential_storage_proof =
356                sequential_result_decoded.storages.get(hashed_address).unwrap();
357            assert_eq!(storage_proof, sequential_storage_proof);
358        }
359
360        // then compare the entire thing for any mask differences
361        assert_eq!(parallel_result, sequential_result_decoded);
362
363        // Workers shut down automatically when handle is dropped
364        drop(proof_worker_handle);
365    }
366}