Skip to main content

reth_trie/
witness.rs

1use crate::{
2    hashed_cursor::{HashedCursor, HashedCursorFactory},
3    prefix_set::TriePrefixSetsMut,
4    proof::Proof,
5    proof_v2,
6    trie_cursor::TrieCursorFactory,
7    TRIE_ACCOUNT_RLP_MAX_SIZE,
8};
9use alloy_primitives::{
10    keccak256,
11    map::{B256Map, HashMap},
12    Bytes, B256, U256,
13};
14use alloy_rlp::{Encodable, EMPTY_STRING_CODE};
15use alloy_trie::{nodes::BranchNodeRef, EMPTY_ROOT_HASH};
16use reth_execution_errors::{SparseStateTrieErrorKind, StateProofError, TrieWitnessError};
17use reth_trie_common::{
18    DecodedMultiProofV2, ExecutionWitnessMode, HashedPostState, MultiProofTargetsV2, ProofV2Target,
19    TrieNodeV2,
20};
21use reth_trie_sparse::{LeafUpdate, SparseStateTrie, SparseTrie as _};
22
23/// State transition witness for the trie.
24#[derive(Debug)]
25pub struct TrieWitness<T, H> {
26    /// The cursor factory for traversing trie nodes.
27    trie_cursor_factory: T,
28    /// The factory for hashed cursors.
29    hashed_cursor_factory: H,
30    /// A set of prefix sets that have changes.
31    prefix_sets: TriePrefixSetsMut,
32    /// Flag indicating whether the root node should always be included (even if the target state
33    /// is empty). This setting is useful if the caller wants to verify the witness against the
34    /// parent state root.
35    /// Set to `false` by default.
36    always_include_root_node: bool,
37    /// Controls how the witness is generated.
38    mode: ExecutionWitnessMode,
39    /// Recorded witness.
40    witness: B256Map<Bytes>,
41}
42
43impl<T, H> TrieWitness<T, H> {
44    /// Creates a new witness generator.
45    pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
46        Self {
47            trie_cursor_factory,
48            hashed_cursor_factory,
49            prefix_sets: TriePrefixSetsMut::default(),
50            always_include_root_node: false,
51            mode: ExecutionWitnessMode::Legacy,
52            witness: HashMap::default(),
53        }
54    }
55
56    /// Set the trie cursor factory.
57    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
58        TrieWitness {
59            trie_cursor_factory,
60            hashed_cursor_factory: self.hashed_cursor_factory,
61            prefix_sets: self.prefix_sets,
62            always_include_root_node: self.always_include_root_node,
63            mode: self.mode,
64            witness: self.witness,
65        }
66    }
67
68    /// Set the hashed cursor factory.
69    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
70        TrieWitness {
71            trie_cursor_factory: self.trie_cursor_factory,
72            hashed_cursor_factory,
73            prefix_sets: self.prefix_sets,
74            always_include_root_node: self.always_include_root_node,
75            mode: self.mode,
76            witness: self.witness,
77        }
78    }
79
80    /// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
81    pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
82        self.prefix_sets = prefix_sets;
83        self
84    }
85
86    /// Set `always_include_root_node` to true. Root node will be included even in empty state.
87    /// This setting is useful if the caller wants to verify the witness against the
88    /// parent state root.
89    pub const fn always_include_root_node(mut self) -> Self {
90        self.always_include_root_node = true;
91        self
92    }
93
94    /// Set the execution witness generation mode.
95    pub const fn with_execution_witness_mode(mut self, mode: ExecutionWitnessMode) -> Self {
96        self.mode = mode;
97        self
98    }
99}
100
101impl<T, H> TrieWitness<T, H>
102where
103    T: TrieCursorFactory + Clone,
104    H: HashedCursorFactory + Clone,
105{
106    /// Compute the state transition witness for the trie. Gather all required nodes
107    /// to apply `state` on top of the current trie state.
108    ///
109    /// # Arguments
110    ///
111    /// `state` - state transition containing both modified and touched accounts and storage slots.
112    pub fn compute(
113        mut self,
114        mut state: HashedPostState,
115    ) -> Result<B256Map<Bytes>, TrieWitnessError> {
116        let is_state_empty = state.is_empty();
117        if is_state_empty && !self.always_include_root_node {
118            return Ok(Default::default())
119        }
120
121        // Expand wiped storages into explicit zero-value entries for every existing slot,
122        // so that downstream code can treat all storages uniformly.
123        self.expand_wiped_storages(&mut state)?;
124
125        let proof_targets = if is_state_empty {
126            MultiProofTargetsV2 {
127                account_targets: vec![ProofV2Target::new(B256::ZERO)],
128                ..Default::default()
129            }
130        } else {
131            Self::get_proof_targets(&state)
132        };
133        let multiproof =
134            Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
135                .with_prefix_sets_mut(self.prefix_sets.clone())
136                .multiproof_v2(proof_targets)?;
137
138        // No need to reconstruct the rest of the trie, we just need to include
139        // the root node and return.
140        if is_state_empty {
141            let (root_hash, root_node) = if let Some(root_node) =
142                multiproof.account_proofs.into_iter().find(|n| n.path.is_empty())
143            {
144                let mut encoded = Vec::new();
145                root_node.node.encode(&mut encoded);
146                let bytes = Bytes::from(encoded);
147                (keccak256(&bytes), bytes)
148            } else {
149                (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
150            };
151            return Ok(B256Map::from_iter([(root_hash, root_node)]))
152        }
153
154        // Record all nodes from multiproof in the witness.
155        self.record_multiproof_nodes(&multiproof);
156
157        let mut sparse_trie = SparseStateTrie::new();
158        sparse_trie.reveal_decoded_multiproof_v2(multiproof)?;
159
160        // Build storage leaf updates for all accounts with storage changes, split into
161        // removals and upserts. Legacy mode applies removals first to preserve the
162        // historical witness shape expected by existing consumers: a removal can collapse
163        // a branch and force proof fetches that some consumers still rely on. Canonical
164        // mode applies upserts first to avoid those compatibility-only nodes and emit
165        // the minimized draft-spec witness.
166        let mut storage_removals: B256Map<B256Map<LeafUpdate>> = B256Map::default();
167        let mut storage_upserts: B256Map<B256Map<LeafUpdate>> = B256Map::default();
168        for (hashed_address, storage) in &state.storages {
169            for (&hashed_slot, value) in &storage.storage {
170                if value.is_zero() {
171                    storage_removals
172                        .entry(*hashed_address)
173                        .or_default()
174                        .insert(hashed_slot, LeafUpdate::Changed(vec![]));
175                } else {
176                    storage_upserts.entry(*hashed_address).or_default().insert(
177                        hashed_slot,
178                        LeafUpdate::Changed(alloy_rlp::encode_fixed_size(value).to_vec()),
179                    );
180                }
181            }
182        }
183
184        let storage_update_sets = if self.mode.is_canonical() {
185            [&mut storage_upserts, &mut storage_removals]
186        } else {
187            [&mut storage_removals, &mut storage_upserts]
188        };
189
190        // Apply storage updates in mode-specific order, fetching additional proofs as needed.
191        for storage_updates in storage_update_sets {
192            loop {
193                let mut targets = MultiProofTargetsV2::default();
194
195                for (&hashed_address, slot_updates) in storage_updates.iter_mut() {
196                    if slot_updates.is_empty() {
197                        continue;
198                    }
199                    let storage_trie = sparse_trie
200                        .storage_trie_mut(&hashed_address)
201                        .expect("storage trie was revealed from multiproof");
202                    storage_trie
203                        .update_leaves(slot_updates, |key, min_len| {
204                            targets
205                                .storage_targets
206                                .entry(hashed_address)
207                                .or_default()
208                                .push(ProofV2Target::new(key).with_min_len(min_len));
209                        })
210                        .map_err(|err| {
211                            SparseStateTrieErrorKind::SparseStorageTrie(
212                                hashed_address,
213                                err.into_kind(),
214                            )
215                        })?;
216                }
217
218                if targets.is_empty() {
219                    break;
220                }
221
222                let multiproof = Proof::new(
223                    self.trie_cursor_factory.clone(),
224                    self.hashed_cursor_factory.clone(),
225                )
226                .with_prefix_sets_mut(self.prefix_sets.clone())
227                .multiproof_v2(targets)?;
228                self.record_multiproof_nodes(&multiproof);
229                sparse_trie.reveal_decoded_multiproof_v2(multiproof)?;
230            }
231        }
232
233        // Build account leaf updates, split into removals and upserts. Legacy mode keeps
234        // removals-first for the same compatibility reason as storage updates, while
235        // canonical mode uses upserts-first so account updates follow the minimized
236        // draft-spec witness order.
237        let mut account_removals: B256Map<LeafUpdate> = B256Map::default();
238        let mut account_upserts: B256Map<LeafUpdate> = B256Map::default();
239        for &hashed_address in state.accounts.keys().chain(state.storages.keys()) {
240            if account_removals.contains_key(&hashed_address) ||
241                account_upserts.contains_key(&hashed_address)
242            {
243                continue;
244            }
245
246            let account = state
247                .accounts
248                .get(&hashed_address)
249                .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
250                .unwrap_or_default();
251
252            let storage_root =
253                if let Some(storage_trie) = sparse_trie.storage_trie_mut(&hashed_address) {
254                    storage_trie.root()
255                } else {
256                    let record_root_node = !self.mode.is_canonical() ||
257                        state
258                            .storages
259                            .get(&hashed_address)
260                            .is_some_and(|storage| !storage.storage.is_empty());
261                    self.account_storage_root(hashed_address, record_root_node)?
262                };
263
264            if account.is_empty() && storage_root == EMPTY_ROOT_HASH {
265                account_removals.insert(hashed_address, LeafUpdate::Changed(vec![]));
266            } else {
267                let mut rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
268                account.into_trie_account(storage_root).encode(&mut rlp);
269                account_upserts.insert(hashed_address, LeafUpdate::Changed(rlp));
270            }
271        }
272
273        let account_update_sets = if self.mode.is_canonical() {
274            [&mut account_upserts, &mut account_removals]
275        } else {
276            [&mut account_removals, &mut account_upserts]
277        };
278
279        // Apply account updates in mode-specific order, fetching additional proofs as needed.
280        for account_updates in account_update_sets {
281            loop {
282                let mut targets = MultiProofTargetsV2::default();
283
284                sparse_trie
285                    .trie_mut()
286                    .update_leaves(account_updates, |key, min_len| {
287                        targets.account_targets.push(ProofV2Target::new(key).with_min_len(min_len));
288                    })
289                    .map_err(SparseStateTrieErrorKind::from)?;
290
291                if targets.is_empty() {
292                    break;
293                }
294
295                let multiproof = Proof::new(
296                    self.trie_cursor_factory.clone(),
297                    self.hashed_cursor_factory.clone(),
298                )
299                .with_prefix_sets_mut(self.prefix_sets.clone())
300                .multiproof_v2(targets)?;
301                self.record_multiproof_nodes(&multiproof);
302                sparse_trie.reveal_decoded_multiproof_v2(multiproof)?;
303            }
304        }
305
306        if self.mode.is_canonical() {
307            // Empty trie nodes carry no useful witness information and are trivially
308            // reconstructible from the empty root hash.
309            self.witness.retain(|_, value| value.as_ref() != [EMPTY_STRING_CODE]);
310        }
311
312        Ok(self.witness)
313    }
314
315    /// Record all nodes from a V2 decoded multiproof in the witness.
316    fn record_multiproof_nodes(&mut self, multiproof: &DecodedMultiProofV2) {
317        let mut encoded = Vec::new();
318        for proof_node in &multiproof.account_proofs {
319            self.record_witness_node(&proof_node.node, &mut encoded);
320        }
321        for proof_nodes in multiproof.storage_proofs.values() {
322            for proof_node in proof_nodes {
323                self.record_witness_node(&proof_node.node, &mut encoded);
324            }
325        }
326    }
327
328    /// Record a single [`TrieNodeV2`] in the witness.
329    fn record_witness_node(&mut self, node: &TrieNodeV2, encoded: &mut Vec<u8>) {
330        encoded.clear();
331        node.encode(encoded);
332        let bytes = Bytes::from(encoded.clone());
333        self.witness.entry(keccak256(&bytes)).or_insert(bytes);
334
335        if let TrieNodeV2::Branch(branch) = node &&
336            !branch.key.is_empty()
337        {
338            encoded.clear();
339            BranchNodeRef::new(&branch.stack, branch.state_mask).encode(encoded);
340            let bytes = Bytes::from(encoded.clone());
341            self.witness.entry(keccak256(&bytes)).or_insert(bytes);
342        }
343    }
344
345    /// Compute the storage root for an account by walking the storage trie using the cursor
346    /// factories and trie input prefix sets. Records the root node in the witness when requested.
347    fn account_storage_root(
348        &mut self,
349        hashed_address: B256,
350        record_root_node: bool,
351    ) -> Result<B256, TrieWitnessError> {
352        let storage_trie_cursor = self
353            .trie_cursor_factory
354            .storage_trie_cursor(hashed_address)
355            .map_err(StateProofError::from)?;
356        let hashed_storage_cursor = self
357            .hashed_cursor_factory
358            .hashed_storage_cursor(hashed_address)
359            .map_err(StateProofError::from)?;
360        let mut calculator = proof_v2::StorageProofCalculator::new_storage(
361            storage_trie_cursor,
362            hashed_storage_cursor,
363        );
364        if let Some(prefix_set) = self.prefix_sets.storage_prefix_sets.get(&hashed_address) {
365            calculator = calculator.with_prefix_set(prefix_set.clone().freeze());
366        }
367        let root_node = calculator.storage_root_node(hashed_address)?;
368        let root_hash = calculator
369            .compute_root_hash(core::slice::from_ref(&root_node))?
370            .unwrap_or(EMPTY_ROOT_HASH);
371        drop(calculator);
372        if record_root_node {
373            let mut encoded = Vec::new();
374            self.record_witness_node(&root_node.node, &mut encoded);
375        }
376        Ok(root_hash)
377    }
378
379    /// Expand wiped storages into explicit zero-value entries for every existing slot in the
380    /// database. After this, all storages can be treated uniformly without special wiped handling.
381    fn expand_wiped_storages(&self, state: &mut HashedPostState) -> Result<(), StateProofError> {
382        for (hashed_address, storage) in &mut state.storages {
383            if !storage.wiped {
384                continue;
385            }
386            let mut storage_cursor =
387                self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
388            let mut current_entry = storage_cursor.seek(B256::ZERO)?;
389            while let Some((hashed_slot, _)) = current_entry {
390                storage.storage.entry(hashed_slot).or_insert(U256::ZERO);
391                current_entry = storage_cursor.next()?;
392            }
393            storage.wiped = false;
394        }
395        Ok(())
396    }
397
398    /// Retrieve proof targets for incoming hashed state.
399    /// Aggregates all accounts and slots present in the state. Wiped storages must have been
400    /// expanded via [`Self::expand_wiped_storages`] before calling this.
401    fn get_proof_targets(state: &HashedPostState) -> MultiProofTargetsV2 {
402        let mut targets = MultiProofTargetsV2::default();
403        for &hashed_address in state.accounts.keys() {
404            targets.account_targets.push(ProofV2Target::new(hashed_address));
405        }
406        for (&hashed_address, storage) in &state.storages {
407            if !state.accounts.contains_key(&hashed_address) {
408                targets.account_targets.push(ProofV2Target::new(hashed_address));
409            }
410            // Skip accounts with no storage slot changes — an empty target set would produce
411            // an empty proof vec which cannot be revealed (no root node).
412            if storage.storage.is_empty() {
413                continue;
414            }
415            let storage_keys = storage.storage.keys().map(|k| ProofV2Target::new(*k)).collect();
416            targets.storage_targets.insert(hashed_address, storage_keys);
417        }
418        targets
419    }
420}