1use crate::{
2 hashed_cursor::{
3 HashedCursorFactory, HashedCursorMetricsCache, HashedStorageCursor,
4 InstrumentedHashedCursor,
5 },
6 node_iter::{TrieElement, TrieNodeIter},
7 prefix_set::{PrefixSetMut, TriePrefixSetsMut},
8 proof_v2::{self, SyncAccountValueEncoder},
9 trie_cursor::{InstrumentedTrieCursor, TrieCursorFactory, TrieCursorMetricsCache},
10 walker::TrieWalker,
11 HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
12};
13use alloy_primitives::{
14 keccak256,
15 map::{B256Map, B256Set, HashSet},
16 Address, B256,
17};
18use alloy_rlp::{BufMut, Encodable};
19use alloy_trie::proof::AddedRemovedKeys;
20use reth_execution_errors::trie::StateProofError;
21use reth_trie_common::{
22 proof::ProofRetainer, AccountProof, BranchNodeMasks, BranchNodeMasksMap, DecodedMultiProofV2,
23 MultiProof, MultiProofTargets, MultiProofTargetsV2, StorageMultiProof,
24};
25
26mod trie_node;
27pub use trie_node::*;
28
29#[derive(Debug)]
35pub struct Proof<T, H, K = AddedRemovedKeys> {
36 trie_cursor_factory: T,
38 hashed_cursor_factory: H,
40 prefix_sets: TriePrefixSetsMut,
42 collect_branch_node_masks: bool,
44 added_removed_keys: Option<K>,
46}
47
48impl<T, H> Proof<T, H> {
49 pub fn new(t: T, h: H) -> Self {
51 Self {
52 trie_cursor_factory: t,
53 hashed_cursor_factory: h,
54 prefix_sets: TriePrefixSetsMut::default(),
55 collect_branch_node_masks: false,
56 added_removed_keys: None,
57 }
58 }
59}
60
61impl<T, H, K> Proof<T, H, K> {
62 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H, K> {
64 Proof {
65 trie_cursor_factory,
66 hashed_cursor_factory: self.hashed_cursor_factory,
67 prefix_sets: self.prefix_sets,
68 collect_branch_node_masks: self.collect_branch_node_masks,
69 added_removed_keys: self.added_removed_keys,
70 }
71 }
72
73 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF, K> {
75 Proof {
76 trie_cursor_factory: self.trie_cursor_factory,
77 hashed_cursor_factory,
78 prefix_sets: self.prefix_sets,
79 collect_branch_node_masks: self.collect_branch_node_masks,
80 added_removed_keys: self.added_removed_keys,
81 }
82 }
83
84 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
86 self.prefix_sets = prefix_sets;
87 self
88 }
89
90 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
92 self.collect_branch_node_masks = branch_node_masks;
93 self
94 }
95
96 pub fn with_added_removed_keys<K2>(self, added_removed_keys: Option<K2>) -> Proof<T, H, K2> {
102 Proof {
103 trie_cursor_factory: self.trie_cursor_factory,
104 hashed_cursor_factory: self.hashed_cursor_factory,
105 prefix_sets: self.prefix_sets,
106 collect_branch_node_masks: self.collect_branch_node_masks,
107 added_removed_keys,
108 }
109 }
110
111 pub const fn trie_cursor_factory(&self) -> &T {
113 &self.trie_cursor_factory
114 }
115
116 pub const fn hashed_cursor_factory(&self) -> &H {
118 &self.hashed_cursor_factory
119 }
120}
121
122impl<T, H, K> Proof<T, H, K>
123where
124 T: TrieCursorFactory + Clone,
125 H: HashedCursorFactory + Clone,
126 K: AsRef<AddedRemovedKeys>,
127{
128 pub fn account_proof(
130 self,
131 address: Address,
132 slots: &[B256],
133 ) -> Result<AccountProof, StateProofError> {
134 Ok(self
135 .multiproof(MultiProofTargets::from_iter([(
136 keccak256(address),
137 slots.iter().map(keccak256).collect(),
138 )]))?
139 .account_proof(address, slots)?)
140 }
141
142 pub fn multiproof_v2(
147 self,
148 targets: MultiProofTargetsV2,
149 ) -> Result<DecodedMultiProofV2, StateProofError> {
150 let MultiProofTargetsV2 { mut account_targets, storage_targets } = targets;
151
152 let storage_prefix_sets: B256Map<_> = self
153 .prefix_sets
154 .storage_prefix_sets
155 .into_iter()
156 .map(|(addr, ps)| (addr, ps.freeze()))
157 .collect();
158
159 let account_trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
161 let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
162 let mut account_value_encoder = SyncAccountValueEncoder::new(
163 self.trie_cursor_factory.clone(),
164 self.hashed_cursor_factory.clone(),
165 )
166 .with_storage_prefix_sets(storage_prefix_sets.clone());
167 let mut account_calculator =
168 proof_v2::ProofCalculator::new(account_trie_cursor, hashed_account_cursor)
169 .with_prefix_set(self.prefix_sets.account_prefix_set.freeze());
170 let account_proofs =
171 account_calculator.proof(&mut account_value_encoder, &mut account_targets)?;
172
173 let mut storage_proofs =
175 B256Map::with_capacity_and_hasher(storage_targets.len(), Default::default());
176 for (hashed_address, mut targets) in storage_targets {
177 let storage_trie_cursor =
178 self.trie_cursor_factory.storage_trie_cursor(hashed_address)?;
179 let hashed_storage_cursor =
180 self.hashed_cursor_factory.hashed_storage_cursor(hashed_address)?;
181 let mut storage_calculator = proof_v2::StorageProofCalculator::new_storage(
182 storage_trie_cursor,
183 hashed_storage_cursor,
184 );
185 if let Some(prefix_set) = storage_prefix_sets.get(&hashed_address) {
186 storage_calculator = storage_calculator.with_prefix_set(prefix_set.clone());
187 }
188 let proofs = storage_calculator.storage_proof(hashed_address, &mut targets)?;
189 storage_proofs.insert(hashed_address, proofs);
190 }
191
192 Ok(DecodedMultiProofV2 { account_proofs, storage_proofs })
193 }
194
195 pub fn multiproof(
197 mut self,
198 mut targets: MultiProofTargets,
199 ) -> Result<MultiProof, StateProofError> {
200 let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
201 let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
202
203 let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
205 prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
206 let walker =
207 TrieWalker::<_, AddedRemovedKeys>::state_trie(trie_cursor, prefix_set.freeze())
208 .with_added_removed_keys(self.added_removed_keys.as_ref());
209
210 let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
212 let retainer = retainer.with_added_removed_keys(self.added_removed_keys.as_ref());
213 let mut hash_builder = HashBuilder::default()
214 .with_proof_retainer(retainer)
215 .with_updates(self.collect_branch_node_masks);
216
217 let mut storages: B256Map<_> =
220 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
221 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
222 let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
223 while let Some(account_node) = account_node_iter.try_next()? {
224 match account_node {
225 TrieElement::Branch(node) => {
226 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
227 }
228 TrieElement::Leaf(hashed_address, account) => {
229 let proof_targets = targets.remove(&hashed_address);
230 let leaf_is_proof_target = proof_targets.is_some();
231 let collect_storage_masks =
232 self.collect_branch_node_masks && leaf_is_proof_target;
233 let storage_prefix_set = self
234 .prefix_sets
235 .storage_prefix_sets
236 .remove(&hashed_address)
237 .unwrap_or_default();
238 let storage_multiproof = StorageProof::new_hashed(
239 self.trie_cursor_factory.clone(),
240 self.hashed_cursor_factory.clone(),
241 hashed_address,
242 )
243 .with_prefix_set_mut(storage_prefix_set)
244 .with_branch_node_masks(collect_storage_masks)
245 .storage_multiproof(proof_targets.unwrap_or_default())?;
246
247 account_rlp.clear();
249 let account = account.into_trie_account(storage_multiproof.root);
250 account.encode(&mut account_rlp as &mut dyn BufMut);
251
252 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
253
254 if leaf_is_proof_target {
256 storages.insert(hashed_address, storage_multiproof);
258 }
259 }
260 }
261 }
262 let _ = hash_builder.root();
263 let account_subtree = hash_builder.take_proof_nodes();
264 let branch_node_masks = if self.collect_branch_node_masks {
265 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
266 updated_branch_nodes
267 .into_iter()
268 .map(|(path, node)| {
269 (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
270 })
271 .collect()
272 } else {
273 BranchNodeMasksMap::default()
274 };
275
276 Ok(MultiProof { account_subtree, branch_node_masks, storages })
277 }
278}
279
280#[derive(Debug)]
282pub struct StorageProof<'a, T, H, K = AddedRemovedKeys> {
283 trie_cursor_factory: T,
285 hashed_cursor_factory: H,
287 hashed_address: B256,
289 prefix_set: PrefixSetMut,
291 collect_branch_node_masks: bool,
293 added_removed_keys: Option<K>,
295 trie_cursor_metrics: Option<&'a mut TrieCursorMetricsCache>,
297 hashed_cursor_metrics: Option<&'a mut HashedCursorMetricsCache>,
299}
300
301impl<T, H> StorageProof<'static, T, H> {
302 pub fn new(t: T, h: H, address: Address) -> Self {
304 Self::new_hashed(t, h, keccak256(address))
305 }
306
307 pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
309 Self {
310 trie_cursor_factory: t,
311 hashed_cursor_factory: h,
312 hashed_address,
313 prefix_set: PrefixSetMut::default(),
314 collect_branch_node_masks: false,
315 added_removed_keys: None,
316 trie_cursor_metrics: None,
317 hashed_cursor_metrics: None,
318 }
319 }
320}
321
322impl<'a, T, H, K> StorageProof<'a, T, H, K> {
323 pub fn with_trie_cursor_factory<TF>(
325 self,
326 trie_cursor_factory: TF,
327 ) -> StorageProof<'a, TF, H, K> {
328 StorageProof {
329 trie_cursor_factory,
330 hashed_cursor_factory: self.hashed_cursor_factory,
331 hashed_address: self.hashed_address,
332 prefix_set: self.prefix_set,
333 collect_branch_node_masks: self.collect_branch_node_masks,
334 added_removed_keys: self.added_removed_keys,
335 trie_cursor_metrics: self.trie_cursor_metrics,
336 hashed_cursor_metrics: self.hashed_cursor_metrics,
337 }
338 }
339
340 pub fn with_hashed_cursor_factory<HF>(
342 self,
343 hashed_cursor_factory: HF,
344 ) -> StorageProof<'a, T, HF, K> {
345 StorageProof {
346 trie_cursor_factory: self.trie_cursor_factory,
347 hashed_cursor_factory,
348 hashed_address: self.hashed_address,
349 prefix_set: self.prefix_set,
350 collect_branch_node_masks: self.collect_branch_node_masks,
351 added_removed_keys: self.added_removed_keys,
352 trie_cursor_metrics: self.trie_cursor_metrics,
353 hashed_cursor_metrics: self.hashed_cursor_metrics,
354 }
355 }
356
357 pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
359 self.prefix_set = prefix_set;
360 self
361 }
362
363 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
365 self.collect_branch_node_masks = branch_node_masks;
366 self
367 }
368
369 pub const fn with_trie_cursor_metrics(
371 mut self,
372 metrics: &'a mut TrieCursorMetricsCache,
373 ) -> Self {
374 self.trie_cursor_metrics = Some(metrics);
375 self
376 }
377
378 pub const fn with_hashed_cursor_metrics(
380 mut self,
381 metrics: &'a mut HashedCursorMetricsCache,
382 ) -> Self {
383 self.hashed_cursor_metrics = Some(metrics);
384 self
385 }
386
387 pub fn with_added_removed_keys<K2>(
393 self,
394 added_removed_keys: Option<K2>,
395 ) -> StorageProof<'a, T, H, K2> {
396 StorageProof {
397 trie_cursor_factory: self.trie_cursor_factory,
398 hashed_cursor_factory: self.hashed_cursor_factory,
399 hashed_address: self.hashed_address,
400 prefix_set: self.prefix_set,
401 collect_branch_node_masks: self.collect_branch_node_masks,
402 added_removed_keys,
403 trie_cursor_metrics: self.trie_cursor_metrics,
404 hashed_cursor_metrics: self.hashed_cursor_metrics,
405 }
406 }
407}
408
409impl<'a, T, H, K> StorageProof<'a, T, H, K>
410where
411 T: TrieCursorFactory,
412 H: HashedCursorFactory,
413 K: AsRef<AddedRemovedKeys>,
414{
415 pub fn storage_proof(
417 self,
418 slot: B256,
419 ) -> Result<reth_trie_common::StorageProof, StateProofError> {
420 let targets = HashSet::from_iter([keccak256(slot)]);
421 Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
422 }
423
424 pub fn storage_multiproof(
426 self,
427 targets: B256Set,
428 ) -> Result<StorageMultiProof, StateProofError> {
429 let mut discard_hashed_cursor_metrics = HashedCursorMetricsCache::default();
430 let hashed_cursor_metrics =
431 self.hashed_cursor_metrics.unwrap_or(&mut discard_hashed_cursor_metrics);
432
433 let hashed_storage_cursor =
434 self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
435
436 let mut hashed_storage_cursor =
437 InstrumentedHashedCursor::new(hashed_storage_cursor, hashed_cursor_metrics);
438
439 if hashed_storage_cursor.is_storage_empty()? {
441 return Ok(StorageMultiProof::empty())
442 }
443
444 let mut discard_trie_cursor_metrics = TrieCursorMetricsCache::default();
445 let trie_cursor_metrics =
446 self.trie_cursor_metrics.unwrap_or(&mut discard_trie_cursor_metrics);
447
448 let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
449 let mut prefix_set = self.prefix_set;
450 prefix_set.extend_keys(target_nibbles.iter().copied());
451
452 let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
453
454 let trie_cursor = InstrumentedTrieCursor::new(trie_cursor, trie_cursor_metrics);
455
456 let walker = TrieWalker::<_>::storage_trie(trie_cursor, prefix_set.freeze())
457 .with_added_removed_keys(self.added_removed_keys.as_ref());
458
459 let retainer = ProofRetainer::from_iter(target_nibbles)
460 .with_added_removed_keys(self.added_removed_keys.as_ref());
461 let mut hash_builder = HashBuilder::default()
462 .with_proof_retainer(retainer)
463 .with_updates(self.collect_branch_node_masks);
464 let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
465 while let Some(node) = storage_node_iter.try_next()? {
466 match node {
467 TrieElement::Branch(node) => {
468 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
469 }
470 TrieElement::Leaf(hashed_slot, value) => {
471 hash_builder.add_leaf(
472 Nibbles::unpack(hashed_slot),
473 alloy_rlp::encode_fixed_size(&value).as_ref(),
474 );
475 }
476 }
477 }
478
479 let root = hash_builder.root();
480 let subtree = hash_builder.take_proof_nodes();
481 let branch_node_masks = if self.collect_branch_node_masks {
482 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
483 updated_branch_nodes
484 .into_iter()
485 .map(|(path, node)| {
486 (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
487 })
488 .collect()
489 } else {
490 BranchNodeMasksMap::default()
491 };
492
493 Ok(StorageMultiProof { root, subtree, branch_node_masks })
494 }
495}