1use crate::{
2 hashed_cursor::{HashedCursor, HashedCursorFactory},
3 prefix_set::TriePrefixSetsMut,
4 proof::{Proof, ProofTrieNodeProviderFactory},
5 trie_cursor::TrieCursorFactory,
6};
7use alloy_rlp::EMPTY_STRING_CODE;
8use alloy_trie::EMPTY_ROOT_HASH;
9use reth_trie_common::HashedPostState;
10use reth_trie_sparse::SparseTrieInterface;
11
12use alloy_primitives::{
13 keccak256,
14 map::{B256Map, B256Set, Entry, HashMap},
15 Bytes, B256,
16};
17use itertools::Itertools;
18use reth_execution_errors::{
19 SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
20 TrieWitnessError,
21};
22use reth_trie_common::{MultiProofTargets, Nibbles};
23use reth_trie_sparse::{
24 provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory},
25 SerialSparseTrie, SparseStateTrie,
26};
27use std::sync::{mpsc, Arc};
28
29#[derive(Debug)]
31pub struct TrieWitness<T, H> {
32 trie_cursor_factory: T,
34 hashed_cursor_factory: H,
36 prefix_sets: TriePrefixSetsMut,
38 always_include_root_node: bool,
43 witness: B256Map<Bytes>,
45}
46
47impl<T, H> TrieWitness<T, H> {
48 pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
50 Self {
51 trie_cursor_factory,
52 hashed_cursor_factory,
53 prefix_sets: TriePrefixSetsMut::default(),
54 always_include_root_node: false,
55 witness: HashMap::default(),
56 }
57 }
58
59 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
61 TrieWitness {
62 trie_cursor_factory,
63 hashed_cursor_factory: self.hashed_cursor_factory,
64 prefix_sets: self.prefix_sets,
65 always_include_root_node: self.always_include_root_node,
66 witness: self.witness,
67 }
68 }
69
70 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
72 TrieWitness {
73 trie_cursor_factory: self.trie_cursor_factory,
74 hashed_cursor_factory,
75 prefix_sets: self.prefix_sets,
76 always_include_root_node: self.always_include_root_node,
77 witness: self.witness,
78 }
79 }
80
81 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
83 self.prefix_sets = prefix_sets;
84 self
85 }
86
87 pub const fn always_include_root_node(mut self) -> Self {
91 self.always_include_root_node = true;
92 self
93 }
94}
95
96impl<T, H> TrieWitness<T, H>
97where
98 T: TrieCursorFactory + Clone + Send + Sync,
99 H: HashedCursorFactory + Clone + Send + Sync,
100{
101 pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
108 let is_state_empty = state.is_empty();
109 if is_state_empty && !self.always_include_root_node {
110 return Ok(Default::default())
111 }
112
113 let proof_targets = if is_state_empty {
114 MultiProofTargets::account(B256::ZERO)
115 } else {
116 self.get_proof_targets(&state)?
117 };
118 let multiproof =
119 Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
120 .with_prefix_sets_mut(self.prefix_sets.clone())
121 .multiproof(proof_targets.clone())?;
122
123 if is_state_empty {
126 let (root_hash, root_node) = if let Some(root_node) =
127 multiproof.account_subtree.into_inner().remove(&Nibbles::default())
128 {
129 (keccak256(&root_node), root_node)
130 } else {
131 (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
132 };
133 return Ok(B256Map::from_iter([(root_hash, root_node)]))
134 }
135
136 for account_node in multiproof.account_subtree.values() {
138 if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
139 entry.insert(account_node.clone());
140 }
141 }
142 for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
143 if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
144 entry.insert(storage_node.clone());
145 }
146 }
147
148 let (tx, rx) = mpsc::channel();
149 let blinded_provider_factory = WitnessTrieNodeProviderFactory::new(
150 ProofTrieNodeProviderFactory::new(
151 self.trie_cursor_factory,
152 self.hashed_cursor_factory,
153 Arc::new(self.prefix_sets),
154 ),
155 tx,
156 );
157 let mut sparse_trie = SparseStateTrie::<SerialSparseTrie>::new();
158 sparse_trie.reveal_multiproof(multiproof)?;
159
160 for (hashed_address, hashed_slots) in
162 proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
163 {
164 let provider = blinded_provider_factory.storage_node_provider(hashed_address);
166 let storage = state.storages.get(&hashed_address);
167 let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
168 SparseStateTrieErrorKind::SparseStorageTrie(
169 hashed_address,
170 SparseTrieErrorKind::Blind,
171 ),
172 )?;
173 for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
174 let storage_nibbles = Nibbles::unpack(hashed_slot);
175 let maybe_leaf_value = storage
176 .and_then(|s| s.storage.get(&hashed_slot))
177 .filter(|v| !v.is_zero())
178 .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
179
180 if let Some(value) = maybe_leaf_value {
181 storage_trie.update_leaf(storage_nibbles, value, &provider).map_err(|err| {
182 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
183 })?;
184 } else {
185 storage_trie.remove_leaf(&storage_nibbles, &provider).map_err(|err| {
186 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
187 })?;
188 }
189 }
190
191 storage_trie.root();
193
194 let account = state
195 .accounts
196 .get(&hashed_address)
197 .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
198 .unwrap_or_default();
199
200 if !sparse_trie.update_account(hashed_address, account, &blinded_provider_factory)? {
201 let nibbles = Nibbles::unpack(hashed_address);
202 sparse_trie.remove_account_leaf(&nibbles, &blinded_provider_factory)?;
203 }
204
205 while let Ok(node) = rx.try_recv() {
206 self.witness.insert(keccak256(&node), node);
207 }
208 }
209
210 Ok(self.witness)
211 }
212
213 fn get_proof_targets(
217 &self,
218 state: &HashedPostState,
219 ) -> Result<MultiProofTargets, StateProofError> {
220 let mut proof_targets = MultiProofTargets::default();
221 for hashed_address in state.accounts.keys() {
222 proof_targets.insert(*hashed_address, B256Set::default());
223 }
224 for (hashed_address, storage) in &state.storages {
225 let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
226 if storage.wiped {
227 let mut storage_cursor =
229 self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
230 let mut current_entry = storage_cursor.seek(B256::ZERO)?;
232 while let Some((hashed_slot, _)) = current_entry {
233 storage_keys.insert(hashed_slot);
234 current_entry = storage_cursor.next()?;
235 }
236 }
237 proof_targets.insert(*hashed_address, storage_keys);
238 }
239 Ok(proof_targets)
240 }
241}
242
243#[derive(Debug, Clone)]
244struct WitnessTrieNodeProviderFactory<F> {
245 provider_factory: F,
247 tx: mpsc::Sender<Bytes>,
249}
250
251impl<F> WitnessTrieNodeProviderFactory<F> {
252 const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
253 Self { provider_factory, tx }
254 }
255}
256
257impl<F> TrieNodeProviderFactory for WitnessTrieNodeProviderFactory<F>
258where
259 F: TrieNodeProviderFactory,
260 F::AccountNodeProvider: TrieNodeProvider,
261 F::StorageNodeProvider: TrieNodeProvider,
262{
263 type AccountNodeProvider = WitnessTrieNodeProvider<F::AccountNodeProvider>;
264 type StorageNodeProvider = WitnessTrieNodeProvider<F::StorageNodeProvider>;
265
266 fn account_node_provider(&self) -> Self::AccountNodeProvider {
267 let provider = self.provider_factory.account_node_provider();
268 WitnessTrieNodeProvider::new(provider, self.tx.clone())
269 }
270
271 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
272 let provider = self.provider_factory.storage_node_provider(account);
273 WitnessTrieNodeProvider::new(provider, self.tx.clone())
274 }
275}
276
277#[derive(Debug)]
278struct WitnessTrieNodeProvider<P> {
279 provider: P,
281 tx: mpsc::Sender<Bytes>,
283}
284
285impl<P> WitnessTrieNodeProvider<P> {
286 const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
287 Self { provider, tx }
288 }
289}
290
291impl<P: TrieNodeProvider> TrieNodeProvider for WitnessTrieNodeProvider<P> {
292 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
293 let maybe_node = self.provider.trie_node(path)?;
294 if let Some(node) = &maybe_node {
295 self.tx
296 .send(node.node.clone())
297 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
298 }
299 Ok(maybe_node)
300 }
301}