Skip to main content

reth_trie/proof/
mod.rs

1use crate::{
2    hashed_cursor::{
3        HashedCursorFactory, HashedCursorMetricsCache, HashedStorageCursor,
4        InstrumentedHashedCursor,
5    },
6    node_iter::{TrieElement, TrieNodeIter},
7    prefix_set::{PrefixSetMut, TriePrefixSetsMut},
8    proof_v2::{self, SyncAccountValueEncoder},
9    trie_cursor::{InstrumentedTrieCursor, TrieCursorFactory, TrieCursorMetricsCache},
10    walker::TrieWalker,
11    HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
12};
13use alloy_primitives::{
14    keccak256,
15    map::{B256Map, B256Set, HashSet},
16    Address, B256,
17};
18use alloy_rlp::{BufMut, Encodable};
19use alloy_trie::proof::AddedRemovedKeys;
20use reth_execution_errors::trie::StateProofError;
21use reth_trie_common::{
22    proof::ProofRetainer, AccountProof, BranchNodeMasks, BranchNodeMasksMap, DecodedMultiProofV2,
23    MultiProof, MultiProofTargets, MultiProofTargetsV2, StorageMultiProof,
24};
25
26mod trie_node;
27pub use trie_node::*;
28
29/// A struct for generating merkle proofs.
30///
31/// Proof generator adds the target address and slots to the prefix set, enables the proof retainer
32/// on the hash builder and follows the same algorithm as the state root calculator.
33/// See `StateRoot::root` for more info.
34#[derive(Debug)]
35pub struct Proof<T, H, K = AddedRemovedKeys> {
36    /// The factory for traversing trie nodes.
37    trie_cursor_factory: T,
38    /// The factory for hashed cursors.
39    hashed_cursor_factory: H,
40    /// A set of prefix sets that have changes.
41    prefix_sets: TriePrefixSetsMut,
42    /// Flag indicating whether to include branch node masks in the proof.
43    collect_branch_node_masks: bool,
44    /// Added and removed keys for proof retention.
45    added_removed_keys: Option<K>,
46}
47
48impl<T, H> Proof<T, H> {
49    /// Create a new [`Proof`] instance.
50    pub fn new(t: T, h: H) -> Self {
51        Self {
52            trie_cursor_factory: t,
53            hashed_cursor_factory: h,
54            prefix_sets: TriePrefixSetsMut::default(),
55            collect_branch_node_masks: false,
56            added_removed_keys: None,
57        }
58    }
59}
60
61impl<T, H, K> Proof<T, H, K> {
62    /// Set the trie cursor factory.
63    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H, K> {
64        Proof {
65            trie_cursor_factory,
66            hashed_cursor_factory: self.hashed_cursor_factory,
67            prefix_sets: self.prefix_sets,
68            collect_branch_node_masks: self.collect_branch_node_masks,
69            added_removed_keys: self.added_removed_keys,
70        }
71    }
72
73    /// Set the hashed cursor factory.
74    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF, K> {
75        Proof {
76            trie_cursor_factory: self.trie_cursor_factory,
77            hashed_cursor_factory,
78            prefix_sets: self.prefix_sets,
79            collect_branch_node_masks: self.collect_branch_node_masks,
80            added_removed_keys: self.added_removed_keys,
81        }
82    }
83
84    /// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
85    pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
86        self.prefix_sets = prefix_sets;
87        self
88    }
89
90    /// Set the flag indicating whether to include branch node masks in the proof.
91    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
92        self.collect_branch_node_masks = branch_node_masks;
93        self
94    }
95
96    /// Configures the proof to retain certain nodes which would otherwise fall outside the target
97    /// set, when those nodes might be required to calculate the state root when keys have been
98    /// added or removed to the trie.
99    ///
100    /// If None is given then retention of extra proofs is disabled.
101    pub fn with_added_removed_keys<K2>(self, added_removed_keys: Option<K2>) -> Proof<T, H, K2> {
102        Proof {
103            trie_cursor_factory: self.trie_cursor_factory,
104            hashed_cursor_factory: self.hashed_cursor_factory,
105            prefix_sets: self.prefix_sets,
106            collect_branch_node_masks: self.collect_branch_node_masks,
107            added_removed_keys,
108        }
109    }
110
111    /// Get a reference to the trie cursor factory.
112    pub const fn trie_cursor_factory(&self) -> &T {
113        &self.trie_cursor_factory
114    }
115
116    /// Get a reference to the hashed cursor factory.
117    pub const fn hashed_cursor_factory(&self) -> &H {
118        &self.hashed_cursor_factory
119    }
120}
121
122impl<T, H, K> Proof<T, H, K>
123where
124    T: TrieCursorFactory + Clone,
125    H: HashedCursorFactory + Clone,
126    K: AsRef<AddedRemovedKeys>,
127{
128    /// Generate an account proof from intermediate nodes.
129    pub fn account_proof(
130        self,
131        address: Address,
132        slots: &[B256],
133    ) -> Result<AccountProof, StateProofError> {
134        Ok(self
135            .multiproof(MultiProofTargets::from_iter([(
136                keccak256(address),
137                slots.iter().map(keccak256).collect(),
138            )]))?
139            .account_proof(address, slots)?)
140    }
141
142    /// Generate a state multiproof using the V2 proof calculator.
143    ///
144    /// This method uses `ProofCalculator` with `SyncAccountValueEncoder` for account proofs
145    /// and `StorageProofCalculator` for storage proofs.
146    pub fn multiproof_v2(
147        self,
148        targets: MultiProofTargetsV2,
149    ) -> Result<DecodedMultiProofV2, StateProofError> {
150        let MultiProofTargetsV2 { mut account_targets, storage_targets } = targets;
151
152        let storage_prefix_sets: B256Map<_> = self
153            .prefix_sets
154            .storage_prefix_sets
155            .into_iter()
156            .map(|(addr, ps)| (addr, ps.freeze()))
157            .collect();
158
159        // Compute account proofs using the V2 proof calculator with sync account encoding.
160        let account_trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
161        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
162        let mut account_value_encoder = SyncAccountValueEncoder::new(
163            self.trie_cursor_factory.clone(),
164            self.hashed_cursor_factory.clone(),
165        )
166        .with_storage_prefix_sets(storage_prefix_sets.clone());
167        let mut account_calculator =
168            proof_v2::ProofCalculator::new(account_trie_cursor, hashed_account_cursor)
169                .with_prefix_set(self.prefix_sets.account_prefix_set.freeze());
170        let account_proofs =
171            account_calculator.proof(&mut account_value_encoder, &mut account_targets)?;
172
173        // Compute storage proofs for each targeted account.
174        let mut storage_proofs =
175            B256Map::with_capacity_and_hasher(storage_targets.len(), Default::default());
176        for (hashed_address, mut targets) in storage_targets {
177            let storage_trie_cursor =
178                self.trie_cursor_factory.storage_trie_cursor(hashed_address)?;
179            let hashed_storage_cursor =
180                self.hashed_cursor_factory.hashed_storage_cursor(hashed_address)?;
181            let mut storage_calculator = proof_v2::StorageProofCalculator::new_storage(
182                storage_trie_cursor,
183                hashed_storage_cursor,
184            );
185            if let Some(prefix_set) = storage_prefix_sets.get(&hashed_address) {
186                storage_calculator = storage_calculator.with_prefix_set(prefix_set.clone());
187            }
188            let proofs = storage_calculator.storage_proof(hashed_address, &mut targets)?;
189            storage_proofs.insert(hashed_address, proofs);
190        }
191
192        Ok(DecodedMultiProofV2 { account_proofs, storage_proofs })
193    }
194
195    /// Generate a state multiproof according to specified targets.
196    pub fn multiproof(
197        mut self,
198        mut targets: MultiProofTargets,
199    ) -> Result<MultiProof, StateProofError> {
200        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
201        let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
202
203        // Create the walker.
204        let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
205        prefix_set.extend_keys(targets.keys().map(Nibbles::unpack));
206        let walker =
207            TrieWalker::<_, AddedRemovedKeys>::state_trie(trie_cursor, prefix_set.freeze())
208                .with_added_removed_keys(self.added_removed_keys.as_ref());
209
210        // Create a hash builder to rebuild the root node since it is not available in the database.
211        let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
212        let retainer = retainer.with_added_removed_keys(self.added_removed_keys.as_ref());
213        let mut hash_builder = HashBuilder::default()
214            .with_proof_retainer(retainer)
215            .with_updates(self.collect_branch_node_masks);
216
217        // Initialize all storage multiproofs as empty.
218        // Storage multiproofs for non-empty tries will be overwritten if necessary.
219        let mut storages: B256Map<_> =
220            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
221        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
222        let mut account_node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
223        while let Some(account_node) = account_node_iter.try_next()? {
224            match account_node {
225                TrieElement::Branch(node) => {
226                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
227                }
228                TrieElement::Leaf(hashed_address, account) => {
229                    let proof_targets = targets.remove(&hashed_address);
230                    let leaf_is_proof_target = proof_targets.is_some();
231                    let collect_storage_masks =
232                        self.collect_branch_node_masks && leaf_is_proof_target;
233                    let storage_prefix_set = self
234                        .prefix_sets
235                        .storage_prefix_sets
236                        .remove(&hashed_address)
237                        .unwrap_or_default();
238                    let storage_multiproof = StorageProof::new_hashed(
239                        self.trie_cursor_factory.clone(),
240                        self.hashed_cursor_factory.clone(),
241                        hashed_address,
242                    )
243                    .with_prefix_set_mut(storage_prefix_set)
244                    .with_branch_node_masks(collect_storage_masks)
245                    .storage_multiproof(proof_targets.unwrap_or_default())?;
246
247                    // Encode account
248                    account_rlp.clear();
249                    let account = account.into_trie_account(storage_multiproof.root);
250                    account.encode(&mut account_rlp as &mut dyn BufMut);
251
252                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
253
254                    // We might be adding leaves that are not necessarily our proof targets.
255                    if leaf_is_proof_target {
256                        // Overwrite storage multiproof.
257                        storages.insert(hashed_address, storage_multiproof);
258                    }
259                }
260            }
261        }
262        let _ = hash_builder.root();
263        let account_subtree = hash_builder.take_proof_nodes();
264        let branch_node_masks = if self.collect_branch_node_masks {
265            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
266            updated_branch_nodes
267                .into_iter()
268                .map(|(path, node)| {
269                    (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
270                })
271                .collect()
272        } else {
273            BranchNodeMasksMap::default()
274        };
275
276        Ok(MultiProof { account_subtree, branch_node_masks, storages })
277    }
278}
279
280/// Generates storage merkle proofs.
281#[derive(Debug)]
282pub struct StorageProof<'a, T, H, K = AddedRemovedKeys> {
283    /// The factory for traversing trie nodes.
284    trie_cursor_factory: T,
285    /// The factory for hashed cursors.
286    hashed_cursor_factory: H,
287    /// The hashed address of an account.
288    hashed_address: B256,
289    /// The set of storage slot prefixes that have changed.
290    prefix_set: PrefixSetMut,
291    /// Flag indicating whether to include branch node masks in the proof.
292    collect_branch_node_masks: bool,
293    /// Provided by the user to give the necessary context to retain extra proofs.
294    added_removed_keys: Option<K>,
295    /// Optional reference to accumulate trie cursor metrics.
296    trie_cursor_metrics: Option<&'a mut TrieCursorMetricsCache>,
297    /// Optional reference to accumulate hashed cursor metrics.
298    hashed_cursor_metrics: Option<&'a mut HashedCursorMetricsCache>,
299}
300
301impl<T, H> StorageProof<'static, T, H> {
302    /// Create a new [`StorageProof`] instance.
303    pub fn new(t: T, h: H, address: Address) -> Self {
304        Self::new_hashed(t, h, keccak256(address))
305    }
306
307    /// Create a new [`StorageProof`] instance with hashed address.
308    pub fn new_hashed(t: T, h: H, hashed_address: B256) -> Self {
309        Self {
310            trie_cursor_factory: t,
311            hashed_cursor_factory: h,
312            hashed_address,
313            prefix_set: PrefixSetMut::default(),
314            collect_branch_node_masks: false,
315            added_removed_keys: None,
316            trie_cursor_metrics: None,
317            hashed_cursor_metrics: None,
318        }
319    }
320}
321
322impl<'a, T, H, K> StorageProof<'a, T, H, K> {
323    /// Set the trie cursor factory.
324    pub fn with_trie_cursor_factory<TF>(
325        self,
326        trie_cursor_factory: TF,
327    ) -> StorageProof<'a, TF, H, K> {
328        StorageProof {
329            trie_cursor_factory,
330            hashed_cursor_factory: self.hashed_cursor_factory,
331            hashed_address: self.hashed_address,
332            prefix_set: self.prefix_set,
333            collect_branch_node_masks: self.collect_branch_node_masks,
334            added_removed_keys: self.added_removed_keys,
335            trie_cursor_metrics: self.trie_cursor_metrics,
336            hashed_cursor_metrics: self.hashed_cursor_metrics,
337        }
338    }
339
340    /// Set the hashed cursor factory.
341    pub fn with_hashed_cursor_factory<HF>(
342        self,
343        hashed_cursor_factory: HF,
344    ) -> StorageProof<'a, T, HF, K> {
345        StorageProof {
346            trie_cursor_factory: self.trie_cursor_factory,
347            hashed_cursor_factory,
348            hashed_address: self.hashed_address,
349            prefix_set: self.prefix_set,
350            collect_branch_node_masks: self.collect_branch_node_masks,
351            added_removed_keys: self.added_removed_keys,
352            trie_cursor_metrics: self.trie_cursor_metrics,
353            hashed_cursor_metrics: self.hashed_cursor_metrics,
354        }
355    }
356
357    /// Set the changed prefixes.
358    pub fn with_prefix_set_mut(mut self, prefix_set: PrefixSetMut) -> Self {
359        self.prefix_set = prefix_set;
360        self
361    }
362
363    /// Set the flag indicating whether to include branch node masks in the proof.
364    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
365        self.collect_branch_node_masks = branch_node_masks;
366        self
367    }
368
369    /// Set the trie cursor metrics cache to accumulate metrics into.
370    pub const fn with_trie_cursor_metrics(
371        mut self,
372        metrics: &'a mut TrieCursorMetricsCache,
373    ) -> Self {
374        self.trie_cursor_metrics = Some(metrics);
375        self
376    }
377
378    /// Set the hashed cursor metrics cache to accumulate metrics into.
379    pub const fn with_hashed_cursor_metrics(
380        mut self,
381        metrics: &'a mut HashedCursorMetricsCache,
382    ) -> Self {
383        self.hashed_cursor_metrics = Some(metrics);
384        self
385    }
386
387    /// Configures the retainer to retain proofs for certain nodes which would otherwise fall
388    /// outside the target set, when those nodes might be required to calculate the state root when
389    /// keys have been added or removed to the trie.
390    ///
391    /// If None is given then retention of extra proofs is disabled.
392    pub fn with_added_removed_keys<K2>(
393        self,
394        added_removed_keys: Option<K2>,
395    ) -> StorageProof<'a, T, H, K2> {
396        StorageProof {
397            trie_cursor_factory: self.trie_cursor_factory,
398            hashed_cursor_factory: self.hashed_cursor_factory,
399            hashed_address: self.hashed_address,
400            prefix_set: self.prefix_set,
401            collect_branch_node_masks: self.collect_branch_node_masks,
402            added_removed_keys,
403            trie_cursor_metrics: self.trie_cursor_metrics,
404            hashed_cursor_metrics: self.hashed_cursor_metrics,
405        }
406    }
407}
408
409impl<'a, T, H, K> StorageProof<'a, T, H, K>
410where
411    T: TrieCursorFactory,
412    H: HashedCursorFactory,
413    K: AsRef<AddedRemovedKeys>,
414{
415    /// Generate an account proof from intermediate nodes.
416    pub fn storage_proof(
417        self,
418        slot: B256,
419    ) -> Result<reth_trie_common::StorageProof, StateProofError> {
420        let targets = HashSet::from_iter([keccak256(slot)]);
421        Ok(self.storage_multiproof(targets)?.storage_proof(slot)?)
422    }
423
424    /// Generate storage proof.
425    pub fn storage_multiproof(
426        self,
427        targets: B256Set,
428    ) -> Result<StorageMultiProof, StateProofError> {
429        let mut discard_hashed_cursor_metrics = HashedCursorMetricsCache::default();
430        let hashed_cursor_metrics =
431            self.hashed_cursor_metrics.unwrap_or(&mut discard_hashed_cursor_metrics);
432
433        let hashed_storage_cursor =
434            self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
435
436        let mut hashed_storage_cursor =
437            InstrumentedHashedCursor::new(hashed_storage_cursor, hashed_cursor_metrics);
438
439        // short circuit on empty storage
440        if hashed_storage_cursor.is_storage_empty()? {
441            return Ok(StorageMultiProof::empty())
442        }
443
444        let mut discard_trie_cursor_metrics = TrieCursorMetricsCache::default();
445        let trie_cursor_metrics =
446            self.trie_cursor_metrics.unwrap_or(&mut discard_trie_cursor_metrics);
447
448        let target_nibbles = targets.into_iter().map(Nibbles::unpack).collect::<Vec<_>>();
449        let mut prefix_set = self.prefix_set;
450        prefix_set.extend_keys(target_nibbles.iter().copied());
451
452        let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
453
454        let trie_cursor = InstrumentedTrieCursor::new(trie_cursor, trie_cursor_metrics);
455
456        let walker = TrieWalker::<_>::storage_trie(trie_cursor, prefix_set.freeze())
457            .with_added_removed_keys(self.added_removed_keys.as_ref());
458
459        let retainer = ProofRetainer::from_iter(target_nibbles)
460            .with_added_removed_keys(self.added_removed_keys.as_ref());
461        let mut hash_builder = HashBuilder::default()
462            .with_proof_retainer(retainer)
463            .with_updates(self.collect_branch_node_masks);
464        let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
465        while let Some(node) = storage_node_iter.try_next()? {
466            match node {
467                TrieElement::Branch(node) => {
468                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
469                }
470                TrieElement::Leaf(hashed_slot, value) => {
471                    hash_builder.add_leaf(
472                        Nibbles::unpack(hashed_slot),
473                        alloy_rlp::encode_fixed_size(&value).as_ref(),
474                    );
475                }
476            }
477        }
478
479        let root = hash_builder.root();
480        let subtree = hash_builder.take_proof_nodes();
481        let branch_node_masks = if self.collect_branch_node_masks {
482            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
483            updated_branch_nodes
484                .into_iter()
485                .map(|(path, node)| {
486                    (path, BranchNodeMasks { hash_mask: node.hash_mask, tree_mask: node.tree_mask })
487                })
488                .collect()
489        } else {
490            BranchNodeMasksMap::default()
491        };
492
493        Ok(StorageMultiProof { root, subtree, branch_node_masks })
494    }
495}