1use crate::{
2 hashed_cursor::{HashedCursor, HashedCursorFactory},
3 prefix_set::TriePrefixSetsMut,
4 proof::{Proof, ProofTrieNodeProviderFactory},
5 trie_cursor::TrieCursorFactory,
6};
7use alloy_rlp::EMPTY_STRING_CODE;
8use alloy_trie::EMPTY_ROOT_HASH;
9use reth_trie_common::HashedPostState;
10use reth_trie_sparse::SparseTrie;
11
12use alloy_primitives::{
13 keccak256,
14 map::{B256Map, B256Set, Entry, HashMap},
15 Bytes, B256,
16};
17use itertools::Itertools;
18use reth_execution_errors::{
19 SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
20 TrieWitnessError,
21};
22use reth_trie_common::{MultiProofTargets, Nibbles};
23use reth_trie_sparse::{
24 provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory},
25 SparseStateTrie,
26};
27use std::sync::mpsc;
28
29#[derive(Debug)]
31pub struct TrieWitness<T, H> {
32 trie_cursor_factory: T,
34 hashed_cursor_factory: H,
36 prefix_sets: TriePrefixSetsMut,
38 always_include_root_node: bool,
43 witness: B256Map<Bytes>,
45}
46
47impl<T, H> TrieWitness<T, H> {
48 pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
50 Self {
51 trie_cursor_factory,
52 hashed_cursor_factory,
53 prefix_sets: TriePrefixSetsMut::default(),
54 always_include_root_node: false,
55 witness: HashMap::default(),
56 }
57 }
58
59 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
61 TrieWitness {
62 trie_cursor_factory,
63 hashed_cursor_factory: self.hashed_cursor_factory,
64 prefix_sets: self.prefix_sets,
65 always_include_root_node: self.always_include_root_node,
66 witness: self.witness,
67 }
68 }
69
70 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
72 TrieWitness {
73 trie_cursor_factory: self.trie_cursor_factory,
74 hashed_cursor_factory,
75 prefix_sets: self.prefix_sets,
76 always_include_root_node: self.always_include_root_node,
77 witness: self.witness,
78 }
79 }
80
81 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
83 self.prefix_sets = prefix_sets;
84 self
85 }
86
87 pub const fn always_include_root_node(mut self) -> Self {
91 self.always_include_root_node = true;
92 self
93 }
94}
95
96impl<T, H> TrieWitness<T, H>
97where
98 T: TrieCursorFactory + Clone,
99 H: HashedCursorFactory + Clone,
100{
101 pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
108 let is_state_empty = state.is_empty();
109 if is_state_empty && !self.always_include_root_node {
110 return Ok(Default::default())
111 }
112
113 let proof_targets = if is_state_empty {
114 MultiProofTargets::account(B256::ZERO)
115 } else {
116 self.get_proof_targets(&state)?
117 };
118 let prefix_sets = core::mem::take(&mut self.prefix_sets);
119 let multiproof =
120 Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
121 .with_prefix_sets_mut(prefix_sets)
122 .multiproof(proof_targets.clone())?;
123
124 if is_state_empty {
127 let (root_hash, root_node) = if let Some(root_node) =
128 multiproof.account_subtree.into_inner().remove(&Nibbles::default())
129 {
130 (keccak256(&root_node), root_node)
131 } else {
132 (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
133 };
134 return Ok(B256Map::from_iter([(root_hash, root_node)]))
135 }
136
137 for account_node in multiproof.account_subtree.values() {
139 if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
140 entry.insert(account_node.clone());
141 }
142 }
143 for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
144 if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
145 entry.insert(storage_node.clone());
146 }
147 }
148
149 let (tx, rx) = mpsc::channel();
150 let blinded_provider_factory = WitnessTrieNodeProviderFactory::new(
151 ProofTrieNodeProviderFactory::new(self.trie_cursor_factory, self.hashed_cursor_factory),
152 tx,
153 );
154 let mut sparse_trie = SparseStateTrie::new();
155 sparse_trie.reveal_multiproof(multiproof)?;
156
157 for (hashed_address, hashed_slots) in
159 proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
160 {
161 let provider = blinded_provider_factory.storage_node_provider(hashed_address);
163 let storage = state.storages.get(&hashed_address);
164 let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
165 SparseStateTrieErrorKind::SparseStorageTrie(
166 hashed_address,
167 SparseTrieErrorKind::Blind,
168 ),
169 )?;
170 for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
171 let storage_nibbles = Nibbles::unpack(hashed_slot);
172 let maybe_leaf_value = storage
173 .and_then(|s| s.storage.get(&hashed_slot))
174 .filter(|v| !v.is_zero())
175 .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
176
177 if let Some(value) = maybe_leaf_value {
178 storage_trie.update_leaf(storage_nibbles, value, &provider).map_err(|err| {
179 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
180 })?;
181 } else {
182 storage_trie.remove_leaf(&storage_nibbles, &provider).map_err(|err| {
183 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
184 })?;
185 }
186 }
187
188 let account = state
189 .accounts
190 .get(&hashed_address)
191 .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
192 .unwrap_or_default();
193
194 if !sparse_trie.update_account(hashed_address, account, &blinded_provider_factory)? {
195 let nibbles = Nibbles::unpack(hashed_address);
196 sparse_trie.remove_account_leaf(&nibbles, &blinded_provider_factory)?;
197 }
198
199 while let Ok(node) = rx.try_recv() {
200 self.witness.insert(keccak256(&node), node);
201 }
202 }
203
204 Ok(self.witness)
205 }
206
207 fn get_proof_targets(
211 &self,
212 state: &HashedPostState,
213 ) -> Result<MultiProofTargets, StateProofError> {
214 let mut proof_targets = MultiProofTargets::default();
215 for hashed_address in state.accounts.keys() {
216 proof_targets.insert(*hashed_address, B256Set::default());
217 }
218 for (hashed_address, storage) in &state.storages {
219 let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
220 if storage.wiped {
221 let mut storage_cursor =
223 self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
224 let mut current_entry = storage_cursor.seek(B256::ZERO)?;
226 while let Some((hashed_slot, _)) = current_entry {
227 storage_keys.insert(hashed_slot);
228 current_entry = storage_cursor.next()?;
229 }
230 }
231 proof_targets.insert(*hashed_address, storage_keys);
232 }
233 Ok(proof_targets)
234 }
235}
236
237#[derive(Debug, Clone)]
238struct WitnessTrieNodeProviderFactory<F> {
239 provider_factory: F,
241 tx: mpsc::Sender<Bytes>,
243}
244
245impl<F> WitnessTrieNodeProviderFactory<F> {
246 const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
247 Self { provider_factory, tx }
248 }
249}
250
251impl<F> TrieNodeProviderFactory for WitnessTrieNodeProviderFactory<F>
252where
253 F: TrieNodeProviderFactory,
254 F::AccountNodeProvider: TrieNodeProvider,
255 F::StorageNodeProvider: TrieNodeProvider,
256{
257 type AccountNodeProvider = WitnessTrieNodeProvider<F::AccountNodeProvider>;
258 type StorageNodeProvider = WitnessTrieNodeProvider<F::StorageNodeProvider>;
259
260 fn account_node_provider(&self) -> Self::AccountNodeProvider {
261 let provider = self.provider_factory.account_node_provider();
262 WitnessTrieNodeProvider::new(provider, self.tx.clone())
263 }
264
265 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
266 let provider = self.provider_factory.storage_node_provider(account);
267 WitnessTrieNodeProvider::new(provider, self.tx.clone())
268 }
269}
270
271#[derive(Debug)]
272struct WitnessTrieNodeProvider<P> {
273 provider: P,
275 tx: mpsc::Sender<Bytes>,
277}
278
279impl<P> WitnessTrieNodeProvider<P> {
280 const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
281 Self { provider, tx }
282 }
283}
284
285impl<P: TrieNodeProvider> TrieNodeProvider for WitnessTrieNodeProvider<P> {
286 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
287 let maybe_node = self.provider.trie_node(path)?;
288 if let Some(node) = &maybe_node {
289 self.tx
290 .send(node.node.clone())
291 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
292 }
293 Ok(maybe_node)
294 }
295}