1use crate::{
2 hashed_cursor::{HashedCursor, HashedCursorFactory},
3 prefix_set::TriePrefixSetsMut,
4 proof::Proof,
5 proof_v2,
6 trie_cursor::TrieCursorFactory,
7 TRIE_ACCOUNT_RLP_MAX_SIZE,
8};
9use alloy_primitives::{
10 keccak256,
11 map::{B256Map, HashMap},
12 Bytes, B256, U256,
13};
14use alloy_rlp::{Encodable, EMPTY_STRING_CODE};
15use alloy_trie::{nodes::BranchNodeRef, EMPTY_ROOT_HASH};
16use reth_execution_errors::{SparseStateTrieErrorKind, StateProofError, TrieWitnessError};
17use reth_trie_common::{
18 DecodedMultiProofV2, HashedPostState, MultiProofTargetsV2, ProofV2Target, TrieNodeV2,
19};
20use reth_trie_sparse::{LeafUpdate, SparseStateTrie, SparseTrie as _};
21
22#[derive(Debug)]
24pub struct TrieWitness<T, H> {
25 trie_cursor_factory: T,
27 hashed_cursor_factory: H,
29 prefix_sets: TriePrefixSetsMut,
31 always_include_root_node: bool,
36 witness: B256Map<Bytes>,
38}
39
40impl<T, H> TrieWitness<T, H> {
41 pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
43 Self {
44 trie_cursor_factory,
45 hashed_cursor_factory,
46 prefix_sets: TriePrefixSetsMut::default(),
47 always_include_root_node: false,
48 witness: HashMap::default(),
49 }
50 }
51
52 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
54 TrieWitness {
55 trie_cursor_factory,
56 hashed_cursor_factory: self.hashed_cursor_factory,
57 prefix_sets: self.prefix_sets,
58 always_include_root_node: self.always_include_root_node,
59 witness: self.witness,
60 }
61 }
62
63 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
65 TrieWitness {
66 trie_cursor_factory: self.trie_cursor_factory,
67 hashed_cursor_factory,
68 prefix_sets: self.prefix_sets,
69 always_include_root_node: self.always_include_root_node,
70 witness: self.witness,
71 }
72 }
73
74 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
76 self.prefix_sets = prefix_sets;
77 self
78 }
79
80 pub const fn always_include_root_node(mut self) -> Self {
84 self.always_include_root_node = true;
85 self
86 }
87}
88
89impl<T, H> TrieWitness<T, H>
90where
91 T: TrieCursorFactory + Clone,
92 H: HashedCursorFactory + Clone,
93{
94 pub fn compute(
101 mut self,
102 mut state: HashedPostState,
103 ) -> Result<B256Map<Bytes>, TrieWitnessError> {
104 let is_state_empty = state.is_empty();
105 if is_state_empty && !self.always_include_root_node {
106 return Ok(Default::default())
107 }
108
109 self.expand_wiped_storages(&mut state)?;
112
113 let proof_targets = if is_state_empty {
114 MultiProofTargetsV2 {
115 account_targets: vec![ProofV2Target::new(B256::ZERO)],
116 ..Default::default()
117 }
118 } else {
119 Self::get_proof_targets(&state)
120 };
121 let multiproof =
122 Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
123 .with_prefix_sets_mut(self.prefix_sets.clone())
124 .multiproof_v2(proof_targets)?;
125
126 if is_state_empty {
129 let (root_hash, root_node) = if let Some(root_node) =
130 multiproof.account_proofs.into_iter().find(|n| n.path.is_empty())
131 {
132 let mut encoded = Vec::new();
133 root_node.node.encode(&mut encoded);
134 let bytes = Bytes::from(encoded);
135 (keccak256(&bytes), bytes)
136 } else {
137 (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
138 };
139 return Ok(B256Map::from_iter([(root_hash, root_node)]))
140 }
141
142 self.record_multiproof_nodes(&multiproof);
144
145 let mut sparse_trie = SparseStateTrie::new();
146 sparse_trie.reveal_decoded_multiproof_v2(multiproof)?;
147
148 let mut storage_removals: B256Map<B256Map<LeafUpdate>> = B256Map::default();
155 let mut storage_upserts: B256Map<B256Map<LeafUpdate>> = B256Map::default();
156 for (hashed_address, storage) in &state.storages {
157 for (&hashed_slot, value) in &storage.storage {
158 if value.is_zero() {
159 storage_removals
160 .entry(*hashed_address)
161 .or_default()
162 .insert(hashed_slot, LeafUpdate::Changed(vec![]));
163 } else {
164 storage_upserts.entry(*hashed_address).or_default().insert(
165 hashed_slot,
166 LeafUpdate::Changed(alloy_rlp::encode_fixed_size(value).to_vec()),
167 );
168 }
169 }
170 }
171
172 for storage_updates in [&mut storage_removals, &mut storage_upserts] {
174 loop {
175 let mut targets = MultiProofTargetsV2::default();
176
177 for (&hashed_address, slot_updates) in storage_updates.iter_mut() {
178 if slot_updates.is_empty() {
179 continue;
180 }
181 let storage_trie = sparse_trie
182 .storage_trie_mut(&hashed_address)
183 .expect("storage trie was revealed from multiproof");
184 storage_trie
185 .update_leaves(slot_updates, |key, min_len| {
186 targets
187 .storage_targets
188 .entry(hashed_address)
189 .or_default()
190 .push(ProofV2Target::new(key).with_min_len(min_len));
191 })
192 .map_err(|err| {
193 SparseStateTrieErrorKind::SparseStorageTrie(
194 hashed_address,
195 err.into_kind(),
196 )
197 })?;
198 }
199
200 if targets.is_empty() {
201 break;
202 }
203
204 let multiproof = Proof::new(
205 self.trie_cursor_factory.clone(),
206 self.hashed_cursor_factory.clone(),
207 )
208 .with_prefix_sets_mut(self.prefix_sets.clone())
209 .multiproof_v2(targets)?;
210 self.record_multiproof_nodes(&multiproof);
211 sparse_trie.reveal_decoded_multiproof_v2(multiproof)?;
212 }
213 }
214
215 let mut account_removals: B256Map<LeafUpdate> = B256Map::default();
218 let mut account_upserts: B256Map<LeafUpdate> = B256Map::default();
219 for &hashed_address in state.accounts.keys().chain(state.storages.keys()) {
220 if account_removals.contains_key(&hashed_address) ||
221 account_upserts.contains_key(&hashed_address)
222 {
223 continue;
224 }
225
226 let account = state
227 .accounts
228 .get(&hashed_address)
229 .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
230 .unwrap_or_default();
231
232 let storage_root =
233 if let Some(storage_trie) = sparse_trie.storage_trie_mut(&hashed_address) {
234 storage_trie.root()
235 } else {
236 self.account_storage_root(hashed_address)?
237 };
238
239 if account.is_empty() && storage_root == EMPTY_ROOT_HASH {
240 account_removals.insert(hashed_address, LeafUpdate::Changed(vec![]));
241 } else {
242 let mut rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
243 account.into_trie_account(storage_root).encode(&mut rlp);
244 account_upserts.insert(hashed_address, LeafUpdate::Changed(rlp));
245 }
246 }
247
248 for account_updates in [&mut account_removals, &mut account_upserts] {
250 loop {
251 let mut targets = MultiProofTargetsV2::default();
252
253 sparse_trie
254 .trie_mut()
255 .update_leaves(account_updates, |key, min_len| {
256 targets.account_targets.push(ProofV2Target::new(key).with_min_len(min_len));
257 })
258 .map_err(SparseStateTrieErrorKind::from)?;
259
260 if targets.is_empty() {
261 break;
262 }
263
264 let multiproof = Proof::new(
265 self.trie_cursor_factory.clone(),
266 self.hashed_cursor_factory.clone(),
267 )
268 .with_prefix_sets_mut(self.prefix_sets.clone())
269 .multiproof_v2(targets)?;
270 self.record_multiproof_nodes(&multiproof);
271 sparse_trie.reveal_decoded_multiproof_v2(multiproof)?;
272 }
273 }
274
275 Ok(self.witness)
276 }
277
278 fn record_multiproof_nodes(&mut self, multiproof: &DecodedMultiProofV2) {
280 let mut encoded = Vec::new();
281 for proof_node in &multiproof.account_proofs {
282 self.record_witness_node(&proof_node.node, &mut encoded);
283 }
284 for proof_nodes in multiproof.storage_proofs.values() {
285 for proof_node in proof_nodes {
286 self.record_witness_node(&proof_node.node, &mut encoded);
287 }
288 }
289 }
290
291 fn record_witness_node(&mut self, node: &TrieNodeV2, encoded: &mut Vec<u8>) {
293 encoded.clear();
294 node.encode(encoded);
295 let bytes = Bytes::from(encoded.clone());
296 self.witness.entry(keccak256(&bytes)).or_insert(bytes);
297
298 if let TrieNodeV2::Branch(branch) = node &&
299 !branch.key.is_empty()
300 {
301 encoded.clear();
302 BranchNodeRef::new(&branch.stack, branch.state_mask).encode(encoded);
303 let bytes = Bytes::from(encoded.clone());
304 self.witness.entry(keccak256(&bytes)).or_insert(bytes);
305 }
306 }
307
308 fn account_storage_root(&mut self, hashed_address: B256) -> Result<B256, TrieWitnessError> {
311 let storage_trie_cursor = self
312 .trie_cursor_factory
313 .storage_trie_cursor(hashed_address)
314 .map_err(StateProofError::from)?;
315 let hashed_storage_cursor = self
316 .hashed_cursor_factory
317 .hashed_storage_cursor(hashed_address)
318 .map_err(StateProofError::from)?;
319 let mut calculator = proof_v2::StorageProofCalculator::new_storage(
320 storage_trie_cursor,
321 hashed_storage_cursor,
322 );
323 if let Some(prefix_set) = self.prefix_sets.storage_prefix_sets.get(&hashed_address) {
324 calculator = calculator.with_prefix_set(prefix_set.clone().freeze());
325 }
326 let root_node = calculator.storage_root_node(hashed_address)?;
327 let root_hash = calculator
328 .compute_root_hash(core::slice::from_ref(&root_node))?
329 .unwrap_or(EMPTY_ROOT_HASH);
330 drop(calculator);
331 let mut encoded = Vec::new();
332 self.record_witness_node(&root_node.node, &mut encoded);
333 Ok(root_hash)
334 }
335
336 fn expand_wiped_storages(&self, state: &mut HashedPostState) -> Result<(), StateProofError> {
339 for (hashed_address, storage) in &mut state.storages {
340 if !storage.wiped {
341 continue;
342 }
343 let mut storage_cursor =
344 self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
345 let mut current_entry = storage_cursor.seek(B256::ZERO)?;
346 while let Some((hashed_slot, _)) = current_entry {
347 storage.storage.entry(hashed_slot).or_insert(U256::ZERO);
348 current_entry = storage_cursor.next()?;
349 }
350 storage.wiped = false;
351 }
352 Ok(())
353 }
354
355 fn get_proof_targets(state: &HashedPostState) -> MultiProofTargetsV2 {
359 let mut targets = MultiProofTargetsV2::default();
360 for &hashed_address in state.accounts.keys() {
361 targets.account_targets.push(ProofV2Target::new(hashed_address));
362 }
363 for (&hashed_address, storage) in &state.storages {
364 if !state.accounts.contains_key(&hashed_address) {
365 targets.account_targets.push(ProofV2Target::new(hashed_address));
366 }
367 if storage.storage.is_empty() {
370 continue;
371 }
372 let storage_keys = storage.storage.keys().map(|k| ProofV2Target::new(*k)).collect();
373 targets.storage_targets.insert(hashed_address, storage_keys);
374 }
375 targets
376 }
377}