Skip to main content

reth_trie/proof/
mod.rs

1use crate::{
2    hashed_cursor::{
3        HashedCursorFactory, HashedCursorMetricsCache, HashedStorageCursor,
4        InstrumentedHashedCursor,
5    },
6    node_iter::{TrieElement, TrieNodeIter},
7    prefix_set::{PrefixSetMut, TriePrefixSetsMut},
8    proof_v2::{self, SyncAccountValueEncoder},
9    trie_cursor::{InstrumentedTrieCursor, TrieCursorFactory, TrieCursorMetricsCache},
10    walker::TrieWalker,
11    HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
12};
13use alloy_primitives::{
14    keccak256,
15    map::{B256Map, B256Set, HashSet},
16    Address, B256,
17};
18use alloy_rlp::{BufMut, Encodable};
19use alloy_trie::proof::AddedRemovedKeys;
20use reth_execution_errors::trie::StateProofError;
21use reth_trie_common::{
22    proof::ProofRetainer, AccountProof, BranchNodeMasks, BranchNodeMasksMap, DecodedMultiProofV2,
23    MultiProof, MultiProofTargets, MultiProofTargetsV2, StorageMultiProof,
24};
25
26/// A struct for generating merkle proofs.
27///
28/// Proof generator adds the target address and slots to the prefix set, enables the proof retainer
29/// on the hash builder and follows the same algorithm as the state root calculator.
30/// See `StateRoot::root` for more info.
31#[derive(Debug)]
32pub struct Proof<T, H, K = AddedRemovedKeys> {
33    /// The factory for traversing trie nodes.
34    trie_cursor_factory: T,
35    /// The factory for hashed cursors.
36    hashed_cursor_factory: H,
37    /// A set of prefix sets that have changes.
38    prefix_sets: TriePrefixSetsMut,
39    /// Flag indicating whether to include branch node masks in the proof.
40    collect_branch_node_masks: bool,
41    /// Added and removed keys for proof retention.
42    added_removed_keys: Option<K>,
43}
44
45impl<T, H> Proof<T, H> {
46    /// Create a new [`Proof`] instance.
47    pub fn new(t: T, h: H) -> Self {
48        Self {
49            trie_cursor_factory: t,
50            hashed_cursor_factory: h,
51            prefix_sets: TriePrefixSetsMut::default(),
52            collect_branch_node_masks: false,
53            added_removed_keys: None,
54        }
55    }
56}
57
58impl<T, H, K> Proof<T, H, K> {
59    /// Set the trie cursor factory.
60    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H, K> {
61        Proof {
62            trie_cursor_factory,
63            hashed_cursor_factory: self.hashed_cursor_factory,
64            prefix_sets: self.prefix_sets,
65            collect_branch_node_masks: self.collect_branch_node_masks,
66            added_removed_keys: self.added_removed_keys,
67        }
68    }
69
70    /// Set the hashed cursor factory.
71    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF, K> {
72        Proof {
73            trie_cursor_factory: self.trie_cursor_factory,
74            hashed_cursor_factory,
75            prefix_sets: self.prefix_sets,
76            collect_branch_node_masks: self.collect_branch_node_masks,
77            added_removed_keys: self.added_removed_keys,
78        }
79    }
80
81    /// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
82    pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
83        self.prefix_sets = prefix_sets;
84        self
85    }
86
87    /// Set the flag indicating whether to include branch node masks in the proof.
88    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
89        self.collect_branch_node_masks = branch_node_masks;
90        self
91    }
92
93    /// Configures the proof to retain certain nodes which would otherwise fall outside the target
94    /// set, when those nodes might be required to calculate the state root when keys have been
95    /// added or removed to the trie.
96    ///
97    /// If None is given then retention of extra proofs is disabled.
98    pub fn with_added_removed_keys<K2>(self, added_removed_keys: Option<K2>) -> Proof<T, H, K2> {
99        Proof {
100            trie_cursor_factory: self.trie_cursor_factory,
101            hashed_cursor_factory: self.hashed_cursor_factory,
102            prefix_sets: self.prefix_sets,
103            collect_branch_node_masks: self.collect_branch_node_masks,
104            added_removed_keys,
105        }
106    }
107
108    /// Get a reference to the trie cursor factory.
109    pub const fn trie_cursor_factory(&self) -> &T {
110        &self.trie_cursor_factory
111    }
112
113    /// Get a reference to the hashed cursor factory.
114    pub const fn hashed_cursor_factory(&self) -> &H {
115        &self.hashed_cursor_factory
116    }
117}
118
119impl<T, H, K> Proof<T, H, K>
120where
121    T: TrieCursorFactory + Clone,
122    H: HashedCursorFactory + Clone,
123    K: AsRef<AddedRemovedKeys>,
124{
125    /// Generate an account proof from intermediate nodes.
126    pub fn account_proof(
127        self,
128        address: Address,
129        slots: &[B256],
130    ) -> Result<AccountProof, StateProofError> {
131        Ok(self
132            .multiproof(MultiProofTargets::from_iter([(
133                keccak256(address),
134                slots.iter().map(keccak256).collect(),
135            )]))?
136            .account_proof(address, slots)?)
137    }
138
139    /// Generate a state multiproof using the V2 proof calculator.
140    ///
141    /// This method uses `ProofCalculator` with `SyncAccountValueEncoder` for account proofs
142    /// and `StorageProofCalculator` for storage proofs.
143    pub fn multiproof_v2(
144        self,
145        targets: MultiProofTargetsV2,
146    ) -> Result<DecodedMultiProofV2, StateProofError> {
147        let MultiProofTargetsV2 { mut account_targets, storage_targets } = targets;
148
149        let storage_prefix_sets: B256Map<_> = self
150            .prefix_sets
151            .storage_prefix_sets
152            .into_iter()
153            .map(|(addr, ps)| (addr, ps.freeze()))
154            .collect();
155
156        // Compute account proofs using the V2 proof calculator with sync account encoding.
157        let account_trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
158        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
159        let mut account_value_encoder = SyncAccountValueEncoder::new(
160            self.trie_cursor_factory.clone(),
161            self.hashed_cursor_factory.clone(),
162        )
163        .with_storage_prefix_sets(storage_prefix_sets.clone());
164        let mut account_calculator =
165            proof_v2::ProofCalculator::new(account_trie_cursor, hashed_account_cursor)
166                .with_prefix_set(self.prefix_sets.account_prefix_set.freeze());
167        let account_proofs =
168            account_calculator.proof(&mut account_value_encoder, &mut account_targets)?;
169
170        // Compute storage proofs for each targeted account.
171        let mut storage_proofs =
172            B256Map::with_capacity_and_hasher(storage_targets.len(), Default::default());
173        for (hashed_address, mut targets) in storage_targets {
174            let storage_trie_cursor =
175                self.trie_cursor_factory.storage_trie_cursor(hashed_address)?;
176            let hashed_storage_cursor =
177                self.hashed_cursor_factory.hashed_storage_cursor(hashed_address)?;
178            let mut storage_calculator = proof_v2::StorageProofCalculator::new_storage(
179                storage_trie_cursor,
180                hashed_storage_cursor,
181            );
182            if let Some(prefix_set) = storage_prefix_sets.get(&hashed_address) {
183                storage_calculator = storage_calculator.with_prefix_set(prefix_set.clone());
184            }
185            let proofs = storage_calculator.storage_proof(hashed_address, &mut targets)?;
186            storage_proofs.insert(hashed_address, proofs);
187        }
188
189        Ok(DecodedMultiProofV2 { account_proofs, storage_proofs })
190    }
191
192    /// Generate a state multiproof according to specified targets.
193    pub fn multiproof(
194        mut self,
195        mut targets: MultiProofTargets,
196    ) -> Result<MultiProof, StateProofError> {
197        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
198        let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
199
200        // Create the walker.
201        let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
202        prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
203        let walker =
204            TrieWalker::<_, AddedRemovedKeys>::state_trie(trie_cursor, prefix_set.freeze())
205                .with_added_removed_keys(self.added_removed_keys.as_ref());
206
207        // Create a hash builder to rebuild the root node since it is not available in the database.
208        let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
209        let retainer = retainer.with_added_removed_keys(self.added_removed_keys.as_ref());
210        let mut hash_builder = HashBuilder::default()
211            .with_proof_retainer(retainer)
212            .with_updates(self.collect_branch_node_masks);
213
214        // Initialize all storage multiproofs as empty.
215        // Storage multiproofs for non-empty tries will be overwritten if necessary.
216        let mut storages: B256Map<_> =
217            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
218        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
219        let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
220        while let Some(account_node) = account_node_iter.try_next()? {
221            match account_node {
222                TrieElement::Branch(node) => {
223                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
224                }
225                TrieElement::Leaf(hashed_address, account) => {
226                    let proof_targets = targets.remove(&hashed_address);
227                    let leaf_is_proof_target = proof_targets.is_some();
228                    let collect_storage_masks =
229                        self.collect_branch_node_masks && leaf_is_proof_target;
230                    let storage_prefix_set = self
231                        .prefix_sets
232                        .storage_prefix_sets
233                        .remove(&hashed_address)
234                        .unwrap_or_default();
235                    let storage_multiproof = StorageProof::new_hashed(
236                        self.trie_cursor_factory.clone(),
237                        self.hashed_cursor_factory.clone(),
238                        hashed_address,
239                    )
240                    .with_prefix_set_mut(storage_prefix_set)
241                    .with_branch_node_masks(collect_storage_masks)
242                    .storage_multiproof(proof_targets.unwrap_or_default())?;
243
244                    // Encode account
245                    account_rlp.clear();
246                    let account = account.into_trie_account(storage_multiproof.root);
247                    account.encode(&mut account_rlp as &mut dyn BufMut);
248
249                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
250
251                    // We might be adding leaves that are not necessarily our proof targets.
252                    if leaf_is_proof_target {
253                        // Overwrite storage multiproof.
254                        storages.insert(hashed_address, storage_multiproof);
255                    }
256                }
257            }
258        }
259        let _ = hash_builder.root();
260        let account_subtree = hash_builder.take_proof_nodes();
261        let branch_node_masks = if self.collect_branch_node_masks {
262            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
263            updated_branch_nodes
264                .into_iter()
265                .map(|(path, node)| {
266                    (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
267                })
268                .collect()
269        } else {
270            BranchNodeMasksMap::default()
271        };
272
273        Ok(MultiProof { account_subtree, branch_node_masks, storages })
274    }
275}
276
277/// Generates storage merkle proofs.
278#[derive(Debug)]
279pub struct StorageProof<'a, T, H, K = AddedRemovedKeys> {
280    /// The factory for traversing trie nodes.
281    trie_cursor_factory: T,
282    /// The factory for hashed cursors.
283    hashed_cursor_factory: H,
284    /// The hashed address of an account.
285    hashed_address: B256,
286    /// The set of storage slot prefixes that have changed.
287    prefix_set: PrefixSetMut,
288    /// Flag indicating whether to include branch node masks in the proof.
289    collect_branch_node_masks: bool,
290    /// Provided by the user to give the necessary context to retain extra proofs.
291    added_removed_keys: Option<K>,
292    /// Optional reference to accumulate trie cursor metrics.
293    trie_cursor_metrics: Option<&'a mut TrieCursorMetricsCache>,
294    /// Optional reference to accumulate hashed cursor metrics.
295    hashed_cursor_metrics: Option<&'a mut HashedCursorMetricsCache>,
296}
297
298impl<T, H> StorageProof<'static, T, H> {
299    /// Create a new [`StorageProof`] instance.
300    pub fn new(t: T, h: H, address: Address) -> Self {
301        Self::new_hashed(t, h, keccak256(address))
302    }
303
304    /// Create a new [`StorageProof`] instance with hashed address.
305    pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
306        Self {
307            trie_cursor_factory: t,
308            hashed_cursor_factory: h,
309            hashed_address,
310            prefix_set: PrefixSetMut::default(),
311            collect_branch_node_masks: false,
312            added_removed_keys: None,
313            trie_cursor_metrics: None,
314            hashed_cursor_metrics: None,
315        }
316    }
317}
318
319impl<'a, T, H, K> StorageProof<'a, T, H, K> {
320    /// Set the trie cursor factory.
321    pub fn with_trie_cursor_factory<TF>(
322        self,
323        trie_cursor_factory: TF,
324    ) -> StorageProof<'a, TF, H, K> {
325        StorageProof {
326            trie_cursor_factory,
327            hashed_cursor_factory: self.hashed_cursor_factory,
328            hashed_address: self.hashed_address,
329            prefix_set: self.prefix_set,
330            collect_branch_node_masks: self.collect_branch_node_masks,
331            added_removed_keys: self.added_removed_keys,
332            trie_cursor_metrics: self.trie_cursor_metrics,
333            hashed_cursor_metrics: self.hashed_cursor_metrics,
334        }
335    }
336
337    /// Set the hashed cursor factory.
338    pub fn with_hashed_cursor_factory<HF>(
339        self,
340        hashed_cursor_factory: HF,
341    ) -> StorageProof<'a, T, HF, K> {
342        StorageProof {
343            trie_cursor_factory: self.trie_cursor_factory,
344            hashed_cursor_factory,
345            hashed_address: self.hashed_address,
346            prefix_set: self.prefix_set,
347            collect_branch_node_masks: self.collect_branch_node_masks,
348            added_removed_keys: self.added_removed_keys,
349            trie_cursor_metrics: self.trie_cursor_metrics,
350            hashed_cursor_metrics: self.hashed_cursor_metrics,
351        }
352    }
353
354    /// Set the changed prefixes.
355    pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
356        self.prefix_set = prefix_set;
357        self
358    }
359
360    /// Set the flag indicating whether to include branch node masks in the proof.
361    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
362        self.collect_branch_node_masks = branch_node_masks;
363        self
364    }
365
366    /// Set the trie cursor metrics cache to accumulate metrics into.
367    pub const fn with_trie_cursor_metrics(
368        mut self,
369        metrics: &'a mut TrieCursorMetricsCache,
370    ) -> Self {
371        self.trie_cursor_metrics = Some(metrics);
372        self
373    }
374
375    /// Set the hashed cursor metrics cache to accumulate metrics into.
376    pub const fn with_hashed_cursor_metrics(
377        mut self,
378        metrics: &'a mut HashedCursorMetricsCache,
379    ) -> Self {
380        self.hashed_cursor_metrics = Some(metrics);
381        self
382    }
383
384    /// Configures the retainer to retain proofs for certain nodes which would otherwise fall
385    /// outside the target set, when those nodes might be required to calculate the state root when
386    /// keys have been added or removed to the trie.
387    ///
388    /// If None is given then retention of extra proofs is disabled.
389    pub fn with_added_removed_keys<K2>(
390        self,
391        added_removed_keys: Option<K2>,
392    ) -> StorageProof<'a, T, H, K2> {
393        StorageProof {
394            trie_cursor_factory: self.trie_cursor_factory,
395            hashed_cursor_factory: self.hashed_cursor_factory,
396            hashed_address: self.hashed_address,
397            prefix_set: self.prefix_set,
398            collect_branch_node_masks: self.collect_branch_node_masks,
399            added_removed_keys,
400            trie_cursor_metrics: self.trie_cursor_metrics,
401            hashed_cursor_metrics: self.hashed_cursor_metrics,
402        }
403    }
404}
405
406impl<'a, T, H, K> StorageProof<'a, T, H, K>
407where
408    T: TrieCursorFactory,
409    H: HashedCursorFactory,
410    K: AsRef<AddedRemovedKeys>,
411{
412    /// Generate an account proof from intermediate nodes.
413    pub fn storage_proof(
414        self,
415        slot: B256,
416    ) -> Result<reth_trie_common::StorageProof, StateProofError> {
417        let targets = HashSet::from_iter([keccak256(slot)]);
418        Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
419    }
420
421    /// Generate storage proof.
422    pub fn storage_multiproof(
423        self,
424        targets: B256Set,
425    ) -> Result<StorageMultiProof, StateProofError> {
426        let mut discard_hashed_cursor_metrics = HashedCursorMetricsCache::default();
427        let hashed_cursor_metrics =
428            self.hashed_cursor_metrics.unwrap_or(&mut discard_hashed_cursor_metrics);
429
430        let hashed_storage_cursor =
431            self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
432
433        let mut hashed_storage_cursor =
434            InstrumentedHashedCursor::new(hashed_storage_cursor, hashed_cursor_metrics);
435
436        // short circuit on empty storage
437        if hashed_storage_cursor.is_storage_empty()? {
438            return Ok(StorageMultiProof::empty())
439        }
440
441        let mut discard_trie_cursor_metrics = TrieCursorMetricsCache::default();
442        let trie_cursor_metrics =
443            self.trie_cursor_metrics.unwrap_or(&mut discard_trie_cursor_metrics);
444
445        let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
446        let mut prefix_set = self.prefix_set;
447        prefix_set.extend_keys(target_nibbles.iter().copied());
448
449        let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
450
451        let trie_cursor = InstrumentedTrieCursor::new(trie_cursor, trie_cursor_metrics);
452
453        let walker = TrieWalker::<_>::storage_trie(trie_cursor, prefix_set.freeze())
454            .with_added_removed_keys(self.added_removed_keys.as_ref());
455
456        let retainer = ProofRetainer::from_iter(target_nibbles)
457            .with_added_removed_keys(self.added_removed_keys.as_ref());
458        let mut hash_builder = HashBuilder::default()
459            .with_proof_retainer(retainer)
460            .with_updates(self.collect_branch_node_masks);
461        let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
462        while let Some(node) = storage_node_iter.try_next()? {
463            match node {
464                TrieElement::Branch(node) => {
465                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
466                }
467                TrieElement::Leaf(hashed_slot, value) => {
468                    hash_builder.add_leaf(
469                        Nibbles::unpack(hashed_slot),
470                        alloy_rlp::encode_fixed_size(&value).as_ref(),
471                    );
472                }
473            }
474        }
475
476        let root = hash_builder.root();
477        let subtree = hash_builder.take_proof_nodes();
478        let branch_node_masks = if self.collect_branch_node_masks {
479            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
480            updated_branch_nodes
481                .into_iter()
482                .map(|(path, node)| {
483                    (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
484                })
485                .collect()
486        } else {
487            BranchNodeMasksMap::default()
488        };
489
490        Ok(StorageMultiProof { root, subtree, branch_node_masks })
491    }
492}