reth_trie_parallel/
proof.rs

1use crate::{
2    metrics::ParallelTrieMetrics,
3    proof_task::{
4        AccountMultiproofInput, ProofResult, ProofResultContext, ProofWorkerHandle,
5        StorageProofInput, StorageProofResultMessage,
6    },
7    root::ParallelStateRootError,
8    StorageRootTargets,
9};
10use alloy_primitives::{map::B256Set, B256};
11use crossbeam_channel::{unbounded as crossbeam_unbounded, Receiver as CrossbeamReceiver};
12use reth_execution_errors::StorageRootError;
13use reth_storage_errors::db::DatabaseError;
14use reth_trie::{
15    prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSets, TriePrefixSetsMut},
16    DecodedMultiProof, DecodedStorageMultiProof, HashedPostState, MultiProofTargets, Nibbles,
17};
18use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
19use std::{sync::Arc, time::Instant};
20use tracing::trace;
21
22/// Parallel proof calculator.
23///
24/// This can collect proof for many targets in parallel, spawning a task for each hashed address
25/// that has proof targets.
26#[derive(Debug)]
27pub struct ParallelProof {
28    /// The collection of prefix sets for the computation.
29    pub prefix_sets: Arc<TriePrefixSetsMut>,
30    /// Flag indicating whether to include branch node masks in the proof.
31    collect_branch_node_masks: bool,
32    /// Provided by the user to give the necessary context to retain extra proofs.
33    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
34    /// Handle to the proof worker pools.
35    proof_worker_handle: ProofWorkerHandle,
36    /// Whether to use V2 storage proofs.
37    v2_proofs_enabled: bool,
38    #[cfg(feature = "metrics")]
39    metrics: ParallelTrieMetrics,
40}
41
42impl ParallelProof {
43    /// Create new state proof generator.
44    pub fn new(
45        prefix_sets: Arc<TriePrefixSetsMut>,
46        proof_worker_handle: ProofWorkerHandle,
47    ) -> Self {
48        Self {
49            prefix_sets,
50            collect_branch_node_masks: false,
51            multi_added_removed_keys: None,
52            proof_worker_handle,
53            v2_proofs_enabled: false,
54            #[cfg(feature = "metrics")]
55            metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
56        }
57    }
58
59    /// Set whether to use V2 storage proofs.
60    pub const fn with_v2_proofs_enabled(mut self, v2_proofs_enabled: bool) -> Self {
61        self.v2_proofs_enabled = v2_proofs_enabled;
62        self
63    }
64
65    /// Set the flag indicating whether to include branch node masks in the proof.
66    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
67        self.collect_branch_node_masks = branch_node_masks;
68        self
69    }
70
71    /// Configure the `ParallelProof` with a [`MultiAddedRemovedKeys`], allowing for retaining
72    /// extra proofs needed to add and remove leaf nodes from the tries.
73    pub fn with_multi_added_removed_keys(
74        mut self,
75        multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
76    ) -> Self {
77        self.multi_added_removed_keys = multi_added_removed_keys;
78        self
79    }
80    /// Queues a storage proof task and returns a receiver for the result.
81    fn send_storage_proof(
82        &self,
83        hashed_address: B256,
84        prefix_set: PrefixSet,
85        target_slots: B256Set,
86    ) -> Result<CrossbeamReceiver<StorageProofResultMessage>, ParallelStateRootError> {
87        let (result_tx, result_rx) = crossbeam_channel::unbounded();
88
89        let input = if self.v2_proofs_enabled {
90            StorageProofInput::new(
91                hashed_address,
92                target_slots.into_iter().map(Into::into).collect(),
93            )
94        } else {
95            StorageProofInput::legacy(
96                hashed_address,
97                prefix_set,
98                target_slots,
99                self.collect_branch_node_masks,
100                self.multi_added_removed_keys.clone(),
101            )
102        };
103
104        self.proof_worker_handle
105            .dispatch_storage_proof(input, result_tx)
106            .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
107
108        Ok(result_rx)
109    }
110
111    /// Generate a storage multiproof according to the specified targets and hashed address.
112    pub fn storage_proof(
113        self,
114        hashed_address: B256,
115        target_slots: B256Set,
116    ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
117        let total_targets = target_slots.len();
118        let prefix_set = if self.v2_proofs_enabled {
119            PrefixSet::default()
120        } else {
121            PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack)).freeze()
122        };
123
124        trace!(
125            target: "trie::parallel_proof",
126            total_targets,
127            ?hashed_address,
128            "Starting storage proof generation"
129        );
130
131        let receiver = self.send_storage_proof(hashed_address, prefix_set, target_slots)?;
132        let proof_msg = receiver.recv().map_err(|_| {
133            ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
134                format!("channel closed for {hashed_address}"),
135            )))
136        })?;
137
138        // Extract storage proof directly from the result
139        let proof_result = proof_msg.result?;
140        let storage_proof = Into::<Option<DecodedStorageMultiProof>>::into(proof_result)
141            .expect("Partial proofs are not yet supported");
142
143        trace!(
144            target: "trie::parallel_proof",
145            total_targets,
146            ?hashed_address,
147            "Storage proof generation completed"
148        );
149
150        Ok(storage_proof)
151    }
152
153    /// Extends prefix sets with the given multiproof targets and returns the frozen result.
154    ///
155    /// This is a helper function used to prepare prefix sets before computing multiproofs.
156    /// Returns frozen (immutable) prefix sets ready for use in proof computation.
157    pub fn extend_prefix_sets_with_targets(
158        base_prefix_sets: &TriePrefixSetsMut,
159        targets: &MultiProofTargets,
160    ) -> TriePrefixSets {
161        let mut extended = base_prefix_sets.clone();
162        extended.extend(TriePrefixSetsMut {
163            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
164            storage_prefix_sets: targets
165                .iter()
166                .filter(|&(_hashed_address, slots)| !slots.is_empty())
167                .map(|(hashed_address, slots)| {
168                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
169                })
170                .collect(),
171            destroyed_accounts: Default::default(),
172        });
173        extended.freeze()
174    }
175
176    /// Generate a state multiproof according to specified targets.
177    pub fn decoded_multiproof(
178        self,
179        targets: MultiProofTargets,
180    ) -> Result<DecodedMultiProof, ParallelStateRootError> {
181        // Extend prefix sets with targets
182        let prefix_sets = Self::extend_prefix_sets_with_targets(&self.prefix_sets, &targets);
183
184        let storage_root_targets_len = StorageRootTargets::count(
185            &prefix_sets.account_prefix_set,
186            &prefix_sets.storage_prefix_sets,
187        );
188
189        trace!(
190            target: "trie::parallel_proof",
191            total_targets = storage_root_targets_len,
192            "Starting parallel proof generation"
193        );
194
195        // Queue account multiproof request to account worker pool
196        // Create channel for receiving ProofResultMessage
197        let (result_tx, result_rx) = crossbeam_unbounded();
198        let account_multiproof_start_time = Instant::now();
199
200        let input = AccountMultiproofInput::Legacy {
201            targets,
202            prefix_sets,
203            collect_branch_node_masks: self.collect_branch_node_masks,
204            multi_added_removed_keys: self.multi_added_removed_keys.clone(),
205            proof_result_sender: ProofResultContext::new(
206                result_tx,
207                0,
208                HashedPostState::default(),
209                account_multiproof_start_time,
210            ),
211        };
212
213        self.proof_worker_handle
214            .dispatch_account_multiproof(input)
215            .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
216
217        // Wait for account multiproof result from worker
218        let proof_result_msg = result_rx.recv().map_err(|_| {
219            ParallelStateRootError::Other(
220                "Account multiproof channel dropped: worker died or pool shutdown".to_string(),
221            )
222        })?;
223
224        let ProofResult::Legacy(multiproof, stats) = proof_result_msg.result? else {
225            panic!("AccountMultiproofInput::Legacy was submitted, expected legacy result")
226        };
227
228        #[cfg(feature = "metrics")]
229        self.metrics.record(stats);
230
231        trace!(
232            target: "trie::parallel_proof",
233            total_targets = storage_root_targets_len,
234            duration = ?stats.duration(),
235            branches_added = stats.branches_added(),
236            leaves_added = stats.leaves_added(),
237            missed_leaves = stats.missed_leaves(),
238            precomputed_storage_roots = stats.precomputed_storage_roots(),
239            "Calculated decoded proof",
240        );
241
242        Ok(multiproof)
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::proof_task::{ProofTaskCtx, ProofWorkerHandle};
250    use alloy_primitives::{
251        keccak256,
252        map::{B256Set, DefaultHashBuilder, HashMap},
253        Address, U256,
254    };
255    use rand::Rng;
256    use reth_primitives_traits::{Account, StorageEntry};
257    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
258    use reth_trie::proof::Proof;
259    use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
260    use tokio::runtime::Runtime;
261
262    #[test]
263    fn random_parallel_proof() {
264        let factory = create_test_provider_factory();
265
266        let mut rng = rand::rng();
267        let state = (0..100)
268            .map(|_| {
269                let address = Address::random();
270                let account =
271                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
272                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
273                let has_storage = rng.random_bool(0.7);
274                if has_storage {
275                    for _ in 0..100 {
276                        storage.insert(
277                            B256::from(U256::from(rng.random::<u64>())),
278                            U256::from(rng.random::<u64>()),
279                        );
280                    }
281                }
282                (address, (account, storage))
283            })
284            .collect::<HashMap<_, _, DefaultHashBuilder>>();
285
286        {
287            let provider_rw = factory.provider_rw().unwrap();
288            provider_rw
289                .insert_account_for_hashing(
290                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
291                )
292                .unwrap();
293            provider_rw
294                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
295                    (
296                        *address,
297                        storage
298                            .iter()
299                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
300                    )
301                }))
302                .unwrap();
303            provider_rw.commit().unwrap();
304        }
305
306        let mut targets = MultiProofTargets::default();
307        for (address, (_, storage)) in state.iter().take(10) {
308            let hashed_address = keccak256(*address);
309            let mut target_slots = B256Set::default();
310
311            for (slot, _) in storage.iter().take(5) {
312                target_slots.insert(*slot);
313            }
314
315            if !target_slots.is_empty() {
316                targets.insert(hashed_address, target_slots);
317            }
318        }
319
320        let provider_rw = factory.provider_rw().unwrap();
321        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
322        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
323
324        let rt = Runtime::new().unwrap();
325
326        let changeset_cache = reth_trie_db::ChangesetCache::new();
327        let factory =
328            reth_provider::providers::OverlayStateProviderFactory::new(factory, changeset_cache);
329        let task_ctx = ProofTaskCtx::new(factory);
330        let proof_worker_handle =
331            ProofWorkerHandle::new(rt.handle().clone(), task_ctx, 1, 1, false);
332
333        let parallel_result = ParallelProof::new(Default::default(), proof_worker_handle.clone())
334            .decoded_multiproof(targets.clone())
335            .unwrap();
336
337        let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
338            .multiproof(targets.clone())
339            .unwrap(); // targets might be consumed by parallel_result
340        let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
341            .try_into()
342            .expect("Failed to decode sequential_result for test comparison");
343
344        // to help narrow down what is wrong - first compare account subtries
345        assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
346
347        // then compare length of all storage subtries
348        assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
349
350        // then compare each storage subtrie
351        for (hashed_address, storage_proof) in &parallel_result.storages {
352            let sequential_storage_proof =
353                sequential_result_decoded.storages.get(hashed_address).unwrap();
354            assert_eq!(storage_proof, sequential_storage_proof);
355        }
356
357        // then compare the entire thing for any mask differences
358        assert_eq!(parallel_result, sequential_result_decoded);
359
360        // Workers shut down automatically when handle is dropped
361        drop(proof_worker_handle);
362    }
363}