1use crate::{
2 metrics::ParallelTrieMetrics,
3 proof_task::{
4 AccountMultiproofInput, ProofResultContext, ProofResultMessage, ProofWorkerHandle,
5 StorageProofInput,
6 },
7 root::ParallelStateRootError,
8 StorageRootTargets,
9};
10use alloy_primitives::{map::B256Set, B256};
11use crossbeam_channel::{unbounded as crossbeam_unbounded, Receiver as CrossbeamReceiver};
12use dashmap::DashMap;
13use reth_execution_errors::StorageRootError;
14use reth_storage_errors::db::DatabaseError;
15use reth_trie::{
16 prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSets, TriePrefixSetsMut},
17 DecodedMultiProof, DecodedStorageMultiProof, HashedPostState, MultiProofTargets, Nibbles,
18};
19use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
20use std::{sync::Arc, time::Instant};
21use tracing::trace;
22
23#[derive(Debug)]
28pub struct ParallelProof {
29 pub prefix_sets: Arc<TriePrefixSetsMut>,
31 collect_branch_node_masks: bool,
33 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
35 proof_worker_handle: ProofWorkerHandle,
37 missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
40 #[cfg(feature = "metrics")]
41 metrics: ParallelTrieMetrics,
42}
43
44impl ParallelProof {
45 pub fn new(
47 prefix_sets: Arc<TriePrefixSetsMut>,
48 missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
49 proof_worker_handle: ProofWorkerHandle,
50 ) -> Self {
51 Self {
52 prefix_sets,
53 missed_leaves_storage_roots,
54 collect_branch_node_masks: false,
55 multi_added_removed_keys: None,
56 proof_worker_handle,
57 #[cfg(feature = "metrics")]
58 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
59 }
60 }
61
62 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
64 self.collect_branch_node_masks = branch_node_masks;
65 self
66 }
67
68 pub fn with_multi_added_removed_keys(
71 mut self,
72 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
73 ) -> Self {
74 self.multi_added_removed_keys = multi_added_removed_keys;
75 self
76 }
77 fn send_storage_proof(
79 &self,
80 hashed_address: B256,
81 prefix_set: PrefixSet,
82 target_slots: B256Set,
83 ) -> Result<CrossbeamReceiver<ProofResultMessage>, ParallelStateRootError> {
84 let (result_tx, result_rx) = crossbeam_channel::unbounded();
85 let start = Instant::now();
86
87 let input = StorageProofInput::new(
88 hashed_address,
89 prefix_set,
90 target_slots,
91 self.collect_branch_node_masks,
92 self.multi_added_removed_keys.clone(),
93 );
94
95 self.proof_worker_handle
96 .dispatch_storage_proof(
97 input,
98 ProofResultContext::new(result_tx, 0, HashedPostState::default(), start),
99 )
100 .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
101
102 Ok(result_rx)
103 }
104
105 pub fn storage_proof(
107 self,
108 hashed_address: B256,
109 target_slots: B256Set,
110 ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
111 let total_targets = target_slots.len();
112 let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
113 let prefix_set = prefix_set.freeze();
114
115 trace!(
116 target: "trie::parallel_proof",
117 total_targets,
118 ?hashed_address,
119 "Starting storage proof generation"
120 );
121
122 let receiver = self.send_storage_proof(hashed_address, prefix_set, target_slots)?;
123 let proof_msg = receiver.recv().map_err(|_| {
124 ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
125 format!("channel closed for {hashed_address}"),
126 )))
127 })?;
128
129 let storage_proof = match proof_msg.result? {
131 crate::proof_task::ProofResult::StorageProof { hashed_address: addr, proof } => {
132 debug_assert_eq!(
133 addr,
134 hashed_address,
135 "storage worker must return same address: expected {hashed_address}, got {addr}"
136 );
137 proof
138 }
139 crate::proof_task::ProofResult::AccountMultiproof { .. } => {
140 unreachable!("storage worker only sends StorageProof variant")
141 }
142 };
143
144 trace!(
145 target: "trie::parallel_proof",
146 total_targets,
147 ?hashed_address,
148 "Storage proof generation completed"
149 );
150
151 Ok(storage_proof)
152 }
153
154 pub fn extend_prefix_sets_with_targets(
159 base_prefix_sets: &TriePrefixSetsMut,
160 targets: &MultiProofTargets,
161 ) -> TriePrefixSets {
162 let mut extended = base_prefix_sets.clone();
163 extended.extend(TriePrefixSetsMut {
164 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
165 storage_prefix_sets: targets
166 .iter()
167 .filter(|&(_hashed_address, slots)| !slots.is_empty())
168 .map(|(hashed_address, slots)| {
169 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
170 })
171 .collect(),
172 destroyed_accounts: Default::default(),
173 });
174 extended.freeze()
175 }
176
177 pub fn decoded_multiproof(
179 self,
180 targets: MultiProofTargets,
181 ) -> Result<DecodedMultiProof, ParallelStateRootError> {
182 let prefix_sets = Self::extend_prefix_sets_with_targets(&self.prefix_sets, &targets);
184
185 let storage_root_targets_len = StorageRootTargets::count(
186 &prefix_sets.account_prefix_set,
187 &prefix_sets.storage_prefix_sets,
188 );
189
190 trace!(
191 target: "trie::parallel_proof",
192 total_targets = storage_root_targets_len,
193 "Starting parallel proof generation"
194 );
195
196 let (result_tx, result_rx) = crossbeam_unbounded();
199 let account_multiproof_start_time = Instant::now();
200
201 let input = AccountMultiproofInput {
202 targets,
203 prefix_sets,
204 collect_branch_node_masks: self.collect_branch_node_masks,
205 multi_added_removed_keys: self.multi_added_removed_keys.clone(),
206 missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(),
207 proof_result_sender: ProofResultContext::new(
208 result_tx,
209 0,
210 HashedPostState::default(),
211 account_multiproof_start_time,
212 ),
213 };
214
215 self.proof_worker_handle
216 .dispatch_account_multiproof(input)
217 .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
218
219 let proof_result_msg = result_rx.recv().map_err(|_| {
221 ParallelStateRootError::Other(
222 "Account multiproof channel dropped: worker died or pool shutdown".to_string(),
223 )
224 })?;
225
226 let (multiproof, stats) = match proof_result_msg.result? {
227 crate::proof_task::ProofResult::AccountMultiproof { proof, stats } => (proof, stats),
228 crate::proof_task::ProofResult::StorageProof { .. } => {
229 unreachable!("account worker only sends AccountMultiproof variant")
230 }
231 };
232
233 #[cfg(feature = "metrics")]
234 self.metrics.record(stats);
235
236 trace!(
237 target: "trie::parallel_proof",
238 total_targets = storage_root_targets_len,
239 duration = ?stats.duration(),
240 branches_added = stats.branches_added(),
241 leaves_added = stats.leaves_added(),
242 missed_leaves = stats.missed_leaves(),
243 precomputed_storage_roots = stats.precomputed_storage_roots(),
244 "Calculated decoded proof"
245 );
246
247 Ok(multiproof)
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::proof_task::{ProofTaskCtx, ProofWorkerHandle};
255 use alloy_primitives::{
256 keccak256,
257 map::{B256Set, DefaultHashBuilder, HashMap},
258 Address, U256,
259 };
260 use rand::Rng;
261 use reth_primitives_traits::{Account, StorageEntry};
262 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
263 use reth_trie::proof::Proof;
264 use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
265 use tokio::runtime::Runtime;
266
267 #[test]
268 fn random_parallel_proof() {
269 let factory = create_test_provider_factory();
270
271 let mut rng = rand::rng();
272 let state = (0..100)
273 .map(|_| {
274 let address = Address::random();
275 let account =
276 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
277 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
278 let has_storage = rng.random_bool(0.7);
279 if has_storage {
280 for _ in 0..100 {
281 storage.insert(
282 B256::from(U256::from(rng.random::<u64>())),
283 U256::from(rng.random::<u64>()),
284 );
285 }
286 }
287 (address, (account, storage))
288 })
289 .collect::<HashMap<_, _, DefaultHashBuilder>>();
290
291 {
292 let provider_rw = factory.provider_rw().unwrap();
293 provider_rw
294 .insert_account_for_hashing(
295 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
296 )
297 .unwrap();
298 provider_rw
299 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
300 (
301 *address,
302 storage
303 .iter()
304 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
305 )
306 }))
307 .unwrap();
308 provider_rw.commit().unwrap();
309 }
310
311 let mut targets = MultiProofTargets::default();
312 for (address, (_, storage)) in state.iter().take(10) {
313 let hashed_address = keccak256(*address);
314 let mut target_slots = B256Set::default();
315
316 for (slot, _) in storage.iter().take(5) {
317 target_slots.insert(*slot);
318 }
319
320 if !target_slots.is_empty() {
321 targets.insert(hashed_address, target_slots);
322 }
323 }
324
325 let provider_rw = factory.provider_rw().unwrap();
326 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
327 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
328
329 let rt = Runtime::new().unwrap();
330
331 let factory = reth_provider::providers::OverlayStateProviderFactory::new(factory);
332 let task_ctx = ProofTaskCtx::new(factory, Default::default());
333 let proof_worker_handle = ProofWorkerHandle::new(rt.handle().clone(), task_ctx, 1, 1);
334
335 let parallel_result =
336 ParallelProof::new(Default::default(), Default::default(), proof_worker_handle.clone())
337 .decoded_multiproof(targets.clone())
338 .unwrap();
339
340 let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
341 .multiproof(targets.clone())
342 .unwrap(); let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
344 .try_into()
345 .expect("Failed to decode sequential_result for test comparison");
346
347 assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
349
350 assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
352
353 for (hashed_address, storage_proof) in ¶llel_result.storages {
355 let sequential_storage_proof =
356 sequential_result_decoded.storages.get(hashed_address).unwrap();
357 assert_eq!(storage_proof, sequential_storage_proof);
358 }
359
360 assert_eq!(parallel_result, sequential_result_decoded);
362
363 drop(proof_worker_handle);
365 }
366}