1use crate::{
2 hashed_cursor::{HashedCursorFactory, HashedStorageCursor},
3 node_iter::{TrieElement, TrieNodeIter},
4 prefix_set::{PrefixSetMut, TriePrefixSetsMut},
5 trie_cursor::TrieCursorFactory,
6 walker::TrieWalker,
7 HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
8};
9use alloy_primitives::{
10 keccak256,
11 map::{B256Map, B256Set, HashMap, HashSet},
12 Address, B256,
13};
14use alloy_rlp::{BufMut, Encodable};
15use alloy_trie::proof::AddedRemovedKeys;
16use reth_execution_errors::trie::StateProofError;
17use reth_trie_common::{
18 proof::ProofRetainer, AccountProof, MultiProof, MultiProofTargets, StorageMultiProof,
19};
20
21mod trie_node;
22pub use trie_node::*;
23
24#[derive(Debug)]
30pub struct Proof<T, H> {
31 trie_cursor_factory: T,
33 hashed_cursor_factory: H,
35 prefix_sets: TriePrefixSetsMut,
37 collect_branch_node_masks: bool,
39}
40
41impl<T, H> Proof<T, H> {
42 pub fn new(t: T, h: H) -> Self {
44 Self {
45 trie_cursor_factory: t,
46 hashed_cursor_factory: h,
47 prefix_sets: TriePrefixSetsMut::default(),
48 collect_branch_node_masks: false,
49 }
50 }
51
52 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H> {
54 Proof {
55 trie_cursor_factory,
56 hashed_cursor_factory: self.hashed_cursor_factory,
57 prefix_sets: self.prefix_sets,
58 collect_branch_node_masks: self.collect_branch_node_masks,
59 }
60 }
61
62 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF> {
64 Proof {
65 trie_cursor_factory: self.trie_cursor_factory,
66 hashed_cursor_factory,
67 prefix_sets: self.prefix_sets,
68 collect_branch_node_masks: self.collect_branch_node_masks,
69 }
70 }
71
72 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
74 self.prefix_sets = prefix_sets;
75 self
76 }
77
78 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
80 self.collect_branch_node_masks = branch_node_masks;
81 self
82 }
83
84 pub const fn trie_cursor_factory(&self) -> &T {
86 &self.trie_cursor_factory
87 }
88
89 pub const fn hashed_cursor_factory(&self) -> &H {
91 &self.hashed_cursor_factory
92 }
93}
94
95impl<T, H> Proof<T, H>
96where
97 T: TrieCursorFactory + Clone,
98 H: HashedCursorFactory + Clone,
99{
100 pub fn account_proof(
102 self,
103 address: Address,
104 slots: &[B256],
105 ) -> Result<AccountProof, StateProofError> {
106 Ok(self
107 .multiproof(MultiProofTargets::from_iter([(
108 keccak256(address),
109 slots.iter().map(keccak256).collect(),
110 )]))?
111 .account_proof(address, slots)?)
112 }
113
114 pub fn multiproof(
116 mut self,
117 mut targets: MultiProofTargets,
118 ) -> Result<MultiProof, StateProofError> {
119 let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
120 let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
121
122 let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
124 prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
125 let walker = TrieWalker::<_>::state_trie(trie_cursor, prefix_set.freeze());
126
127 let retainer = targets.keys().map(Nibbles::unpack).collect();
129 let mut hash_builder = HashBuilder::default()
130 .with_proof_retainer(retainer)
131 .with_updates(self.collect_branch_node_masks);
132
133 let mut storages: B256Map<_> =
136 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
137 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
138 let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
139 while let Some(account_node) = account_node_iter.try_next()? {
140 match account_node {
141 TrieElement::Branch(node) => {
142 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
143 }
144 TrieElement::Leaf(hashed_address, account) => {
145 let proof_targets = targets.remove(&hashed_address);
146 let leaf_is_proof_target = proof_targets.is_some();
147 let storage_prefix_set = self
148 .prefix_sets
149 .storage_prefix_sets
150 .remove(&hashed_address)
151 .unwrap_or_default();
152 let storage_multiproof = StorageProof::new_hashed(
153 self.trie_cursor_factory.clone(),
154 self.hashed_cursor_factory.clone(),
155 hashed_address,
156 )
157 .with_prefix_set_mut(storage_prefix_set)
158 .with_branch_node_masks(self.collect_branch_node_masks)
159 .storage_multiproof(proof_targets.unwrap_or_default())?;
160
161 account_rlp.clear();
163 let account = account.into_trie_account(storage_multiproof.root);
164 account.encode(&mut account_rlp as &mut dyn BufMut);
165
166 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
167
168 if leaf_is_proof_target {
170 storages.insert(hashed_address, storage_multiproof);
172 }
173 }
174 }
175 }
176 let _ = hash_builder.root();
177 let account_subtree = hash_builder.take_proof_nodes();
178 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
179 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
180 (
181 updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
182 updated_branch_nodes
183 .into_iter()
184 .map(|(path, node)| (path, node.tree_mask))
185 .collect(),
186 )
187 } else {
188 (HashMap::default(), HashMap::default())
189 };
190
191 Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
192 }
193}
194
195#[derive(Debug)]
197pub struct StorageProof<T, H, K = AddedRemovedKeys> {
198 trie_cursor_factory: T,
200 hashed_cursor_factory: H,
202 hashed_address: B256,
204 prefix_set: PrefixSetMut,
206 collect_branch_node_masks: bool,
208 added_removed_keys: Option<K>,
210}
211
212impl<T, H> StorageProof<T, H> {
213 pub fn new(t: T, h: H, address: Address) -> Self {
215 Self::new_hashed(t, h, keccak256(address))
216 }
217
218 pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
220 Self {
221 trie_cursor_factory: t,
222 hashed_cursor_factory: h,
223 hashed_address,
224 prefix_set: PrefixSetMut::default(),
225 collect_branch_node_masks: false,
226 added_removed_keys: None,
227 }
228 }
229}
230
231impl<T, H, K> StorageProof<T, H, K> {
232 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StorageProof<TF, H, K> {
234 StorageProof {
235 trie_cursor_factory,
236 hashed_cursor_factory: self.hashed_cursor_factory,
237 hashed_address: self.hashed_address,
238 prefix_set: self.prefix_set,
239 collect_branch_node_masks: self.collect_branch_node_masks,
240 added_removed_keys: self.added_removed_keys,
241 }
242 }
243
244 pub fn with_hashed_cursor_factory<HF>(
246 self,
247 hashed_cursor_factory: HF,
248 ) -> StorageProof<T, HF, K> {
249 StorageProof {
250 trie_cursor_factory: self.trie_cursor_factory,
251 hashed_cursor_factory,
252 hashed_address: self.hashed_address,
253 prefix_set: self.prefix_set,
254 collect_branch_node_masks: self.collect_branch_node_masks,
255 added_removed_keys: self.added_removed_keys,
256 }
257 }
258
259 pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
261 self.prefix_set = prefix_set;
262 self
263 }
264
265 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
267 self.collect_branch_node_masks = branch_node_masks;
268 self
269 }
270
271 pub fn with_added_removed_keys<K2>(
277 self,
278 added_removed_keys: Option<K2>,
279 ) -> StorageProof<T, H, K2> {
280 StorageProof {
281 trie_cursor_factory: self.trie_cursor_factory,
282 hashed_cursor_factory: self.hashed_cursor_factory,
283 hashed_address: self.hashed_address,
284 prefix_set: self.prefix_set,
285 collect_branch_node_masks: self.collect_branch_node_masks,
286 added_removed_keys,
287 }
288 }
289}
290
291impl<T, H, K> StorageProof<T, H, K>
292where
293 T: TrieCursorFactory,
294 H: HashedCursorFactory,
295 K: AsRef<AddedRemovedKeys>,
296{
297 pub fn storage_proof(
299 self,
300 slot: B256,
301 ) -> Result<reth_trie_common::StorageProof, StateProofError> {
302 let targets = HashSet::from_iter([keccak256(slot)]);
303 Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
304 }
305
306 pub fn storage_multiproof(
308 mut self,
309 targets: B256Set,
310 ) -> Result<StorageMultiProof, StateProofError> {
311 let mut hashed_storage_cursor =
312 self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
313
314 if hashed_storage_cursor.is_storage_empty()? {
316 return Ok(StorageMultiProof::empty())
317 }
318
319 let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
320 self.prefix_set.extend_keys(target_nibbles.clone());
321
322 let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
323 let walker = TrieWalker::<_>::storage_trie(trie_cursor, self.prefix_set.freeze())
324 .with_added_removed_keys(self.added_removed_keys.as_ref());
325
326 let retainer = ProofRetainer::from_iter(target_nibbles)
327 .with_added_removed_keys(self.added_removed_keys.as_ref());
328 let mut hash_builder = HashBuilder::default()
329 .with_proof_retainer(retainer)
330 .with_updates(self.collect_branch_node_masks);
331 let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
332 while let Some(node) = storage_node_iter.try_next()? {
333 match node {
334 TrieElement::Branch(node) => {
335 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
336 }
337 TrieElement::Leaf(hashed_slot, value) => {
338 hash_builder.add_leaf(
339 Nibbles::unpack(hashed_slot),
340 alloy_rlp::encode_fixed_size(&value).as_ref(),
341 );
342 }
343 }
344 }
345
346 let root = hash_builder.root();
347 let subtree = hash_builder.take_proof_nodes();
348 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
349 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
350 (
351 updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
352 updated_branch_nodes
353 .into_iter()
354 .map(|(path, node)| (path, node.tree_mask))
355 .collect(),
356 )
357 } else {
358 (HashMap::default(), HashMap::default())
359 };
360
361 Ok(StorageMultiProof { root, subtree, branch_node_hash_masks, branch_node_tree_masks })
362 }
363}