1use crate::{
2 hashed_cursor::{
3 HashedCursorFactory, HashedCursorMetricsCache, HashedStorageCursor,
4 InstrumentedHashedCursor,
5 },
6 node_iter::{TrieElement, TrieNodeIter},
7 prefix_set::{PrefixSetMut, TriePrefixSetsMut},
8 trie_cursor::{InstrumentedTrieCursor, TrieCursorFactory, TrieCursorMetricsCache},
9 walker::TrieWalker,
10 HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
11};
12use alloy_primitives::{
13 keccak256,
14 map::{B256Map, B256Set, HashMap, HashSet},
15 Address, B256,
16};
17use alloy_rlp::{BufMut, Encodable};
18use alloy_trie::proof::AddedRemovedKeys;
19use reth_execution_errors::trie::StateProofError;
20use reth_trie_common::{
21 proof::ProofRetainer, AccountProof, MultiProof, MultiProofTargets, StorageMultiProof,
22};
23
24mod trie_node;
25pub use trie_node::*;
26
27#[derive(Debug)]
33pub struct Proof<T, H> {
34 trie_cursor_factory: T,
36 hashed_cursor_factory: H,
38 prefix_sets: TriePrefixSetsMut,
40 collect_branch_node_masks: bool,
42}
43
44impl<T, H> Proof<T, H> {
45 pub fn new(t: T, h: H) -> Self {
47 Self {
48 trie_cursor_factory: t,
49 hashed_cursor_factory: h,
50 prefix_sets: TriePrefixSetsMut::default(),
51 collect_branch_node_masks: false,
52 }
53 }
54
55 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H> {
57 Proof {
58 trie_cursor_factory,
59 hashed_cursor_factory: self.hashed_cursor_factory,
60 prefix_sets: self.prefix_sets,
61 collect_branch_node_masks: self.collect_branch_node_masks,
62 }
63 }
64
65 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF> {
67 Proof {
68 trie_cursor_factory: self.trie_cursor_factory,
69 hashed_cursor_factory,
70 prefix_sets: self.prefix_sets,
71 collect_branch_node_masks: self.collect_branch_node_masks,
72 }
73 }
74
75 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
77 self.prefix_sets = prefix_sets;
78 self
79 }
80
81 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
83 self.collect_branch_node_masks = branch_node_masks;
84 self
85 }
86
87 pub const fn trie_cursor_factory(&self) -> &T {
89 &self.trie_cursor_factory
90 }
91
92 pub const fn hashed_cursor_factory(&self) -> &H {
94 &self.hashed_cursor_factory
95 }
96}
97
98impl<T, H> Proof<T, H>
99where
100 T: TrieCursorFactory + Clone,
101 H: HashedCursorFactory + Clone,
102{
103 pub fn account_proof(
105 self,
106 address: Address,
107 slots: &[B256],
108 ) -> Result<AccountProof, StateProofError> {
109 Ok(self
110 .multiproof(MultiProofTargets::from_iter([(
111 keccak256(address),
112 slots.iter().map(keccak256).collect(),
113 )]))?
114 .account_proof(address, slots)?)
115 }
116
117 pub fn multiproof(
119 mut self,
120 mut targets: MultiProofTargets,
121 ) -> Result<MultiProof, StateProofError> {
122 let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
123 let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
124
125 let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
127 prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
128 let walker = TrieWalker::<_>::state_trie(trie_cursor, prefix_set.freeze());
129
130 let retainer = targets.keys().map(Nibbles::unpack).collect();
132 let mut hash_builder = HashBuilder::default()
133 .with_proof_retainer(retainer)
134 .with_updates(self.collect_branch_node_masks);
135
136 let mut storages: B256Map<_> =
139 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
140 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
141 let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
142 while let Some(account_node) = account_node_iter.try_next()? {
143 match account_node {
144 TrieElement::Branch(node) => {
145 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
146 }
147 TrieElement::Leaf(hashed_address, account) => {
148 let proof_targets = targets.remove(&hashed_address);
149 let leaf_is_proof_target = proof_targets.is_some();
150 let collect_storage_masks =
151 self.collect_branch_node_masks && leaf_is_proof_target;
152 let storage_prefix_set = self
153 .prefix_sets
154 .storage_prefix_sets
155 .remove(&hashed_address)
156 .unwrap_or_default();
157 let storage_multiproof = StorageProof::new_hashed(
158 self.trie_cursor_factory.clone(),
159 self.hashed_cursor_factory.clone(),
160 hashed_address,
161 )
162 .with_prefix_set_mut(storage_prefix_set)
163 .with_branch_node_masks(collect_storage_masks)
164 .storage_multiproof(proof_targets.unwrap_or_default())?;
165
166 account_rlp.clear();
168 let account = account.into_trie_account(storage_multiproof.root);
169 account.encode(&mut account_rlp as &mut dyn BufMut);
170
171 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
172
173 if leaf_is_proof_target {
175 storages.insert(hashed_address, storage_multiproof);
177 }
178 }
179 }
180 }
181 let _ = hash_builder.root();
182 let account_subtree = hash_builder.take_proof_nodes();
183 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
184 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
185 (
186 updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
187 updated_branch_nodes
188 .into_iter()
189 .map(|(path, node)| (path, node.tree_mask))
190 .collect(),
191 )
192 } else {
193 (HashMap::default(), HashMap::default())
194 };
195
196 Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
197 }
198}
199
200#[derive(Debug)]
202pub struct StorageProof<'a, T, H, K = AddedRemovedKeys> {
203 trie_cursor_factory: T,
205 hashed_cursor_factory: H,
207 hashed_address: B256,
209 prefix_set: PrefixSetMut,
211 collect_branch_node_masks: bool,
213 added_removed_keys: Option<K>,
215 trie_cursor_metrics: Option<&'a mut TrieCursorMetricsCache>,
217 hashed_cursor_metrics: Option<&'a mut HashedCursorMetricsCache>,
219}
220
221impl<T, H> StorageProof<'static, T, H> {
222 pub fn new(t: T, h: H, address: Address) -> Self {
224 Self::new_hashed(t, h, keccak256(address))
225 }
226
227 pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
229 Self {
230 trie_cursor_factory: t,
231 hashed_cursor_factory: h,
232 hashed_address,
233 prefix_set: PrefixSetMut::default(),
234 collect_branch_node_masks: false,
235 added_removed_keys: None,
236 trie_cursor_metrics: None,
237 hashed_cursor_metrics: None,
238 }
239 }
240}
241
242impl<'a, T, H, K> StorageProof<'a, T, H, K> {
243 pub fn with_trie_cursor_factory<TF>(
245 self,
246 trie_cursor_factory: TF,
247 ) -> StorageProof<'a, TF, H, K> {
248 StorageProof {
249 trie_cursor_factory,
250 hashed_cursor_factory: self.hashed_cursor_factory,
251 hashed_address: self.hashed_address,
252 prefix_set: self.prefix_set,
253 collect_branch_node_masks: self.collect_branch_node_masks,
254 added_removed_keys: self.added_removed_keys,
255 trie_cursor_metrics: self.trie_cursor_metrics,
256 hashed_cursor_metrics: self.hashed_cursor_metrics,
257 }
258 }
259
260 pub fn with_hashed_cursor_factory<HF>(
262 self,
263 hashed_cursor_factory: HF,
264 ) -> StorageProof<'a, T, HF, K> {
265 StorageProof {
266 trie_cursor_factory: self.trie_cursor_factory,
267 hashed_cursor_factory,
268 hashed_address: self.hashed_address,
269 prefix_set: self.prefix_set,
270 collect_branch_node_masks: self.collect_branch_node_masks,
271 added_removed_keys: self.added_removed_keys,
272 trie_cursor_metrics: self.trie_cursor_metrics,
273 hashed_cursor_metrics: self.hashed_cursor_metrics,
274 }
275 }
276
277 pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
279 self.prefix_set = prefix_set;
280 self
281 }
282
283 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
285 self.collect_branch_node_masks = branch_node_masks;
286 self
287 }
288
289 pub const fn with_trie_cursor_metrics(
291 mut self,
292 metrics: &'a mut TrieCursorMetricsCache,
293 ) -> Self {
294 self.trie_cursor_metrics = Some(metrics);
295 self
296 }
297
298 pub const fn with_hashed_cursor_metrics(
300 mut self,
301 metrics: &'a mut HashedCursorMetricsCache,
302 ) -> Self {
303 self.hashed_cursor_metrics = Some(metrics);
304 self
305 }
306
307 pub fn with_added_removed_keys<K2>(
313 self,
314 added_removed_keys: Option<K2>,
315 ) -> StorageProof<'a, T, H, K2> {
316 StorageProof {
317 trie_cursor_factory: self.trie_cursor_factory,
318 hashed_cursor_factory: self.hashed_cursor_factory,
319 hashed_address: self.hashed_address,
320 prefix_set: self.prefix_set,
321 collect_branch_node_masks: self.collect_branch_node_masks,
322 added_removed_keys,
323 trie_cursor_metrics: self.trie_cursor_metrics,
324 hashed_cursor_metrics: self.hashed_cursor_metrics,
325 }
326 }
327}
328
329impl<'a, T, H, K> StorageProof<'a, T, H, K>
330where
331 T: TrieCursorFactory,
332 H: HashedCursorFactory,
333 K: AsRef<AddedRemovedKeys>,
334{
335 pub fn storage_proof(
337 self,
338 slot: B256,
339 ) -> Result<reth_trie_common::StorageProof, StateProofError> {
340 let targets = HashSet::from_iter([keccak256(slot)]);
341 Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
342 }
343
344 pub fn storage_multiproof(
346 self,
347 targets: B256Set,
348 ) -> Result<StorageMultiProof, StateProofError> {
349 let mut discard_hashed_cursor_metrics = HashedCursorMetricsCache::default();
350 let hashed_cursor_metrics =
351 self.hashed_cursor_metrics.unwrap_or(&mut discard_hashed_cursor_metrics);
352
353 let hashed_storage_cursor =
354 self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
355
356 let mut hashed_storage_cursor =
357 InstrumentedHashedCursor::new(hashed_storage_cursor, hashed_cursor_metrics);
358
359 if hashed_storage_cursor.is_storage_empty()? {
361 return Ok(StorageMultiProof::empty())
362 }
363
364 let mut discard_trie_cursor_metrics = TrieCursorMetricsCache::default();
365 let trie_cursor_metrics =
366 self.trie_cursor_metrics.unwrap_or(&mut discard_trie_cursor_metrics);
367
368 let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
369 let mut prefix_set = self.prefix_set;
370 prefix_set.extend_keys(target_nibbles.clone());
371
372 let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
373
374 let trie_cursor = InstrumentedTrieCursor::new(trie_cursor, trie_cursor_metrics);
375
376 let walker = TrieWalker::<_>::storage_trie(trie_cursor, prefix_set.freeze())
377 .with_added_removed_keys(self.added_removed_keys.as_ref());
378
379 let retainer = ProofRetainer::from_iter(target_nibbles)
380 .with_added_removed_keys(self.added_removed_keys.as_ref());
381 let mut hash_builder = HashBuilder::default()
382 .with_proof_retainer(retainer)
383 .with_updates(self.collect_branch_node_masks);
384 let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
385 while let Some(node) = storage_node_iter.try_next()? {
386 match node {
387 TrieElement::Branch(node) => {
388 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
389 }
390 TrieElement::Leaf(hashed_slot, value) => {
391 hash_builder.add_leaf(
392 Nibbles::unpack(hashed_slot),
393 alloy_rlp::encode_fixed_size(&value).as_ref(),
394 );
395 }
396 }
397 }
398
399 let root = hash_builder.root();
400 let subtree = hash_builder.take_proof_nodes();
401 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
402 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
403 (
404 updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
405 updated_branch_nodes
406 .into_iter()
407 .map(|(path, node)| (path, node.tree_mask))
408 .collect(),
409 )
410 } else {
411 (HashMap::default(), HashMap::default())
412 };
413
414 Ok(StorageMultiProof { root, subtree, branch_node_hash_masks, branch_node_tree_masks })
415 }
416}