reth_trie/proof/
mod.rs

1use crate::{
2    hashed_cursor::{HashedCursorFactory, HashedStorageCursor},
3    node_iter::{TrieElement, TrieNodeIter},
4    prefix_set::{PrefixSetMut, TriePrefixSetsMut},
5    trie_cursor::TrieCursorFactory,
6    walker::TrieWalker,
7    HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
8};
9use alloy_primitives::{
10    keccak256,
11    map::{B256Map, B256Set, HashMap, HashSet},
12    Address, B256,
13};
14use alloy_rlp::{BufMut, Encodable};
15use alloy_trie::proof::AddedRemovedKeys;
16use reth_execution_errors::trie::StateProofError;
17use reth_trie_common::{
18    proof::ProofRetainer, AccountProof, MultiProof, MultiProofTargets, StorageMultiProof,
19};
20
21mod trie_node;
22pub use trie_node::*;
23
24/// A struct for generating merkle proofs.
25///
26/// Proof generator adds the target address and slots to the prefix set, enables the proof retainer
27/// on the hash builder and follows the same algorithm as the state root calculator.
28/// See `StateRoot::root` for more info.
29#[derive(Debug)]
30pub struct Proof<T, H> {
31    /// The factory for traversing trie nodes.
32    trie_cursor_factory: T,
33    /// The factory for hashed cursors.
34    hashed_cursor_factory: H,
35    /// A set of prefix sets that have changes.
36    prefix_sets: TriePrefixSetsMut,
37    /// Flag indicating whether to include branch node masks in the proof.
38    collect_branch_node_masks: bool,
39}
40
41impl<T, H> Proof<T, H> {
42    /// Create a new [`Proof`] instance.
43    pub fn new(t: T, h: H) -> Self {
44        Self {
45            trie_cursor_factory: t,
46            hashed_cursor_factory: h,
47            prefix_sets: TriePrefixSetsMut::default(),
48            collect_branch_node_masks: false,
49        }
50    }
51
52    /// Set the trie cursor factory.
53    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H> {
54        Proof {
55            trie_cursor_factory,
56            hashed_cursor_factory: self.hashed_cursor_factory,
57            prefix_sets: self.prefix_sets,
58            collect_branch_node_masks: self.collect_branch_node_masks,
59        }
60    }
61
62    /// Set the hashed cursor factory.
63    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF> {
64        Proof {
65            trie_cursor_factory: self.trie_cursor_factory,
66            hashed_cursor_factory,
67            prefix_sets: self.prefix_sets,
68            collect_branch_node_masks: self.collect_branch_node_masks,
69        }
70    }
71
72    /// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
73    pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
74        self.prefix_sets = prefix_sets;
75        self
76    }
77
78    /// Set the flag indicating whether to include branch node masks in the proof.
79    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
80        self.collect_branch_node_masks = branch_node_masks;
81        self
82    }
83
84    /// Get a reference to the trie cursor factory.
85    pub const fn trie_cursor_factory(&self) -> &T {
86        &self.trie_cursor_factory
87    }
88
89    /// Get a reference to the hashed cursor factory.
90    pub const fn hashed_cursor_factory(&self) -> &H {
91        &self.hashed_cursor_factory
92    }
93}
94
95impl<T, H> Proof<T, H>
96where
97    T: TrieCursorFactory + Clone,
98    H: HashedCursorFactory + Clone,
99{
100    /// Generate an account proof from intermediate nodes.
101    pub fn account_proof(
102        self,
103        address: Address,
104        slots: &[B256],
105    ) -> Result<AccountProof, StateProofError> {
106        Ok(self
107            .multiproof(MultiProofTargets::from_iter([(
108                keccak256(address),
109                slots.iter().map(keccak256).collect(),
110            )]))?
111            .account_proof(address, slots)?)
112    }
113
114    /// Generate a state multiproof according to specified targets.
115    pub fn multiproof(
116        mut self,
117        mut targets: MultiProofTargets,
118    ) -> Result<MultiProof, StateProofError> {
119        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
120        let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
121
122        // Create the walker.
123        let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
124        prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
125        let walker = TrieWalker::<_>::state_trie(trie_cursor, prefix_set.freeze());
126
127        // Create a hash builder to rebuild the root node since it is not available in the database.
128        let retainer = targets.keys().map(Nibbles::unpack).collect();
129        let mut hash_builder = HashBuilder::default()
130            .with_proof_retainer(retainer)
131            .with_updates(self.collect_branch_node_masks);
132
133        // Initialize all storage multiproofs as empty.
134        // Storage multiproofs for non-empty tries will be overwritten if necessary.
135        let mut storages: B256Map<_> =
136            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
137        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
138        let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
139        while let Some(account_node) = account_node_iter.try_next()? {
140            match account_node {
141                TrieElement::Branch(node) => {
142                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
143                }
144                TrieElement::Leaf(hashed_address, account) => {
145                    let proof_targets = targets.remove(&hashed_address);
146                    let leaf_is_proof_target = proof_targets.is_some();
147                    let storage_prefix_set = self
148                        .prefix_sets
149                        .storage_prefix_sets
150                        .remove(&hashed_address)
151                        .unwrap_or_default();
152                    let storage_multiproof = StorageProof::new_hashed(
153                        self.trie_cursor_factory.clone(),
154                        self.hashed_cursor_factory.clone(),
155                        hashed_address,
156                    )
157                    .with_prefix_set_mut(storage_prefix_set)
158                    .with_branch_node_masks(self.collect_branch_node_masks)
159                    .storage_multiproof(proof_targets.unwrap_or_default())?;
160
161                    // Encode account
162                    account_rlp.clear();
163                    let account = account.into_trie_account(storage_multiproof.root);
164                    account.encode(&mut account_rlp as &mut dyn BufMut);
165
166                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
167
168                    // We might be adding leaves that are not necessarily our proof targets.
169                    if leaf_is_proof_target {
170                        // Overwrite storage multiproof.
171                        storages.insert(hashed_address, storage_multiproof);
172                    }
173                }
174            }
175        }
176        let _ = hash_builder.root();
177        let account_subtree = hash_builder.take_proof_nodes();
178        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
179            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
180            (
181                updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
182                updated_branch_nodes
183                    .into_iter()
184                    .map(|(path, node)| (path, node.tree_mask))
185                    .collect(),
186            )
187        } else {
188            (HashMap::default(), HashMap::default())
189        };
190
191        Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
192    }
193}
194
195/// Generates storage merkle proofs.
196#[derive(Debug)]
197pub struct StorageProof<T, H, K = AddedRemovedKeys> {
198    /// The factory for traversing trie nodes.
199    trie_cursor_factory: T,
200    /// The factory for hashed cursors.
201    hashed_cursor_factory: H,
202    /// The hashed address of an account.
203    hashed_address: B256,
204    /// The set of storage slot prefixes that have changed.
205    prefix_set: PrefixSetMut,
206    /// Flag indicating whether to include branch node masks in the proof.
207    collect_branch_node_masks: bool,
208    /// Provided by the user to give the necessary context to retain extra proofs.
209    added_removed_keys: Option<K>,
210}
211
212impl<T, H> StorageProof<T, H> {
213    /// Create a new [`StorageProof`] instance.
214    pub fn new(t: T, h: H, address: Address) -> Self {
215        Self::new_hashed(t, h, keccak256(address))
216    }
217
218    /// Create a new [`StorageProof`] instance with hashed address.
219    pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
220        Self {
221            trie_cursor_factory: t,
222            hashed_cursor_factory: h,
223            hashed_address,
224            prefix_set: PrefixSetMut::default(),
225            collect_branch_node_masks: false,
226            added_removed_keys: None,
227        }
228    }
229}
230
231impl<T, H, K> StorageProof<T, H, K> {
232    /// Set the trie cursor factory.
233    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StorageProof<TF, H, K> {
234        StorageProof {
235            trie_cursor_factory,
236            hashed_cursor_factory: self.hashed_cursor_factory,
237            hashed_address: self.hashed_address,
238            prefix_set: self.prefix_set,
239            collect_branch_node_masks: self.collect_branch_node_masks,
240            added_removed_keys: self.added_removed_keys,
241        }
242    }
243
244    /// Set the hashed cursor factory.
245    pub fn with_hashed_cursor_factory<HF>(
246        self,
247        hashed_cursor_factory: HF,
248    ) -> StorageProof<T, HF, K> {
249        StorageProof {
250            trie_cursor_factory: self.trie_cursor_factory,
251            hashed_cursor_factory,
252            hashed_address: self.hashed_address,
253            prefix_set: self.prefix_set,
254            collect_branch_node_masks: self.collect_branch_node_masks,
255            added_removed_keys: self.added_removed_keys,
256        }
257    }
258
259    /// Set the changed prefixes.
260    pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
261        self.prefix_set = prefix_set;
262        self
263    }
264
265    /// Set the flag indicating whether to include branch node masks in the proof.
266    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
267        self.collect_branch_node_masks = branch_node_masks;
268        self
269    }
270
271    /// Configures the retainer to retain proofs for certain nodes which would otherwise fall
272    /// outside the target set, when those nodes might be required to calculate the state root when
273    /// keys have been added or removed to the trie.
274    ///
275    /// If None is given then retention of extra proofs is disabled.
276    pub fn with_added_removed_keys<K2>(
277        self,
278        added_removed_keys: Option<K2>,
279    ) -> StorageProof<T, H, K2> {
280        StorageProof {
281            trie_cursor_factory: self.trie_cursor_factory,
282            hashed_cursor_factory: self.hashed_cursor_factory,
283            hashed_address: self.hashed_address,
284            prefix_set: self.prefix_set,
285            collect_branch_node_masks: self.collect_branch_node_masks,
286            added_removed_keys,
287        }
288    }
289}
290
291impl<T, H, K> StorageProof<T, H, K>
292where
293    T: TrieCursorFactory,
294    H: HashedCursorFactory,
295    K: AsRef<AddedRemovedKeys>,
296{
297    /// Generate an account proof from intermediate nodes.
298    pub fn storage_proof(
299        self,
300        slot: B256,
301    ) -> Result<reth_trie_common::StorageProof, StateProofError> {
302        let targets = HashSet::from_iter([keccak256(slot)]);
303        Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
304    }
305
306    /// Generate storage proof.
307    pub fn storage_multiproof(
308        mut self,
309        targets: B256Set,
310    ) -> Result<StorageMultiProof, StateProofError> {
311        let mut hashed_storage_cursor =
312            self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
313
314        // short circuit on empty storage
315        if hashed_storage_cursor.is_storage_empty()? {
316            return Ok(StorageMultiProof::empty())
317        }
318
319        let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
320        self.prefix_set.extend_keys(target_nibbles.clone());
321
322        let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
323        let walker = TrieWalker::<_>::storage_trie(trie_cursor, self.prefix_set.freeze())
324            .with_added_removed_keys(self.added_removed_keys.as_ref());
325
326        let retainer = ProofRetainer::from_iter(target_nibbles)
327            .with_added_removed_keys(self.added_removed_keys.as_ref());
328        let mut hash_builder = HashBuilder::default()
329            .with_proof_retainer(retainer)
330            .with_updates(self.collect_branch_node_masks);
331        let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
332        while let Some(node) = storage_node_iter.try_next()? {
333            match node {
334                TrieElement::Branch(node) => {
335                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
336                }
337                TrieElement::Leaf(hashed_slot, value) => {
338                    hash_builder.add_leaf(
339                        Nibbles::unpack(hashed_slot),
340                        alloy_rlp::encode_fixed_size(&value).as_ref(),
341                    );
342                }
343            }
344        }
345
346        let root = hash_builder.root();
347        let subtree = hash_builder.take_proof_nodes();
348        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
349            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
350            (
351                updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
352                updated_branch_nodes
353                    .into_iter()
354                    .map(|(path, node)| (path, node.tree_mask))
355                    .collect(),
356            )
357        } else {
358            (HashMap::default(), HashMap::default())
359        };
360
361        Ok(StorageMultiProof { root, subtree, branch_node_hash_masks, branch_node_tree_masks })
362    }
363}