reth_trie/
witness.rs

1use crate::{
2    hashed_cursor::{HashedCursor, HashedCursorFactory},
3    prefix_set::TriePrefixSetsMut,
4    proof::{Proof, ProofTrieNodeProviderFactory},
5    trie_cursor::TrieCursorFactory,
6};
7use alloy_rlp::EMPTY_STRING_CODE;
8use alloy_trie::EMPTY_ROOT_HASH;
9use reth_trie_common::HashedPostState;
10use reth_trie_sparse::SparseTrieInterface;
11
12use alloy_primitives::{
13    keccak256,
14    map::{B256Map, B256Set, Entry, HashMap},
15    Bytes, B256,
16};
17use itertools::Itertools;
18use reth_execution_errors::{
19    SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
20    TrieWitnessError,
21};
22use reth_trie_common::{MultiProofTargets, Nibbles};
23use reth_trie_sparse::{
24    provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory},
25    SerialSparseTrie, SparseStateTrie,
26};
27use std::sync::{mpsc, Arc};
28
29/// State transition witness for the trie.
30#[derive(Debug)]
31pub struct TrieWitness<T, H> {
32    /// The cursor factory for traversing trie nodes.
33    trie_cursor_factory: T,
34    /// The factory for hashed cursors.
35    hashed_cursor_factory: H,
36    /// A set of prefix sets that have changes.
37    prefix_sets: TriePrefixSetsMut,
38    /// Flag indicating whether the root node should always be included (even if the target state
39    /// is empty). This setting is useful if the caller wants to verify the witness against the
40    /// parent state root.
41    /// Set to `false` by default.
42    always_include_root_node: bool,
43    /// Recorded witness.
44    witness: B256Map<Bytes>,
45}
46
47impl<T, H> TrieWitness<T, H> {
48    /// Creates a new witness generator.
49    pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
50        Self {
51            trie_cursor_factory,
52            hashed_cursor_factory,
53            prefix_sets: TriePrefixSetsMut::default(),
54            always_include_root_node: false,
55            witness: HashMap::default(),
56        }
57    }
58
59    /// Set the trie cursor factory.
60    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
61        TrieWitness {
62            trie_cursor_factory,
63            hashed_cursor_factory: self.hashed_cursor_factory,
64            prefix_sets: self.prefix_sets,
65            always_include_root_node: self.always_include_root_node,
66            witness: self.witness,
67        }
68    }
69
70    /// Set the hashed cursor factory.
71    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
72        TrieWitness {
73            trie_cursor_factory: self.trie_cursor_factory,
74            hashed_cursor_factory,
75            prefix_sets: self.prefix_sets,
76            always_include_root_node: self.always_include_root_node,
77            witness: self.witness,
78        }
79    }
80
81    /// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
82    pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
83        self.prefix_sets = prefix_sets;
84        self
85    }
86
87    /// Set `always_include_root_node` to true. Root node will be included even on empty state.
88    /// This setting is useful if the caller wants to verify the witness against the
89    /// parent state root.
90    pub const fn always_include_root_node(mut self) -> Self {
91        self.always_include_root_node = true;
92        self
93    }
94}
95
96impl<T, H> TrieWitness<T, H>
97where
98    T: TrieCursorFactory + Clone + Send + Sync,
99    H: HashedCursorFactory + Clone + Send + Sync,
100{
101    /// Compute the state transition witness for the trie. Gather all required nodes
102    /// to apply `state` on top of the current trie state.
103    ///
104    /// # Arguments
105    ///
106    /// `state` - state transition containing both modified and touched accounts and storage slots.
107    pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
108        let is_state_empty = state.is_empty();
109        if is_state_empty && !self.always_include_root_node {
110            return Ok(Default::default())
111        }
112
113        let proof_targets = if is_state_empty {
114            MultiProofTargets::account(B256::ZERO)
115        } else {
116            self.get_proof_targets(&state)?
117        };
118        let multiproof =
119            Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
120                .with_prefix_sets_mut(self.prefix_sets.clone())
121                .multiproof(proof_targets.clone())?;
122
123        // No need to reconstruct the rest of the trie, we just need to include
124        // the root node and return.
125        if is_state_empty {
126            let (root_hash, root_node) = if let Some(root_node) =
127                multiproof.account_subtree.into_inner().remove(&Nibbles::default())
128            {
129                (keccak256(&root_node), root_node)
130            } else {
131                (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
132            };
133            return Ok(B256Map::from_iter([(root_hash, root_node)]))
134        }
135
136        // Record all nodes from multiproof in the witness
137        for account_node in multiproof.account_subtree.values() {
138            if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
139                entry.insert(account_node.clone());
140            }
141        }
142        for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
143            if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
144                entry.insert(storage_node.clone());
145            }
146        }
147
148        let (tx, rx) = mpsc::channel();
149        let blinded_provider_factory = WitnessTrieNodeProviderFactory::new(
150            ProofTrieNodeProviderFactory::new(
151                self.trie_cursor_factory,
152                self.hashed_cursor_factory,
153                Arc::new(self.prefix_sets),
154            ),
155            tx,
156        );
157        let mut sparse_trie = SparseStateTrie::<SerialSparseTrie>::new();
158        sparse_trie.reveal_multiproof(multiproof)?;
159
160        // Attempt to update state trie to gather additional information for the witness.
161        for (hashed_address, hashed_slots) in
162            proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
163        {
164            // Update storage trie first.
165            let provider = blinded_provider_factory.storage_node_provider(hashed_address);
166            let storage = state.storages.get(&hashed_address);
167            let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
168                SparseStateTrieErrorKind::SparseStorageTrie(
169                    hashed_address,
170                    SparseTrieErrorKind::Blind,
171                ),
172            )?;
173            for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
174                let storage_nibbles = Nibbles::unpack(hashed_slot);
175                let maybe_leaf_value = storage
176                    .and_then(|s| s.storage.get(&hashed_slot))
177                    .filter(|v| !v.is_zero())
178                    .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
179
180                if let Some(value) = maybe_leaf_value {
181                    storage_trie.update_leaf(storage_nibbles, value, &provider).map_err(|err| {
182                        SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
183                    })?;
184                } else {
185                    storage_trie.remove_leaf(&storage_nibbles, &provider).map_err(|err| {
186                        SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
187                    })?;
188                }
189            }
190
191            // Calculate storage root after updates.
192            storage_trie.root();
193
194            let account = state
195                .accounts
196                .get(&hashed_address)
197                .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
198                .unwrap_or_default();
199
200            if !sparse_trie.update_account(hashed_address, account, &blinded_provider_factory)? {
201                let nibbles = Nibbles::unpack(hashed_address);
202                sparse_trie.remove_account_leaf(&nibbles, &blinded_provider_factory)?;
203            }
204
205            while let Ok(node) = rx.try_recv() {
206                self.witness.insert(keccak256(&node), node);
207            }
208        }
209
210        Ok(self.witness)
211    }
212
213    /// Retrieve proof targets for incoming hashed state.
214    /// This method will aggregate all accounts and slots present in the hash state as well as
215    /// select all existing slots from the database for the accounts that have been destroyed.
216    fn get_proof_targets(
217        &self,
218        state: &HashedPostState,
219    ) -> Result<MultiProofTargets, StateProofError> {
220        let mut proof_targets = MultiProofTargets::default();
221        for hashed_address in state.accounts.keys() {
222            proof_targets.insert(*hashed_address, B256Set::default());
223        }
224        for (hashed_address, storage) in &state.storages {
225            let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
226            if storage.wiped {
227                // storage for this account was destroyed, gather all slots from the current state
228                let mut storage_cursor =
229                    self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
230                // position cursor at the start
231                let mut current_entry = storage_cursor.seek(B256::ZERO)?;
232                while let Some((hashed_slot, _)) = current_entry {
233                    storage_keys.insert(hashed_slot);
234                    current_entry = storage_cursor.next()?;
235                }
236            }
237            proof_targets.insert(*hashed_address, storage_keys);
238        }
239        Ok(proof_targets)
240    }
241}
242
243#[derive(Debug, Clone)]
244struct WitnessTrieNodeProviderFactory<F> {
245    /// Trie node provider factory.
246    provider_factory: F,
247    /// Sender for forwarding fetched trie node.
248    tx: mpsc::Sender<Bytes>,
249}
250
251impl<F> WitnessTrieNodeProviderFactory<F> {
252    const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
253        Self { provider_factory, tx }
254    }
255}
256
257impl<F> TrieNodeProviderFactory for WitnessTrieNodeProviderFactory<F>
258where
259    F: TrieNodeProviderFactory,
260    F::AccountNodeProvider: TrieNodeProvider,
261    F::StorageNodeProvider: TrieNodeProvider,
262{
263    type AccountNodeProvider = WitnessTrieNodeProvider<F::AccountNodeProvider>;
264    type StorageNodeProvider = WitnessTrieNodeProvider<F::StorageNodeProvider>;
265
266    fn account_node_provider(&self) -> Self::AccountNodeProvider {
267        let provider = self.provider_factory.account_node_provider();
268        WitnessTrieNodeProvider::new(provider, self.tx.clone())
269    }
270
271    fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
272        let provider = self.provider_factory.storage_node_provider(account);
273        WitnessTrieNodeProvider::new(provider, self.tx.clone())
274    }
275}
276
277#[derive(Debug)]
278struct WitnessTrieNodeProvider<P> {
279    /// Proof-based blinded.
280    provider: P,
281    /// Sender for forwarding fetched blinded node.
282    tx: mpsc::Sender<Bytes>,
283}
284
285impl<P> WitnessTrieNodeProvider<P> {
286    const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
287        Self { provider, tx }
288    }
289}
290
291impl<P: TrieNodeProvider> TrieNodeProvider for WitnessTrieNodeProvider<P> {
292    fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
293        let maybe_node = self.provider.trie_node(path)?;
294        if let Some(node) = &maybe_node {
295            self.tx
296                .send(node.node.clone())
297                .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
298        }
299        Ok(maybe_node)
300    }
301}