1use crate::{
2 hashed_cursor::{HashedCursor, HashedCursorFactory},
3 prefix_set::TriePrefixSetsMut,
4 proof::{Proof, ProofTrieNodeProviderFactory},
5 trie_cursor::TrieCursorFactory,
6};
7use alloy_rlp::EMPTY_STRING_CODE;
8use alloy_trie::EMPTY_ROOT_HASH;
9use reth_trie_common::HashedPostState;
10use reth_trie_sparse::SparseTrieInterface;
11
12use alloy_primitives::{
13 keccak256,
14 map::{B256Map, B256Set, Entry, HashMap},
15 Bytes, B256,
16};
17use itertools::Itertools;
18use reth_execution_errors::{
19 SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
20 TrieWitnessError,
21};
22use reth_trie_common::{MultiProofTargets, Nibbles};
23use reth_trie_sparse::{
24 provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory},
25 SerialSparseTrie, SparseStateTrie,
26};
27use std::sync::mpsc;
28
29#[derive(Debug)]
31pub struct TrieWitness<T, H> {
32 trie_cursor_factory: T,
34 hashed_cursor_factory: H,
36 prefix_sets: TriePrefixSetsMut,
38 always_include_root_node: bool,
43 witness: B256Map<Bytes>,
45}
46
47impl<T, H> TrieWitness<T, H> {
48 pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
50 Self {
51 trie_cursor_factory,
52 hashed_cursor_factory,
53 prefix_sets: TriePrefixSetsMut::default(),
54 always_include_root_node: false,
55 witness: HashMap::default(),
56 }
57 }
58
59 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
61 TrieWitness {
62 trie_cursor_factory,
63 hashed_cursor_factory: self.hashed_cursor_factory,
64 prefix_sets: self.prefix_sets,
65 always_include_root_node: self.always_include_root_node,
66 witness: self.witness,
67 }
68 }
69
70 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
72 TrieWitness {
73 trie_cursor_factory: self.trie_cursor_factory,
74 hashed_cursor_factory,
75 prefix_sets: self.prefix_sets,
76 always_include_root_node: self.always_include_root_node,
77 witness: self.witness,
78 }
79 }
80
81 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
83 self.prefix_sets = prefix_sets;
84 self
85 }
86
87 pub const fn always_include_root_node(mut self) -> Self {
91 self.always_include_root_node = true;
92 self
93 }
94}
95
96impl<T, H> TrieWitness<T, H>
97where
98 T: TrieCursorFactory + Clone + Send + Sync,
99 H: HashedCursorFactory + Clone + Send + Sync,
100{
101 pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
108 let is_state_empty = state.is_empty();
109 if is_state_empty && !self.always_include_root_node {
110 return Ok(Default::default())
111 }
112
113 let proof_targets = if is_state_empty {
114 MultiProofTargets::account(B256::ZERO)
115 } else {
116 self.get_proof_targets(&state)?
117 };
118 let multiproof =
119 Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
120 .with_prefix_sets_mut(self.prefix_sets.clone())
121 .multiproof(proof_targets.clone())?;
122
123 if is_state_empty {
126 let (root_hash, root_node) = if let Some(root_node) =
127 multiproof.account_subtree.into_inner().remove(&Nibbles::default())
128 {
129 (keccak256(&root_node), root_node)
130 } else {
131 (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
132 };
133 return Ok(B256Map::from_iter([(root_hash, root_node)]))
134 }
135
136 for account_node in multiproof.account_subtree.values() {
138 if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
139 entry.insert(account_node.clone());
140 }
141 }
142 for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
143 if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
144 entry.insert(storage_node.clone());
145 }
146 }
147
148 let (tx, rx) = mpsc::channel();
149 let blinded_provider_factory = WitnessTrieNodeProviderFactory::new(
150 ProofTrieNodeProviderFactory::new(self.trie_cursor_factory, self.hashed_cursor_factory),
151 tx,
152 );
153 let mut sparse_trie = SparseStateTrie::<SerialSparseTrie>::new();
154 sparse_trie.reveal_multiproof(multiproof)?;
155
156 for (hashed_address, hashed_slots) in
158 proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
159 {
160 let provider = blinded_provider_factory.storage_node_provider(hashed_address);
162 let storage = state.storages.get(&hashed_address);
163 let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
164 SparseStateTrieErrorKind::SparseStorageTrie(
165 hashed_address,
166 SparseTrieErrorKind::Blind,
167 ),
168 )?;
169 for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
170 let storage_nibbles = Nibbles::unpack(hashed_slot);
171 let maybe_leaf_value = storage
172 .and_then(|s| s.storage.get(&hashed_slot))
173 .filter(|v| !v.is_zero())
174 .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
175
176 if let Some(value) = maybe_leaf_value {
177 storage_trie.update_leaf(storage_nibbles, value, &provider).map_err(|err| {
178 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
179 })?;
180 } else {
181 storage_trie.remove_leaf(&storage_nibbles, &provider).map_err(|err| {
182 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
183 })?;
184 }
185 }
186
187 storage_trie.root();
189
190 let account = state
191 .accounts
192 .get(&hashed_address)
193 .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
194 .unwrap_or_default();
195
196 if !sparse_trie.update_account(hashed_address, account, &blinded_provider_factory)? {
197 let nibbles = Nibbles::unpack(hashed_address);
198 sparse_trie.remove_account_leaf(&nibbles, &blinded_provider_factory)?;
199 }
200
201 while let Ok(node) = rx.try_recv() {
202 self.witness.insert(keccak256(&node), node);
203 }
204 }
205
206 Ok(self.witness)
207 }
208
209 fn get_proof_targets(
213 &self,
214 state: &HashedPostState,
215 ) -> Result<MultiProofTargets, StateProofError> {
216 let mut proof_targets = MultiProofTargets::default();
217 for hashed_address in state.accounts.keys() {
218 proof_targets.insert(*hashed_address, B256Set::default());
219 }
220 for (hashed_address, storage) in &state.storages {
221 let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
222 if storage.wiped {
223 let mut storage_cursor =
225 self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
226 let mut current_entry = storage_cursor.seek(B256::ZERO)?;
228 while let Some((hashed_slot, _)) = current_entry {
229 storage_keys.insert(hashed_slot);
230 current_entry = storage_cursor.next()?;
231 }
232 }
233 proof_targets.insert(*hashed_address, storage_keys);
234 }
235 Ok(proof_targets)
236 }
237}
238
239#[derive(Debug, Clone)]
240struct WitnessTrieNodeProviderFactory<F> {
241 provider_factory: F,
243 tx: mpsc::Sender<Bytes>,
245}
246
247impl<F> WitnessTrieNodeProviderFactory<F> {
248 const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
249 Self { provider_factory, tx }
250 }
251}
252
253impl<F> TrieNodeProviderFactory for WitnessTrieNodeProviderFactory<F>
254where
255 F: TrieNodeProviderFactory,
256 F::AccountNodeProvider: TrieNodeProvider,
257 F::StorageNodeProvider: TrieNodeProvider,
258{
259 type AccountNodeProvider = WitnessTrieNodeProvider<F::AccountNodeProvider>;
260 type StorageNodeProvider = WitnessTrieNodeProvider<F::StorageNodeProvider>;
261
262 fn account_node_provider(&self) -> Self::AccountNodeProvider {
263 let provider = self.provider_factory.account_node_provider();
264 WitnessTrieNodeProvider::new(provider, self.tx.clone())
265 }
266
267 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
268 let provider = self.provider_factory.storage_node_provider(account);
269 WitnessTrieNodeProvider::new(provider, self.tx.clone())
270 }
271}
272
273#[derive(Debug)]
274struct WitnessTrieNodeProvider<P> {
275 provider: P,
277 tx: mpsc::Sender<Bytes>,
279}
280
281impl<P> WitnessTrieNodeProvider<P> {
282 const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
283 Self { provider, tx }
284 }
285}
286
287impl<P: TrieNodeProvider> TrieNodeProvider for WitnessTrieNodeProvider<P> {
288 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
289 let maybe_node = self.provider.trie_node(path)?;
290 if let Some(node) = &maybe_node {
291 self.tx
292 .send(node.node.clone())
293 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
294 }
295 Ok(maybe_node)
296 }
297}