1use crate::{
2 metrics::ParallelTrieMetrics,
3 proof_task::{
4 AccountMultiproofInput, ProofResult, ProofResultContext, ProofWorkerHandle,
5 StorageProofInput, StorageProofResultMessage,
6 },
7 root::ParallelStateRootError,
8 StorageRootTargets,
9};
10use alloy_primitives::{map::B256Set, B256};
11use crossbeam_channel::{unbounded as crossbeam_unbounded, Receiver as CrossbeamReceiver};
12use reth_execution_errors::StorageRootError;
13use reth_storage_errors::db::DatabaseError;
14use reth_trie::{
15 prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSets, TriePrefixSetsMut},
16 DecodedMultiProof, DecodedStorageMultiProof, HashedPostState, MultiProofTargets, Nibbles,
17};
18use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
19use std::{sync::Arc, time::Instant};
20use tracing::trace;
21
22#[derive(Debug)]
27pub struct ParallelProof {
28 pub prefix_sets: Arc<TriePrefixSetsMut>,
30 collect_branch_node_masks: bool,
32 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
34 proof_worker_handle: ProofWorkerHandle,
36 v2_proofs_enabled: bool,
38 #[cfg(feature = "metrics")]
39 metrics: ParallelTrieMetrics,
40}
41
42impl ParallelProof {
43 pub fn new(
45 prefix_sets: Arc<TriePrefixSetsMut>,
46 proof_worker_handle: ProofWorkerHandle,
47 ) -> Self {
48 Self {
49 prefix_sets,
50 collect_branch_node_masks: false,
51 multi_added_removed_keys: None,
52 proof_worker_handle,
53 v2_proofs_enabled: false,
54 #[cfg(feature = "metrics")]
55 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
56 }
57 }
58
59 pub const fn with_v2_proofs_enabled(mut self, v2_proofs_enabled: bool) -> Self {
61 self.v2_proofs_enabled = v2_proofs_enabled;
62 self
63 }
64
65 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
67 self.collect_branch_node_masks = branch_node_masks;
68 self
69 }
70
71 pub fn with_multi_added_removed_keys(
74 mut self,
75 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
76 ) -> Self {
77 self.multi_added_removed_keys = multi_added_removed_keys;
78 self
79 }
80 fn send_storage_proof(
82 &self,
83 hashed_address: B256,
84 prefix_set: PrefixSet,
85 target_slots: B256Set,
86 ) -> Result<CrossbeamReceiver<StorageProofResultMessage>, ParallelStateRootError> {
87 let (result_tx, result_rx) = crossbeam_channel::unbounded();
88
89 let input = if self.v2_proofs_enabled {
90 StorageProofInput::new(
91 hashed_address,
92 target_slots.into_iter().map(Into::into).collect(),
93 )
94 } else {
95 StorageProofInput::legacy(
96 hashed_address,
97 prefix_set,
98 target_slots,
99 self.collect_branch_node_masks,
100 self.multi_added_removed_keys.clone(),
101 )
102 };
103
104 self.proof_worker_handle
105 .dispatch_storage_proof(input, result_tx)
106 .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
107
108 Ok(result_rx)
109 }
110
111 pub fn storage_proof(
113 self,
114 hashed_address: B256,
115 target_slots: B256Set,
116 ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
117 let total_targets = target_slots.len();
118 let prefix_set = if self.v2_proofs_enabled {
119 PrefixSet::default()
120 } else {
121 PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack)).freeze()
122 };
123
124 trace!(
125 target: "trie::parallel_proof",
126 total_targets,
127 ?hashed_address,
128 "Starting storage proof generation"
129 );
130
131 let receiver = self.send_storage_proof(hashed_address, prefix_set, target_slots)?;
132 let proof_msg = receiver.recv().map_err(|_| {
133 ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
134 format!("channel closed for {hashed_address}"),
135 )))
136 })?;
137
138 let proof_result = proof_msg.result?;
140 let storage_proof = Into::<Option<DecodedStorageMultiProof>>::into(proof_result)
141 .expect("Partial proofs are not yet supported");
142
143 trace!(
144 target: "trie::parallel_proof",
145 total_targets,
146 ?hashed_address,
147 "Storage proof generation completed"
148 );
149
150 Ok(storage_proof)
151 }
152
153 pub fn extend_prefix_sets_with_targets(
158 base_prefix_sets: &TriePrefixSetsMut,
159 targets: &MultiProofTargets,
160 ) -> TriePrefixSets {
161 let mut extended = base_prefix_sets.clone();
162 extended.extend(TriePrefixSetsMut {
163 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
164 storage_prefix_sets: targets
165 .iter()
166 .filter(|&(_hashed_address, slots)| !slots.is_empty())
167 .map(|(hashed_address, slots)| {
168 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
169 })
170 .collect(),
171 destroyed_accounts: Default::default(),
172 });
173 extended.freeze()
174 }
175
176 pub fn decoded_multiproof(
178 self,
179 targets: MultiProofTargets,
180 ) -> Result<DecodedMultiProof, ParallelStateRootError> {
181 let prefix_sets = Self::extend_prefix_sets_with_targets(&self.prefix_sets, &targets);
183
184 let storage_root_targets_len = StorageRootTargets::count(
185 &prefix_sets.account_prefix_set,
186 &prefix_sets.storage_prefix_sets,
187 );
188
189 trace!(
190 target: "trie::parallel_proof",
191 total_targets = storage_root_targets_len,
192 "Starting parallel proof generation"
193 );
194
195 let (result_tx, result_rx) = crossbeam_unbounded();
198 let account_multiproof_start_time = Instant::now();
199
200 let input = AccountMultiproofInput::Legacy {
201 targets,
202 prefix_sets,
203 collect_branch_node_masks: self.collect_branch_node_masks,
204 multi_added_removed_keys: self.multi_added_removed_keys.clone(),
205 proof_result_sender: ProofResultContext::new(
206 result_tx,
207 0,
208 HashedPostState::default(),
209 account_multiproof_start_time,
210 ),
211 };
212
213 self.proof_worker_handle
214 .dispatch_account_multiproof(input)
215 .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
216
217 let proof_result_msg = result_rx.recv().map_err(|_| {
219 ParallelStateRootError::Other(
220 "Account multiproof channel dropped: worker died or pool shutdown".to_string(),
221 )
222 })?;
223
224 let ProofResult::Legacy(multiproof, stats) = proof_result_msg.result? else {
225 panic!("AccountMultiproofInput::Legacy was submitted, expected legacy result")
226 };
227
228 #[cfg(feature = "metrics")]
229 self.metrics.record(stats);
230
231 trace!(
232 target: "trie::parallel_proof",
233 total_targets = storage_root_targets_len,
234 duration = ?stats.duration(),
235 branches_added = stats.branches_added(),
236 leaves_added = stats.leaves_added(),
237 missed_leaves = stats.missed_leaves(),
238 precomputed_storage_roots = stats.precomputed_storage_roots(),
239 "Calculated decoded proof",
240 );
241
242 Ok(multiproof)
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::proof_task::{ProofTaskCtx, ProofWorkerHandle};
250 use alloy_primitives::{
251 keccak256,
252 map::{B256Set, DefaultHashBuilder, HashMap},
253 Address, U256,
254 };
255 use rand::Rng;
256 use reth_primitives_traits::{Account, StorageEntry};
257 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
258 use reth_trie::proof::Proof;
259 use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
260
261 #[test]
262 fn random_parallel_proof() {
263 let factory = create_test_provider_factory();
264
265 let mut rng = rand::rng();
266 let state = (0..100)
267 .map(|_| {
268 let address = Address::random();
269 let account =
270 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
271 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
272 let has_storage = rng.random_bool(0.7);
273 if has_storage {
274 for _ in 0..100 {
275 storage.insert(
276 B256::from(U256::from(rng.random::<u64>())),
277 U256::from(rng.random::<u64>()),
278 );
279 }
280 }
281 (address, (account, storage))
282 })
283 .collect::<HashMap<_, _, DefaultHashBuilder>>();
284
285 {
286 let provider_rw = factory.provider_rw().unwrap();
287 provider_rw
288 .insert_account_for_hashing(
289 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
290 )
291 .unwrap();
292 provider_rw
293 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
294 (
295 *address,
296 storage
297 .iter()
298 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
299 )
300 }))
301 .unwrap();
302 provider_rw.commit().unwrap();
303 }
304
305 let mut targets = MultiProofTargets::default();
306 for (address, (_, storage)) in state.iter().take(10) {
307 let hashed_address = keccak256(*address);
308 let mut target_slots = B256Set::default();
309
310 for (slot, _) in storage.iter().take(5) {
311 target_slots.insert(*slot);
312 }
313
314 if !target_slots.is_empty() {
315 targets.insert(hashed_address, target_slots);
316 }
317 }
318
319 let provider_rw = factory.provider_rw().unwrap();
320 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
321 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
322
323 let changeset_cache = reth_trie_db::ChangesetCache::new();
324 let factory =
325 reth_provider::providers::OverlayStateProviderFactory::new(factory, changeset_cache);
326 let task_ctx = ProofTaskCtx::new(factory);
327 let runtime = reth_tasks::Runtime::test();
328 let proof_worker_handle = ProofWorkerHandle::new(&runtime, task_ctx, false);
329
330 let parallel_result = ParallelProof::new(Default::default(), proof_worker_handle.clone())
331 .decoded_multiproof(targets.clone())
332 .unwrap();
333
334 let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
335 .multiproof(targets.clone())
336 .unwrap(); let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
338 .try_into()
339 .expect("Failed to decode sequential_result for test comparison");
340
341 assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
343
344 assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
346
347 for (hashed_address, storage_proof) in ¶llel_result.storages {
349 let sequential_storage_proof =
350 sequential_result_decoded.storages.get(hashed_address).unwrap();
351 assert_eq!(storage_proof, sequential_storage_proof);
352 }
353
354 assert_eq!(parallel_result, sequential_result_decoded);
356
357 drop(proof_worker_handle);
359 }
360}