reth_trie/proof/
mod.rs

1use crate::{
2    hashed_cursor::{
3        HashedCursorFactory, HashedCursorMetricsCache, HashedStorageCursor,
4        InstrumentedHashedCursor,
5    },
6    node_iter::{TrieElement, TrieNodeIter},
7    prefix_set::{PrefixSetMut, TriePrefixSetsMut},
8    trie_cursor::{InstrumentedTrieCursor, TrieCursorFactory, TrieCursorMetricsCache},
9    walker::TrieWalker,
10    HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
11};
12use alloy_primitives::{
13    keccak256,
14    map::{B256Map, B256Set, HashMap, HashSet},
15    Address, B256,
16};
17use alloy_rlp::{BufMut, Encodable};
18use alloy_trie::proof::AddedRemovedKeys;
19use reth_execution_errors::trie::StateProofError;
20use reth_trie_common::{
21    proof::ProofRetainer, AccountProof, MultiProof, MultiProofTargets, StorageMultiProof,
22};
23
24mod trie_node;
25pub use trie_node::*;
26
27/// A struct for generating merkle proofs.
28///
29/// Proof generator adds the target address and slots to the prefix set, enables the proof retainer
30/// on the hash builder and follows the same algorithm as the state root calculator.
31/// See `StateRoot::root` for more info.
32#[derive(Debug)]
33pub struct Proof<T, H> {
34    /// The factory for traversing trie nodes.
35    trie_cursor_factory: T,
36    /// The factory for hashed cursors.
37    hashed_cursor_factory: H,
38    /// A set of prefix sets that have changes.
39    prefix_sets: TriePrefixSetsMut,
40    /// Flag indicating whether to include branch node masks in the proof.
41    collect_branch_node_masks: bool,
42}
43
44impl<T, H> Proof<T, H> {
45    /// Create a new [`Proof`] instance.
46    pub fn new(t: T, h: H) -> Self {
47        Self {
48            trie_cursor_factory: t,
49            hashed_cursor_factory: h,
50            prefix_sets: TriePrefixSetsMut::default(),
51            collect_branch_node_masks: false,
52        }
53    }
54
55    /// Set the trie cursor factory.
56    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H> {
57        Proof {
58            trie_cursor_factory,
59            hashed_cursor_factory: self.hashed_cursor_factory,
60            prefix_sets: self.prefix_sets,
61            collect_branch_node_masks: self.collect_branch_node_masks,
62        }
63    }
64
65    /// Set the hashed cursor factory.
66    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF> {
67        Proof {
68            trie_cursor_factory: self.trie_cursor_factory,
69            hashed_cursor_factory,
70            prefix_sets: self.prefix_sets,
71            collect_branch_node_masks: self.collect_branch_node_masks,
72        }
73    }
74
75    /// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
76    pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
77        self.prefix_sets = prefix_sets;
78        self
79    }
80
81    /// Set the flag indicating whether to include branch node masks in the proof.
82    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
83        self.collect_branch_node_masks = branch_node_masks;
84        self
85    }
86
87    /// Get a reference to the trie cursor factory.
88    pub const fn trie_cursor_factory(&self) -> &T {
89        &self.trie_cursor_factory
90    }
91
92    /// Get a reference to the hashed cursor factory.
93    pub const fn hashed_cursor_factory(&self) -> &H {
94        &self.hashed_cursor_factory
95    }
96}
97
98impl<T, H> Proof<T, H>
99where
100    T: TrieCursorFactory + Clone,
101    H: HashedCursorFactory + Clone,
102{
103    /// Generate an account proof from intermediate nodes.
104    pub fn account_proof(
105        self,
106        address: Address,
107        slots: &[B256],
108    ) -> Result<AccountProof, StateProofError> {
109        Ok(self
110            .multiproof(MultiProofTargets::from_iter([(
111                keccak256(address),
112                slots.iter().map(keccak256).collect(),
113            )]))?
114            .account_proof(address, slots)?)
115    }
116
117    /// Generate a state multiproof according to specified targets.
118    pub fn multiproof(
119        mut self,
120        mut targets: MultiProofTargets,
121    ) -> Result<MultiProof, StateProofError> {
122        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
123        let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
124
125        // Create the walker.
126        let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
127        prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
128        let walker = TrieWalker::<_>::state_trie(trie_cursor, prefix_set.freeze());
129
130        // Create a hash builder to rebuild the root node since it is not available in the database.
131        let retainer = targets.keys().map(Nibbles::unpack).collect();
132        let mut hash_builder = HashBuilder::default()
133            .with_proof_retainer(retainer)
134            .with_updates(self.collect_branch_node_masks);
135
136        // Initialize all storage multiproofs as empty.
137        // Storage multiproofs for non-empty tries will be overwritten if necessary.
138        let mut storages: B256Map<_> =
139            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
140        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
141        let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
142        while let Some(account_node) = account_node_iter.try_next()? {
143            match account_node {
144                TrieElement::Branch(node) => {
145                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
146                }
147                TrieElement::Leaf(hashed_address, account) => {
148                    let proof_targets = targets.remove(&hashed_address);
149                    let leaf_is_proof_target = proof_targets.is_some();
150                    let collect_storage_masks =
151                        self.collect_branch_node_masks && leaf_is_proof_target;
152                    let storage_prefix_set = self
153                        .prefix_sets
154                        .storage_prefix_sets
155                        .remove(&hashed_address)
156                        .unwrap_or_default();
157                    let storage_multiproof = StorageProof::new_hashed(
158                        self.trie_cursor_factory.clone(),
159                        self.hashed_cursor_factory.clone(),
160                        hashed_address,
161                    )
162                    .with_prefix_set_mut(storage_prefix_set)
163                    .with_branch_node_masks(collect_storage_masks)
164                    .storage_multiproof(proof_targets.unwrap_or_default())?;
165
166                    // Encode account
167                    account_rlp.clear();
168                    let account = account.into_trie_account(storage_multiproof.root);
169                    account.encode(&mut account_rlp as &mut dyn BufMut);
170
171                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
172
173                    // We might be adding leaves that are not necessarily our proof targets.
174                    if leaf_is_proof_target {
175                        // Overwrite storage multiproof.
176                        storages.insert(hashed_address, storage_multiproof);
177                    }
178                }
179            }
180        }
181        let _ = hash_builder.root();
182        let account_subtree = hash_builder.take_proof_nodes();
183        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
184            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
185            (
186                updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
187                updated_branch_nodes
188                    .into_iter()
189                    .map(|(path, node)| (path, node.tree_mask))
190                    .collect(),
191            )
192        } else {
193            (HashMap::default(), HashMap::default())
194        };
195
196        Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
197    }
198}
199
200/// Generates storage merkle proofs.
201#[derive(Debug)]
202pub struct StorageProof<'a, T, H, K = AddedRemovedKeys> {
203    /// The factory for traversing trie nodes.
204    trie_cursor_factory: T,
205    /// The factory for hashed cursors.
206    hashed_cursor_factory: H,
207    /// The hashed address of an account.
208    hashed_address: B256,
209    /// The set of storage slot prefixes that have changed.
210    prefix_set: PrefixSetMut,
211    /// Flag indicating whether to include branch node masks in the proof.
212    collect_branch_node_masks: bool,
213    /// Provided by the user to give the necessary context to retain extra proofs.
214    added_removed_keys: Option<K>,
215    /// Optional reference to accumulate trie cursor metrics.
216    trie_cursor_metrics: Option<&'a mut TrieCursorMetricsCache>,
217    /// Optional reference to accumulate hashed cursor metrics.
218    hashed_cursor_metrics: Option<&'a mut HashedCursorMetricsCache>,
219}
220
221impl<T, H> StorageProof<'static, T, H> {
222    /// Create a new [`StorageProof`] instance.
223    pub fn new(t: T, h: H, address: Address) -> Self {
224        Self::new_hashed(t, h, keccak256(address))
225    }
226
227    /// Create a new [`StorageProof`] instance with hashed address.
228    pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
229        Self {
230            trie_cursor_factory: t,
231            hashed_cursor_factory: h,
232            hashed_address,
233            prefix_set: PrefixSetMut::default(),
234            collect_branch_node_masks: false,
235            added_removed_keys: None,
236            trie_cursor_metrics: None,
237            hashed_cursor_metrics: None,
238        }
239    }
240}
241
242impl<'a, T, H, K> StorageProof<'a, T, H, K> {
243    /// Set the trie cursor factory.
244    pub fn with_trie_cursor_factory<TF>(
245        self,
246        trie_cursor_factory: TF,
247    ) -> StorageProof<'a, TF, H, K> {
248        StorageProof {
249            trie_cursor_factory,
250            hashed_cursor_factory: self.hashed_cursor_factory,
251            hashed_address: self.hashed_address,
252            prefix_set: self.prefix_set,
253            collect_branch_node_masks: self.collect_branch_node_masks,
254            added_removed_keys: self.added_removed_keys,
255            trie_cursor_metrics: self.trie_cursor_metrics,
256            hashed_cursor_metrics: self.hashed_cursor_metrics,
257        }
258    }
259
260    /// Set the hashed cursor factory.
261    pub fn with_hashed_cursor_factory<HF>(
262        self,
263        hashed_cursor_factory: HF,
264    ) -> StorageProof<'a, T, HF, K> {
265        StorageProof {
266            trie_cursor_factory: self.trie_cursor_factory,
267            hashed_cursor_factory,
268            hashed_address: self.hashed_address,
269            prefix_set: self.prefix_set,
270            collect_branch_node_masks: self.collect_branch_node_masks,
271            added_removed_keys: self.added_removed_keys,
272            trie_cursor_metrics: self.trie_cursor_metrics,
273            hashed_cursor_metrics: self.hashed_cursor_metrics,
274        }
275    }
276
277    /// Set the changed prefixes.
278    pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
279        self.prefix_set = prefix_set;
280        self
281    }
282
283    /// Set the flag indicating whether to include branch node masks in the proof.
284    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
285        self.collect_branch_node_masks = branch_node_masks;
286        self
287    }
288
289    /// Set the trie cursor metrics cache to accumulate metrics into.
290    pub const fn with_trie_cursor_metrics(
291        mut self,
292        metrics: &'a mut TrieCursorMetricsCache,
293    ) -> Self {
294        self.trie_cursor_metrics = Some(metrics);
295        self
296    }
297
298    /// Set the hashed cursor metrics cache to accumulate metrics into.
299    pub const fn with_hashed_cursor_metrics(
300        mut self,
301        metrics: &'a mut HashedCursorMetricsCache,
302    ) -> Self {
303        self.hashed_cursor_metrics = Some(metrics);
304        self
305    }
306
307    /// Configures the retainer to retain proofs for certain nodes which would otherwise fall
308    /// outside the target set, when those nodes might be required to calculate the state root when
309    /// keys have been added or removed to the trie.
310    ///
311    /// If None is given then retention of extra proofs is disabled.
312    pub fn with_added_removed_keys<K2>(
313        self,
314        added_removed_keys: Option<K2>,
315    ) -> StorageProof<'a, T, H, K2> {
316        StorageProof {
317            trie_cursor_factory: self.trie_cursor_factory,
318            hashed_cursor_factory: self.hashed_cursor_factory,
319            hashed_address: self.hashed_address,
320            prefix_set: self.prefix_set,
321            collect_branch_node_masks: self.collect_branch_node_masks,
322            added_removed_keys,
323            trie_cursor_metrics: self.trie_cursor_metrics,
324            hashed_cursor_metrics: self.hashed_cursor_metrics,
325        }
326    }
327}
328
329impl<'a, T, H, K> StorageProof<'a, T, H, K>
330where
331    T: TrieCursorFactory,
332    H: HashedCursorFactory,
333    K: AsRef<AddedRemovedKeys>,
334{
335    /// Generate an account proof from intermediate nodes.
336    pub fn storage_proof(
337        self,
338        slot: B256,
339    ) -> Result<reth_trie_common::StorageProof, StateProofError> {
340        let targets = HashSet::from_iter([keccak256(slot)]);
341        Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
342    }
343
344    /// Generate storage proof.
345    pub fn storage_multiproof(
346        self,
347        targets: B256Set,
348    ) -> Result<StorageMultiProof, StateProofError> {
349        let mut discard_hashed_cursor_metrics = HashedCursorMetricsCache::default();
350        let hashed_cursor_metrics =
351            self.hashed_cursor_metrics.unwrap_or(&mut discard_hashed_cursor_metrics);
352
353        let hashed_storage_cursor =
354            self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
355
356        let mut hashed_storage_cursor =
357            InstrumentedHashedCursor::new(hashed_storage_cursor, hashed_cursor_metrics);
358
359        // short circuit on empty storage
360        if hashed_storage_cursor.is_storage_empty()? {
361            return Ok(StorageMultiProof::empty())
362        }
363
364        let mut discard_trie_cursor_metrics = TrieCursorMetricsCache::default();
365        let trie_cursor_metrics =
366            self.trie_cursor_metrics.unwrap_or(&mut discard_trie_cursor_metrics);
367
368        let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
369        let mut prefix_set = self.prefix_set;
370        prefix_set.extend_keys(target_nibbles.clone());
371
372        let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
373
374        let trie_cursor = InstrumentedTrieCursor::new(trie_cursor, trie_cursor_metrics);
375
376        let walker = TrieWalker::<_>::storage_trie(trie_cursor, prefix_set.freeze())
377            .with_added_removed_keys(self.added_removed_keys.as_ref());
378
379        let retainer = ProofRetainer::from_iter(target_nibbles)
380            .with_added_removed_keys(self.added_removed_keys.as_ref());
381        let mut hash_builder = HashBuilder::default()
382            .with_proof_retainer(retainer)
383            .with_updates(self.collect_branch_node_masks);
384        let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
385        while let Some(node) = storage_node_iter.try_next()? {
386            match node {
387                TrieElement::Branch(node) => {
388                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
389                }
390                TrieElement::Leaf(hashed_slot, value) => {
391                    hash_builder.add_leaf(
392                        Nibbles::unpack(hashed_slot),
393                        alloy_rlp::encode_fixed_size(&value).as_ref(),
394                    );
395                }
396            }
397        }
398
399        let root = hash_builder.root();
400        let subtree = hash_builder.take_proof_nodes();
401        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
402            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
403            (
404                updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
405                updated_branch_nodes
406                    .into_iter()
407                    .map(|(path, node)| (path, node.tree_mask))
408                    .collect(),
409            )
410        } else {
411            (HashMap::default(), HashMap::default())
412        };
413
414        Ok(StorageMultiProof { root, subtree, branch_node_hash_masks, branch_node_tree_masks })
415    }
416}