1use crate::{
2 metrics::ParallelTrieMetrics,
3 proof_task::{
4 AccountMultiproofInput, ProofResult, ProofResultContext, ProofWorkerHandle,
5 StorageProofInput, StorageProofResultMessage,
6 },
7 root::ParallelStateRootError,
8 StorageRootTargets,
9};
10use alloy_primitives::{map::B256Set, B256};
11use crossbeam_channel::{unbounded as crossbeam_unbounded, Receiver as CrossbeamReceiver};
12use reth_execution_errors::StorageRootError;
13use reth_storage_errors::db::DatabaseError;
14use reth_trie::{
15 prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSets, TriePrefixSetsMut},
16 DecodedMultiProof, DecodedStorageMultiProof, HashedPostState, MultiProofTargets, Nibbles,
17};
18use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
19use std::{sync::Arc, time::Instant};
20use tracing::trace;
21
22#[derive(Debug)]
27pub struct ParallelProof {
28 pub prefix_sets: Arc<TriePrefixSetsMut>,
30 collect_branch_node_masks: bool,
32 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
34 proof_worker_handle: ProofWorkerHandle,
36 v2_proofs_enabled: bool,
38 #[cfg(feature = "metrics")]
39 metrics: ParallelTrieMetrics,
40}
41
42impl ParallelProof {
43 pub fn new(
45 prefix_sets: Arc<TriePrefixSetsMut>,
46 proof_worker_handle: ProofWorkerHandle,
47 ) -> Self {
48 Self {
49 prefix_sets,
50 collect_branch_node_masks: false,
51 multi_added_removed_keys: None,
52 proof_worker_handle,
53 v2_proofs_enabled: false,
54 #[cfg(feature = "metrics")]
55 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
56 }
57 }
58
59 pub const fn with_v2_proofs_enabled(mut self, v2_proofs_enabled: bool) -> Self {
61 self.v2_proofs_enabled = v2_proofs_enabled;
62 self
63 }
64
65 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
67 self.collect_branch_node_masks = branch_node_masks;
68 self
69 }
70
71 pub fn with_multi_added_removed_keys(
74 mut self,
75 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
76 ) -> Self {
77 self.multi_added_removed_keys = multi_added_removed_keys;
78 self
79 }
80 fn send_storage_proof(
82 &self,
83 hashed_address: B256,
84 prefix_set: PrefixSet,
85 target_slots: B256Set,
86 ) -> Result<CrossbeamReceiver<StorageProofResultMessage>, ParallelStateRootError> {
87 let (result_tx, result_rx) = crossbeam_channel::unbounded();
88
89 let input = if self.v2_proofs_enabled {
90 StorageProofInput::new(
91 hashed_address,
92 target_slots.into_iter().map(Into::into).collect(),
93 )
94 } else {
95 StorageProofInput::legacy(
96 hashed_address,
97 prefix_set,
98 target_slots,
99 self.collect_branch_node_masks,
100 self.multi_added_removed_keys.clone(),
101 )
102 };
103
104 self.proof_worker_handle
105 .dispatch_storage_proof(input, result_tx)
106 .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
107
108 Ok(result_rx)
109 }
110
111 pub fn storage_proof(
113 self,
114 hashed_address: B256,
115 target_slots: B256Set,
116 ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
117 let total_targets = target_slots.len();
118 let prefix_set = if self.v2_proofs_enabled {
119 PrefixSet::default()
120 } else {
121 PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack)).freeze()
122 };
123
124 trace!(
125 target: "trie::parallel_proof",
126 total_targets,
127 ?hashed_address,
128 "Starting storage proof generation"
129 );
130
131 let receiver = self.send_storage_proof(hashed_address, prefix_set, target_slots)?;
132 let proof_msg = receiver.recv().map_err(|_| {
133 ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
134 format!("channel closed for {hashed_address}"),
135 )))
136 })?;
137
138 let proof_result = proof_msg.result?;
140 let storage_proof = Into::<Option<DecodedStorageMultiProof>>::into(proof_result)
141 .expect("Partial proofs are not yet supported");
142
143 trace!(
144 target: "trie::parallel_proof",
145 total_targets,
146 ?hashed_address,
147 "Storage proof generation completed"
148 );
149
150 Ok(storage_proof)
151 }
152
153 pub fn extend_prefix_sets_with_targets(
158 base_prefix_sets: &TriePrefixSetsMut,
159 targets: &MultiProofTargets,
160 ) -> TriePrefixSets {
161 let mut extended = base_prefix_sets.clone();
162 extended.extend(TriePrefixSetsMut {
163 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
164 storage_prefix_sets: targets
165 .iter()
166 .filter(|&(_hashed_address, slots)| !slots.is_empty())
167 .map(|(hashed_address, slots)| {
168 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
169 })
170 .collect(),
171 destroyed_accounts: Default::default(),
172 });
173 extended.freeze()
174 }
175
176 pub fn decoded_multiproof(
178 self,
179 targets: MultiProofTargets,
180 ) -> Result<DecodedMultiProof, ParallelStateRootError> {
181 let prefix_sets = Self::extend_prefix_sets_with_targets(&self.prefix_sets, &targets);
183
184 let storage_root_targets_len = StorageRootTargets::count(
185 &prefix_sets.account_prefix_set,
186 &prefix_sets.storage_prefix_sets,
187 );
188
189 trace!(
190 target: "trie::parallel_proof",
191 total_targets = storage_root_targets_len,
192 "Starting parallel proof generation"
193 );
194
195 let (result_tx, result_rx) = crossbeam_unbounded();
198 let account_multiproof_start_time = Instant::now();
199
200 let input = AccountMultiproofInput::Legacy {
201 targets,
202 prefix_sets,
203 collect_branch_node_masks: self.collect_branch_node_masks,
204 multi_added_removed_keys: self.multi_added_removed_keys.clone(),
205 proof_result_sender: ProofResultContext::new(
206 result_tx,
207 0,
208 HashedPostState::default(),
209 account_multiproof_start_time,
210 ),
211 };
212
213 self.proof_worker_handle
214 .dispatch_account_multiproof(input)
215 .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
216
217 let proof_result_msg = result_rx.recv().map_err(|_| {
219 ParallelStateRootError::Other(
220 "Account multiproof channel dropped: worker died or pool shutdown".to_string(),
221 )
222 })?;
223
224 let ProofResult::Legacy(multiproof, stats) = proof_result_msg.result? else {
225 panic!("AccountMultiproofInput::Legacy was submitted, expected legacy result")
226 };
227
228 #[cfg(feature = "metrics")]
229 self.metrics.record(stats);
230
231 trace!(
232 target: "trie::parallel_proof",
233 total_targets = storage_root_targets_len,
234 duration = ?stats.duration(),
235 branches_added = stats.branches_added(),
236 leaves_added = stats.leaves_added(),
237 missed_leaves = stats.missed_leaves(),
238 precomputed_storage_roots = stats.precomputed_storage_roots(),
239 "Calculated decoded proof",
240 );
241
242 Ok(multiproof)
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::proof_task::{ProofTaskCtx, ProofWorkerHandle};
250 use alloy_primitives::{
251 keccak256,
252 map::{B256Set, DefaultHashBuilder, HashMap},
253 Address, U256,
254 };
255 use rand::Rng;
256 use reth_primitives_traits::{Account, StorageEntry};
257 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
258 use reth_trie::proof::Proof;
259 use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
260 use tokio::runtime::Runtime;
261
262 #[test]
263 fn random_parallel_proof() {
264 let factory = create_test_provider_factory();
265
266 let mut rng = rand::rng();
267 let state = (0..100)
268 .map(|_| {
269 let address = Address::random();
270 let account =
271 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
272 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
273 let has_storage = rng.random_bool(0.7);
274 if has_storage {
275 for _ in 0..100 {
276 storage.insert(
277 B256::from(U256::from(rng.random::<u64>())),
278 U256::from(rng.random::<u64>()),
279 );
280 }
281 }
282 (address, (account, storage))
283 })
284 .collect::<HashMap<_, _, DefaultHashBuilder>>();
285
286 {
287 let provider_rw = factory.provider_rw().unwrap();
288 provider_rw
289 .insert_account_for_hashing(
290 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
291 )
292 .unwrap();
293 provider_rw
294 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
295 (
296 *address,
297 storage
298 .iter()
299 .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
300 )
301 }))
302 .unwrap();
303 provider_rw.commit().unwrap();
304 }
305
306 let mut targets = MultiProofTargets::default();
307 for (address, (_, storage)) in state.iter().take(10) {
308 let hashed_address = keccak256(*address);
309 let mut target_slots = B256Set::default();
310
311 for (slot, _) in storage.iter().take(5) {
312 target_slots.insert(*slot);
313 }
314
315 if !target_slots.is_empty() {
316 targets.insert(hashed_address, target_slots);
317 }
318 }
319
320 let provider_rw = factory.provider_rw().unwrap();
321 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
322 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
323
324 let rt = Runtime::new().unwrap();
325
326 let changeset_cache = reth_trie_db::ChangesetCache::new();
327 let factory =
328 reth_provider::providers::OverlayStateProviderFactory::new(factory, changeset_cache);
329 let task_ctx = ProofTaskCtx::new(factory);
330 let proof_worker_handle =
331 ProofWorkerHandle::new(rt.handle().clone(), task_ctx, 1, 1, false);
332
333 let parallel_result = ParallelProof::new(Default::default(), proof_worker_handle.clone())
334 .decoded_multiproof(targets.clone())
335 .unwrap();
336
337 let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
338 .multiproof(targets.clone())
339 .unwrap(); let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
341 .try_into()
342 .expect("Failed to decode sequential_result for test comparison");
343
344 assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
346
347 assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
349
350 for (hashed_address, storage_proof) in ¶llel_result.storages {
352 let sequential_storage_proof =
353 sequential_result_decoded.storages.get(hashed_address).unwrap();
354 assert_eq!(storage_proof, sequential_storage_proof);
355 }
356
357 assert_eq!(parallel_result, sequential_result_decoded);
359
360 drop(proof_worker_handle);
362 }
363}