1use crate::{
2 hashed_cursor::{
3 HashedCursorFactory, HashedCursorMetricsCache, HashedStorageCursor,
4 InstrumentedHashedCursor,
5 },
6 node_iter::{TrieElement, TrieNodeIter},
7 prefix_set::{PrefixSetMut, TriePrefixSetsMut},
8 proof_v2::{self, SyncAccountValueEncoder},
9 trie_cursor::{InstrumentedTrieCursor, TrieCursorFactory, TrieCursorMetricsCache},
10 walker::TrieWalker,
11 HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
12};
13use alloy_primitives::{
14 keccak256,
15 map::{B256Map, B256Set, HashSet},
16 Address, B256,
17};
18use alloy_rlp::{BufMut, Encodable};
19use alloy_trie::proof::AddedRemovedKeys;
20use reth_execution_errors::trie::StateProofError;
21use reth_trie_common::{
22 proof::ProofRetainer, AccountProof, BranchNodeMasks, BranchNodeMasksMap, DecodedMultiProofV2,
23 MultiProof, MultiProofTargets, MultiProofTargetsV2, StorageMultiProof,
24};
25
26#[derive(Debug)]
32pub struct Proof<T, H, K = AddedRemovedKeys> {
33 trie_cursor_factory: T,
35 hashed_cursor_factory: H,
37 prefix_sets: TriePrefixSetsMut,
39 collect_branch_node_masks: bool,
41 added_removed_keys: Option<K>,
43}
44
45impl<T, H> Proof<T, H> {
46 pub fn new(t: T, h: H) -> Self {
48 Self {
49 trie_cursor_factory: t,
50 hashed_cursor_factory: h,
51 prefix_sets: TriePrefixSetsMut::default(),
52 collect_branch_node_masks: false,
53 added_removed_keys: None,
54 }
55 }
56}
57
58impl<T, H, K> Proof<T, H, K> {
59 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H, K> {
61 Proof {
62 trie_cursor_factory,
63 hashed_cursor_factory: self.hashed_cursor_factory,
64 prefix_sets: self.prefix_sets,
65 collect_branch_node_masks: self.collect_branch_node_masks,
66 added_removed_keys: self.added_removed_keys,
67 }
68 }
69
70 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF, K> {
72 Proof {
73 trie_cursor_factory: self.trie_cursor_factory,
74 hashed_cursor_factory,
75 prefix_sets: self.prefix_sets,
76 collect_branch_node_masks: self.collect_branch_node_masks,
77 added_removed_keys: self.added_removed_keys,
78 }
79 }
80
81 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
83 self.prefix_sets = prefix_sets;
84 self
85 }
86
87 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
89 self.collect_branch_node_masks = branch_node_masks;
90 self
91 }
92
93 pub fn with_added_removed_keys<K2>(self, added_removed_keys: Option<K2>) -> Proof<T, H, K2> {
99 Proof {
100 trie_cursor_factory: self.trie_cursor_factory,
101 hashed_cursor_factory: self.hashed_cursor_factory,
102 prefix_sets: self.prefix_sets,
103 collect_branch_node_masks: self.collect_branch_node_masks,
104 added_removed_keys,
105 }
106 }
107
108 pub const fn trie_cursor_factory(&self) -> &T {
110 &self.trie_cursor_factory
111 }
112
113 pub const fn hashed_cursor_factory(&self) -> &H {
115 &self.hashed_cursor_factory
116 }
117}
118
119impl<T, H, K> Proof<T, H, K>
120where
121 T: TrieCursorFactory + Clone,
122 H: HashedCursorFactory + Clone,
123 K: AsRef<AddedRemovedKeys>,
124{
125 pub fn account_proof(
127 self,
128 address: Address,
129 slots: &[B256],
130 ) -> Result<AccountProof, StateProofError> {
131 Ok(self
132 .multiproof(MultiProofTargets::from_iter([(
133 keccak256(address),
134 slots.iter().map(keccak256).collect(),
135 )]))?
136 .account_proof(address, slots)?)
137 }
138
139 pub fn multiproof_v2(
144 self,
145 targets: MultiProofTargetsV2,
146 ) -> Result<DecodedMultiProofV2, StateProofError> {
147 let MultiProofTargetsV2 { mut account_targets, storage_targets } = targets;
148
149 let storage_prefix_sets: B256Map<_> = self
150 .prefix_sets
151 .storage_prefix_sets
152 .into_iter()
153 .map(|(addr, ps)| (addr, ps.freeze()))
154 .collect();
155
156 let account_trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
158 let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
159 let mut account_value_encoder = SyncAccountValueEncoder::new(
160 self.trie_cursor_factory.clone(),
161 self.hashed_cursor_factory.clone(),
162 )
163 .with_storage_prefix_sets(storage_prefix_sets.clone());
164 let mut account_calculator =
165 proof_v2::ProofCalculator::new(account_trie_cursor, hashed_account_cursor)
166 .with_prefix_set(self.prefix_sets.account_prefix_set.freeze());
167 let account_proofs =
168 account_calculator.proof(&mut account_value_encoder, &mut account_targets)?;
169
170 let mut storage_proofs =
172 B256Map::with_capacity_and_hasher(storage_targets.len(), Default::default());
173 for (hashed_address, mut targets) in storage_targets {
174 let storage_trie_cursor =
175 self.trie_cursor_factory.storage_trie_cursor(hashed_address)?;
176 let hashed_storage_cursor =
177 self.hashed_cursor_factory.hashed_storage_cursor(hashed_address)?;
178 let mut storage_calculator = proof_v2::StorageProofCalculator::new_storage(
179 storage_trie_cursor,
180 hashed_storage_cursor,
181 );
182 if let Some(prefix_set) = storage_prefix_sets.get(&hashed_address) {
183 storage_calculator = storage_calculator.with_prefix_set(prefix_set.clone());
184 }
185 let proofs = storage_calculator.storage_proof(hashed_address, &mut targets)?;
186 storage_proofs.insert(hashed_address, proofs);
187 }
188
189 Ok(DecodedMultiProofV2 { account_proofs, storage_proofs })
190 }
191
192 pub fn multiproof(
194 mut self,
195 mut targets: MultiProofTargets,
196 ) -> Result<MultiProof, StateProofError> {
197 let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
198 let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
199
200 let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
202 prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
203 let walker =
204 TrieWalker::<_, AddedRemovedKeys>::state_trie(trie_cursor, prefix_set.freeze())
205 .with_added_removed_keys(self.added_removed_keys.as_ref());
206
207 let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
209 let retainer = retainer.with_added_removed_keys(self.added_removed_keys.as_ref());
210 let mut hash_builder = HashBuilder::default()
211 .with_proof_retainer(retainer)
212 .with_updates(self.collect_branch_node_masks);
213
214 let mut storages: B256Map<_> =
217 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
218 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
219 let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
220 while let Some(account_node) = account_node_iter.try_next()? {
221 match account_node {
222 TrieElement::Branch(node) => {
223 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
224 }
225 TrieElement::Leaf(hashed_address, account) => {
226 let proof_targets = targets.remove(&hashed_address);
227 let leaf_is_proof_target = proof_targets.is_some();
228 let collect_storage_masks =
229 self.collect_branch_node_masks && leaf_is_proof_target;
230 let storage_prefix_set = self
231 .prefix_sets
232 .storage_prefix_sets
233 .remove(&hashed_address)
234 .unwrap_or_default();
235 let storage_multiproof = StorageProof::new_hashed(
236 self.trie_cursor_factory.clone(),
237 self.hashed_cursor_factory.clone(),
238 hashed_address,
239 )
240 .with_prefix_set_mut(storage_prefix_set)
241 .with_branch_node_masks(collect_storage_masks)
242 .storage_multiproof(proof_targets.unwrap_or_default())?;
243
244 account_rlp.clear();
246 let account = account.into_trie_account(storage_multiproof.root);
247 account.encode(&mut account_rlp as &mut dyn BufMut);
248
249 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
250
251 if leaf_is_proof_target {
253 storages.insert(hashed_address, storage_multiproof);
255 }
256 }
257 }
258 }
259 let _ = hash_builder.root();
260 let account_subtree = hash_builder.take_proof_nodes();
261 let branch_node_masks = if self.collect_branch_node_masks {
262 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
263 updated_branch_nodes
264 .into_iter()
265 .map(|(path, node)| {
266 (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
267 })
268 .collect()
269 } else {
270 BranchNodeMasksMap::default()
271 };
272
273 Ok(MultiProof { account_subtree, branch_node_masks, storages })
274 }
275}
276
277#[derive(Debug)]
279pub struct StorageProof<'a, T, H, K = AddedRemovedKeys> {
280 trie_cursor_factory: T,
282 hashed_cursor_factory: H,
284 hashed_address: B256,
286 prefix_set: PrefixSetMut,
288 collect_branch_node_masks: bool,
290 added_removed_keys: Option<K>,
292 trie_cursor_metrics: Option<&'a mut TrieCursorMetricsCache>,
294 hashed_cursor_metrics: Option<&'a mut HashedCursorMetricsCache>,
296}
297
298impl<T, H> StorageProof<'static, T, H> {
299 pub fn new(t: T, h: H, address: Address) -> Self {
301 Self::new_hashed(t, h, keccak256(address))
302 }
303
304 pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
306 Self {
307 trie_cursor_factory: t,
308 hashed_cursor_factory: h,
309 hashed_address,
310 prefix_set: PrefixSetMut::default(),
311 collect_branch_node_masks: false,
312 added_removed_keys: None,
313 trie_cursor_metrics: None,
314 hashed_cursor_metrics: None,
315 }
316 }
317}
318
319impl<'a, T, H, K> StorageProof<'a, T, H, K> {
320 pub fn with_trie_cursor_factory<TF>(
322 self,
323 trie_cursor_factory: TF,
324 ) -> StorageProof<'a, TF, H, K> {
325 StorageProof {
326 trie_cursor_factory,
327 hashed_cursor_factory: self.hashed_cursor_factory,
328 hashed_address: self.hashed_address,
329 prefix_set: self.prefix_set,
330 collect_branch_node_masks: self.collect_branch_node_masks,
331 added_removed_keys: self.added_removed_keys,
332 trie_cursor_metrics: self.trie_cursor_metrics,
333 hashed_cursor_metrics: self.hashed_cursor_metrics,
334 }
335 }
336
337 pub fn with_hashed_cursor_factory<HF>(
339 self,
340 hashed_cursor_factory: HF,
341 ) -> StorageProof<'a, T, HF, K> {
342 StorageProof {
343 trie_cursor_factory: self.trie_cursor_factory,
344 hashed_cursor_factory,
345 hashed_address: self.hashed_address,
346 prefix_set: self.prefix_set,
347 collect_branch_node_masks: self.collect_branch_node_masks,
348 added_removed_keys: self.added_removed_keys,
349 trie_cursor_metrics: self.trie_cursor_metrics,
350 hashed_cursor_metrics: self.hashed_cursor_metrics,
351 }
352 }
353
354 pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
356 self.prefix_set = prefix_set;
357 self
358 }
359
360 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
362 self.collect_branch_node_masks = branch_node_masks;
363 self
364 }
365
366 pub const fn with_trie_cursor_metrics(
368 mut self,
369 metrics: &'a mut TrieCursorMetricsCache,
370 ) -> Self {
371 self.trie_cursor_metrics = Some(metrics);
372 self
373 }
374
375 pub const fn with_hashed_cursor_metrics(
377 mut self,
378 metrics: &'a mut HashedCursorMetricsCache,
379 ) -> Self {
380 self.hashed_cursor_metrics = Some(metrics);
381 self
382 }
383
384 pub fn with_added_removed_keys<K2>(
390 self,
391 added_removed_keys: Option<K2>,
392 ) -> StorageProof<'a, T, H, K2> {
393 StorageProof {
394 trie_cursor_factory: self.trie_cursor_factory,
395 hashed_cursor_factory: self.hashed_cursor_factory,
396 hashed_address: self.hashed_address,
397 prefix_set: self.prefix_set,
398 collect_branch_node_masks: self.collect_branch_node_masks,
399 added_removed_keys,
400 trie_cursor_metrics: self.trie_cursor_metrics,
401 hashed_cursor_metrics: self.hashed_cursor_metrics,
402 }
403 }
404}
405
406impl<'a, T, H, K> StorageProof<'a, T, H, K>
407where
408 T: TrieCursorFactory,
409 H: HashedCursorFactory,
410 K: AsRef<AddedRemovedKeys>,
411{
412 pub fn storage_proof(
414 self,
415 slot: B256,
416 ) -> Result<reth_trie_common::StorageProof, StateProofError> {
417 let targets = HashSet::from_iter([keccak256(slot)]);
418 Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
419 }
420
421 pub fn storage_multiproof(
423 self,
424 targets: B256Set,
425 ) -> Result<StorageMultiProof, StateProofError> {
426 let mut discard_hashed_cursor_metrics = HashedCursorMetricsCache::default();
427 let hashed_cursor_metrics =
428 self.hashed_cursor_metrics.unwrap_or(&mut discard_hashed_cursor_metrics);
429
430 let hashed_storage_cursor =
431 self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
432
433 let mut hashed_storage_cursor =
434 InstrumentedHashedCursor::new(hashed_storage_cursor, hashed_cursor_metrics);
435
436 if hashed_storage_cursor.is_storage_empty()? {
438 return Ok(StorageMultiProof::empty())
439 }
440
441 let mut discard_trie_cursor_metrics = TrieCursorMetricsCache::default();
442 let trie_cursor_metrics =
443 self.trie_cursor_metrics.unwrap_or(&mut discard_trie_cursor_metrics);
444
445 let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
446 let mut prefix_set = self.prefix_set;
447 prefix_set.extend_keys(target_nibbles.iter().copied());
448
449 let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
450
451 let trie_cursor = InstrumentedTrieCursor::new(trie_cursor, trie_cursor_metrics);
452
453 let walker = TrieWalker::<_>::storage_trie(trie_cursor, prefix_set.freeze())
454 .with_added_removed_keys(self.added_removed_keys.as_ref());
455
456 let retainer = ProofRetainer::from_iter(target_nibbles)
457 .with_added_removed_keys(self.added_removed_keys.as_ref());
458 let mut hash_builder = HashBuilder::default()
459 .with_proof_retainer(retainer)
460 .with_updates(self.collect_branch_node_masks);
461 let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
462 while let Some(node) = storage_node_iter.try_next()? {
463 match node {
464 TrieElement::Branch(node) => {
465 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
466 }
467 TrieElement::Leaf(hashed_slot, value) => {
468 hash_builder.add_leaf(
469 Nibbles::unpack(hashed_slot),
470 alloy_rlp::encode_fixed_size(&value).as_ref(),
471 );
472 }
473 }
474 }
475
476 let root = hash_builder.root();
477 let subtree = hash_builder.take_proof_nodes();
478 let branch_node_masks = if self.collect_branch_node_masks {
479 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
480 updated_branch_nodes
481 .into_iter()
482 .map(|(path, node)| {
483 (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
484 })
485 .collect()
486 } else {
487 BranchNodeMasksMap::default()
488 };
489
490 Ok(StorageMultiProof { root, subtree, branch_node_masks })
491 }
492}