Skip to main content

reth_trie/
witness.rs

1use crate::{
2    hashed_cursor::{HashedCursor, HashedCursorFactory},
3    proof::{Proof, ProofTrieNodeProviderFactory},
4    trie_cursor::TrieCursorFactory,
5};
6use alloy_rlp::EMPTY_STRING_CODE;
7use alloy_trie::EMPTY_ROOT_HASH;
8use reth_trie_common::HashedPostState;
9use reth_trie_sparse::SparseTrie;
10
11use alloy_primitives::{
12    keccak256,
13    map::{B256Map, B256Set, Entry, HashMap},
14    Bytes, B256,
15};
16use itertools::Itertools;
17use reth_execution_errors::{
18    SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
19    TrieWitnessError,
20};
21use reth_trie_common::{MultiProofTargets, Nibbles};
22use reth_trie_sparse::{
23    provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory},
24    SparseStateTrie,
25};
26use std::sync::mpsc;
27
28/// State transition witness for the trie.
29#[derive(Debug)]
30pub struct TrieWitness<T, H> {
31    /// The cursor factory for traversing trie nodes.
32    trie_cursor_factory: T,
33    /// The factory for hashed cursors.
34    hashed_cursor_factory: H,
35    /// Flag indicating whether the root node should always be included (even if the target state
36    /// is empty). This setting is useful if the caller wants to verify the witness against the
37    /// parent state root.
38    /// Set to `false` by default.
39    always_include_root_node: bool,
40    /// Recorded witness.
41    witness: B256Map<Bytes>,
42}
43
44impl<T, H> TrieWitness<T, H> {
45    /// Creates a new witness generator.
46    pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
47        Self {
48            trie_cursor_factory,
49            hashed_cursor_factory,
50            always_include_root_node: false,
51            witness: HashMap::default(),
52        }
53    }
54
55    /// Set the trie cursor factory.
56    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
57        TrieWitness {
58            trie_cursor_factory,
59            hashed_cursor_factory: self.hashed_cursor_factory,
60            always_include_root_node: self.always_include_root_node,
61            witness: self.witness,
62        }
63    }
64
65    /// Set the hashed cursor factory.
66    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
67        TrieWitness {
68            trie_cursor_factory: self.trie_cursor_factory,
69            hashed_cursor_factory,
70            always_include_root_node: self.always_include_root_node,
71            witness: self.witness,
72        }
73    }
74
75    /// Set `always_include_root_node` to true. Root node will be included even in empty state.
76    /// This setting is useful if the caller wants to verify the witness against the
77    /// parent state root.
78    pub const fn always_include_root_node(mut self) -> Self {
79        self.always_include_root_node = true;
80        self
81    }
82}
83
84impl<T, H> TrieWitness<T, H>
85where
86    T: TrieCursorFactory + Clone,
87    H: HashedCursorFactory + Clone,
88{
89    /// Compute the state transition witness for the trie. Gather all required nodes
90    /// to apply `state` on top of the current trie state.
91    ///
92    /// # Arguments
93    ///
94    /// `state` - state transition containing both modified and touched accounts and storage slots.
95    pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
96        let is_state_empty = state.is_empty();
97        if is_state_empty && !self.always_include_root_node {
98            return Ok(Default::default())
99        }
100
101        let proof_targets = if is_state_empty {
102            MultiProofTargets::account(B256::ZERO)
103        } else {
104            self.get_proof_targets(&state)?
105        };
106        let multiproof =
107            Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
108                .multiproof(proof_targets.clone())?;
109
110        // No need to reconstruct the rest of the trie, we just need to include
111        // the root node and return.
112        if is_state_empty {
113            let (root_hash, root_node) = if let Some(root_node) =
114                multiproof.account_subtree.into_inner().remove(&Nibbles::default())
115            {
116                (keccak256(&root_node), root_node)
117            } else {
118                (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
119            };
120            return Ok(B256Map::from_iter([(root_hash, root_node)]))
121        }
122
123        // Record all nodes from multiproof in the witness
124        for account_node in multiproof.account_subtree.values() {
125            if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
126                entry.insert(account_node.clone());
127            }
128        }
129        for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
130            if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
131                entry.insert(storage_node.clone());
132            }
133        }
134
135        let (tx, rx) = mpsc::channel();
136        let blinded_provider_factory = WitnessTrieNodeProviderFactory::new(
137            ProofTrieNodeProviderFactory::new(self.trie_cursor_factory, self.hashed_cursor_factory),
138            tx,
139        );
140        let mut sparse_trie = SparseStateTrie::new();
141        sparse_trie.reveal_multiproof(multiproof)?;
142
143        // Attempt to update state trie to gather additional information for the witness.
144        for (hashed_address, hashed_slots) in
145            proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
146        {
147            // Update storage trie first.
148            let provider = blinded_provider_factory.storage_node_provider(hashed_address);
149            let storage = state.storages.get(&hashed_address);
150            let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
151                SparseStateTrieErrorKind::SparseStorageTrie(
152                    hashed_address,
153                    SparseTrieErrorKind::Blind,
154                ),
155            )?;
156            for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
157                let storage_nibbles = Nibbles::unpack(hashed_slot);
158                let maybe_leaf_value = storage
159                    .and_then(|s| s.storage.get(&hashed_slot))
160                    .filter(|v| !v.is_zero())
161                    .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
162
163                if let Some(value) = maybe_leaf_value {
164                    storage_trie.update_leaf(storage_nibbles, value, &provider).map_err(|err| {
165                        SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
166                    })?;
167                } else {
168                    storage_trie.remove_leaf(&storage_nibbles, &provider).map_err(|err| {
169                        SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
170                    })?;
171                }
172            }
173
174            let account = state
175                .accounts
176                .get(&hashed_address)
177                .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
178                .unwrap_or_default();
179
180            if !sparse_trie.update_account(hashed_address, account, &blinded_provider_factory)? {
181                let nibbles = Nibbles::unpack(hashed_address);
182                sparse_trie.remove_account_leaf(&nibbles, &blinded_provider_factory)?;
183            }
184
185            while let Ok(node) = rx.try_recv() {
186                self.witness.insert(keccak256(&node), node);
187            }
188        }
189
190        Ok(self.witness)
191    }
192
193    /// Retrieve proof targets for incoming hashed state.
194    /// This method will aggregate all accounts and slots present in the hash state as well as
195    /// select all existing slots from the database for the accounts that have been destroyed.
196    fn get_proof_targets(
197        &self,
198        state: &HashedPostState,
199    ) -> Result<MultiProofTargets, StateProofError> {
200        let mut proof_targets = MultiProofTargets::default();
201        for hashed_address in state.accounts.keys() {
202            proof_targets.insert(*hashed_address, B256Set::default());
203        }
204        for (hashed_address, storage) in &state.storages {
205            let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
206            if storage.wiped {
207                // storage for this account was destroyed, gather all slots from the current state
208                let mut storage_cursor =
209                    self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
210                // position cursor at the start
211                let mut current_entry = storage_cursor.seek(B256::ZERO)?;
212                while let Some((hashed_slot, _)) = current_entry {
213                    storage_keys.insert(hashed_slot);
214                    current_entry = storage_cursor.next()?;
215                }
216            }
217            proof_targets.insert(*hashed_address, storage_keys);
218        }
219        Ok(proof_targets)
220    }
221}
222
223#[derive(Debug, Clone)]
224struct WitnessTrieNodeProviderFactory<F> {
225    /// Trie node provider factory.
226    provider_factory: F,
227    /// Sender for forwarding fetched trie node.
228    tx: mpsc::Sender<Bytes>,
229}
230
231impl<F> WitnessTrieNodeProviderFactory<F> {
232    const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
233        Self { provider_factory, tx }
234    }
235}
236
237impl<F> TrieNodeProviderFactory for WitnessTrieNodeProviderFactory<F>
238where
239    F: TrieNodeProviderFactory,
240    F::AccountNodeProvider: TrieNodeProvider,
241    F::StorageNodeProvider: TrieNodeProvider,
242{
243    type AccountNodeProvider = WitnessTrieNodeProvider<F::AccountNodeProvider>;
244    type StorageNodeProvider = WitnessTrieNodeProvider<F::StorageNodeProvider>;
245
246    fn account_node_provider(&self) -> Self::AccountNodeProvider {
247        let provider = self.provider_factory.account_node_provider();
248        WitnessTrieNodeProvider::new(provider, self.tx.clone())
249    }
250
251    fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
252        let provider = self.provider_factory.storage_node_provider(account);
253        WitnessTrieNodeProvider::new(provider, self.tx.clone())
254    }
255}
256
257#[derive(Debug)]
258struct WitnessTrieNodeProvider<P> {
259    /// Proof-based blinded.
260    provider: P,
261    /// Sender for forwarding fetched blinded node.
262    tx: mpsc::Sender<Bytes>,
263}
264
265impl<P> WitnessTrieNodeProvider<P> {
266    const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
267        Self { provider, tx }
268    }
269}
270
271impl<P: TrieNodeProvider> TrieNodeProvider for WitnessTrieNodeProvider<P> {
272    fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
273        let maybe_node = self.provider.trie_node(path)?;
274        if let Some(node) = &maybe_node {
275            self.tx
276                .send(node.node.clone())
277                .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
278        }
279        Ok(maybe_node)
280    }
281}