reth_trie/
witness.rs

1use crate::{
2    hashed_cursor::{HashedCursor, HashedCursorFactory},
3    prefix_set::TriePrefixSetsMut,
4    proof::{Proof, ProofBlindedProviderFactory},
5    trie_cursor::TrieCursorFactory,
6};
7use alloy_rlp::EMPTY_STRING_CODE;
8use alloy_trie::EMPTY_ROOT_HASH;
9use reth_trie_common::HashedPostState;
10
11use alloy_primitives::{
12    keccak256,
13    map::{B256Map, B256Set, Entry, HashMap},
14    Bytes, B256,
15};
16use itertools::Itertools;
17use reth_execution_errors::{
18    SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
19    TrieWitnessError,
20};
21use reth_trie_common::{MultiProofTargets, Nibbles};
22use reth_trie_sparse::{
23    blinded::{BlindedProvider, BlindedProviderFactory, RevealedNode},
24    SparseStateTrie,
25};
26use std::sync::{mpsc, Arc};
27
28/// State transition witness for the trie.
29#[derive(Debug)]
30pub struct TrieWitness<T, H> {
31    /// The cursor factory for traversing trie nodes.
32    trie_cursor_factory: T,
33    /// The factory for hashed cursors.
34    hashed_cursor_factory: H,
35    /// A set of prefix sets that have changes.
36    prefix_sets: TriePrefixSetsMut,
37    /// Flag indicating whether the root node should always be included (even if the target state
38    /// is empty). This setting is useful if the caller wants to verify the witness against the
39    /// parent state root.
40    /// Set to `false` by default.
41    always_include_root_node: bool,
42    /// Recorded witness.
43    witness: B256Map<Bytes>,
44}
45
46impl<T, H> TrieWitness<T, H> {
47    /// Creates a new witness generator.
48    pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
49        Self {
50            trie_cursor_factory,
51            hashed_cursor_factory,
52            prefix_sets: TriePrefixSetsMut::default(),
53            always_include_root_node: false,
54            witness: HashMap::default(),
55        }
56    }
57
58    /// Set the trie cursor factory.
59    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
60        TrieWitness {
61            trie_cursor_factory,
62            hashed_cursor_factory: self.hashed_cursor_factory,
63            prefix_sets: self.prefix_sets,
64            always_include_root_node: self.always_include_root_node,
65            witness: self.witness,
66        }
67    }
68
69    /// Set the hashed cursor factory.
70    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
71        TrieWitness {
72            trie_cursor_factory: self.trie_cursor_factory,
73            hashed_cursor_factory,
74            prefix_sets: self.prefix_sets,
75            always_include_root_node: self.always_include_root_node,
76            witness: self.witness,
77        }
78    }
79
80    /// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
81    pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
82        self.prefix_sets = prefix_sets;
83        self
84    }
85
86    /// Set `always_include_root_node` to true. Root node will be included even on empty state.
87    /// This setting is useful if the caller wants to verify the witness against the
88    /// parent state root.
89    pub const fn always_include_root_node(mut self) -> Self {
90        self.always_include_root_node = true;
91        self
92    }
93}
94
95impl<T, H> TrieWitness<T, H>
96where
97    T: TrieCursorFactory + Clone + Send + Sync,
98    H: HashedCursorFactory + Clone + Send + Sync,
99{
100    /// Compute the state transition witness for the trie. Gather all required nodes
101    /// to apply `state` on top of the current trie state.
102    ///
103    /// # Arguments
104    ///
105    /// `state` - state transition containing both modified and touched accounts and storage slots.
106    pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
107        let is_state_empty = state.is_empty();
108        if is_state_empty && !self.always_include_root_node {
109            return Ok(Default::default())
110        }
111
112        let proof_targets = if is_state_empty {
113            MultiProofTargets::account(B256::ZERO)
114        } else {
115            self.get_proof_targets(&state)?
116        };
117        let multiproof =
118            Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
119                .with_prefix_sets_mut(self.prefix_sets.clone())
120                .multiproof(proof_targets.clone())?;
121
122        // No need to reconstruct the rest of the trie, we just need to include
123        // the root node and return.
124        if is_state_empty {
125            let (root_hash, root_node) = if let Some(root_node) =
126                multiproof.account_subtree.into_inner().remove(&Nibbles::default())
127            {
128                (keccak256(&root_node), root_node)
129            } else {
130                (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
131            };
132            return Ok(B256Map::from_iter([(root_hash, root_node)]))
133        }
134
135        // Record all nodes from multiproof in the witness
136        for account_node in multiproof.account_subtree.values() {
137            if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
138                entry.insert(account_node.clone());
139            }
140        }
141        for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
142            if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
143                entry.insert(storage_node.clone());
144            }
145        }
146
147        let (tx, rx) = mpsc::channel();
148        let blinded_provider_factory = WitnessBlindedProviderFactory::new(
149            ProofBlindedProviderFactory::new(
150                self.trie_cursor_factory,
151                self.hashed_cursor_factory,
152                Arc::new(self.prefix_sets),
153            ),
154            tx,
155        );
156        let mut sparse_trie = SparseStateTrie::new(blinded_provider_factory);
157        sparse_trie.reveal_multiproof(multiproof)?;
158
159        // Attempt to update state trie to gather additional information for the witness.
160        for (hashed_address, hashed_slots) in
161            proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
162        {
163            // Update storage trie first.
164            let storage = state.storages.get(&hashed_address);
165            let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
166                SparseStateTrieErrorKind::SparseStorageTrie(
167                    hashed_address,
168                    SparseTrieErrorKind::Blind,
169                ),
170            )?;
171            for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
172                let storage_nibbles = Nibbles::unpack(hashed_slot);
173                let maybe_leaf_value = storage
174                    .and_then(|s| s.storage.get(&hashed_slot))
175                    .filter(|v| !v.is_zero())
176                    .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
177
178                if let Some(value) = maybe_leaf_value {
179                    storage_trie.update_leaf(storage_nibbles, value).map_err(|err| {
180                        SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
181                    })?;
182                } else {
183                    storage_trie.remove_leaf(&storage_nibbles).map_err(|err| {
184                        SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
185                    })?;
186                }
187            }
188
189            // Calculate storage root after updates.
190            storage_trie.root();
191
192            let account = state
193                .accounts
194                .get(&hashed_address)
195                .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
196                .unwrap_or_default();
197            sparse_trie.update_account(hashed_address, account)?;
198
199            while let Ok(node) = rx.try_recv() {
200                self.witness.insert(keccak256(&node), node);
201            }
202        }
203
204        Ok(self.witness)
205    }
206
207    /// Retrieve proof targets for incoming hashed state.
208    /// This method will aggregate all accounts and slots present in the hash state as well as
209    /// select all existing slots from the database for the accounts that have been destroyed.
210    fn get_proof_targets(
211        &self,
212        state: &HashedPostState,
213    ) -> Result<MultiProofTargets, StateProofError> {
214        let mut proof_targets = MultiProofTargets::default();
215        for hashed_address in state.accounts.keys() {
216            proof_targets.insert(*hashed_address, B256Set::default());
217        }
218        for (hashed_address, storage) in &state.storages {
219            let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
220            if storage.wiped {
221                // storage for this account was destroyed, gather all slots from the current state
222                let mut storage_cursor =
223                    self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
224                // position cursor at the start
225                let mut current_entry = storage_cursor.seek(B256::ZERO)?;
226                while let Some((hashed_slot, _)) = current_entry {
227                    storage_keys.insert(hashed_slot);
228                    current_entry = storage_cursor.next()?;
229                }
230            }
231            proof_targets.insert(*hashed_address, storage_keys);
232        }
233        Ok(proof_targets)
234    }
235}
236
237#[derive(Debug, Clone)]
238struct WitnessBlindedProviderFactory<F> {
239    /// Blinded node provider factory.
240    provider_factory: F,
241    /// Sender for forwarding fetched blinded node.
242    tx: mpsc::Sender<Bytes>,
243}
244
245impl<F> WitnessBlindedProviderFactory<F> {
246    const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
247        Self { provider_factory, tx }
248    }
249}
250
251impl<F> BlindedProviderFactory for WitnessBlindedProviderFactory<F>
252where
253    F: BlindedProviderFactory,
254    F::AccountNodeProvider: BlindedProvider,
255    F::StorageNodeProvider: BlindedProvider,
256{
257    type AccountNodeProvider = WitnessBlindedProvider<F::AccountNodeProvider>;
258    type StorageNodeProvider = WitnessBlindedProvider<F::StorageNodeProvider>;
259
260    fn account_node_provider(&self) -> Self::AccountNodeProvider {
261        let provider = self.provider_factory.account_node_provider();
262        WitnessBlindedProvider::new(provider, self.tx.clone())
263    }
264
265    fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
266        let provider = self.provider_factory.storage_node_provider(account);
267        WitnessBlindedProvider::new(provider, self.tx.clone())
268    }
269}
270
271#[derive(Debug)]
272struct WitnessBlindedProvider<P> {
273    /// Proof-based blinded.
274    provider: P,
275    /// Sender for forwarding fetched blinded node.
276    tx: mpsc::Sender<Bytes>,
277}
278
279impl<P> WitnessBlindedProvider<P> {
280    const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
281        Self { provider, tx }
282    }
283}
284
285impl<P: BlindedProvider> BlindedProvider for WitnessBlindedProvider<P> {
286    fn blinded_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
287        let maybe_node = self.provider.blinded_node(path)?;
288        if let Some(node) = &maybe_node {
289            self.tx
290                .send(node.node.clone())
291                .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
292        }
293        Ok(maybe_node)
294    }
295}