1use crate::{
2 hashed_cursor::{HashedCursor, HashedCursorFactory},
3 prefix_set::TriePrefixSetsMut,
4 proof::{Proof, ProofTrieNodeProviderFactory},
5 trie_cursor::TrieCursorFactory,
6};
7use alloy_rlp::EMPTY_STRING_CODE;
8use alloy_trie::EMPTY_ROOT_HASH;
9use reth_trie_common::HashedPostState;
10use reth_trie_sparse::SparseTrieInterface;
11
12use alloy_primitives::{
13 keccak256,
14 map::{B256Map, B256Set, Entry, HashMap},
15 Bytes, B256,
16};
17use itertools::Itertools;
18use reth_execution_errors::{
19 SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
20 TrieWitnessError,
21};
22use reth_trie_common::{MultiProofTargets, Nibbles};
23use reth_trie_sparse::{
24 provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory},
25 SerialSparseTrie, SparseStateTrie,
26};
27use std::sync::mpsc;
28
29#[derive(Debug)]
31pub struct TrieWitness<T, H> {
32 trie_cursor_factory: T,
34 hashed_cursor_factory: H,
36 prefix_sets: TriePrefixSetsMut,
38 always_include_root_node: bool,
43 witness: B256Map<Bytes>,
45}
46
47impl<T, H> TrieWitness<T, H> {
48 pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
50 Self {
51 trie_cursor_factory,
52 hashed_cursor_factory,
53 prefix_sets: TriePrefixSetsMut::default(),
54 always_include_root_node: false,
55 witness: HashMap::default(),
56 }
57 }
58
59 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
61 TrieWitness {
62 trie_cursor_factory,
63 hashed_cursor_factory: self.hashed_cursor_factory,
64 prefix_sets: self.prefix_sets,
65 always_include_root_node: self.always_include_root_node,
66 witness: self.witness,
67 }
68 }
69
70 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
72 TrieWitness {
73 trie_cursor_factory: self.trie_cursor_factory,
74 hashed_cursor_factory,
75 prefix_sets: self.prefix_sets,
76 always_include_root_node: self.always_include_root_node,
77 witness: self.witness,
78 }
79 }
80
81 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
83 self.prefix_sets = prefix_sets;
84 self
85 }
86
87 pub const fn always_include_root_node(mut self) -> Self {
91 self.always_include_root_node = true;
92 self
93 }
94}
95
96impl<T, H> TrieWitness<T, H>
97where
98 T: TrieCursorFactory + Clone,
99 H: HashedCursorFactory + Clone,
100{
101 pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
108 let is_state_empty = state.is_empty();
109 if is_state_empty && !self.always_include_root_node {
110 return Ok(Default::default())
111 }
112
113 let proof_targets = if is_state_empty {
114 MultiProofTargets::account(B256::ZERO)
115 } else {
116 self.get_proof_targets(&state)?
117 };
118 let multiproof =
119 Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
120 .with_prefix_sets_mut(self.prefix_sets.clone())
121 .multiproof(proof_targets.clone())?;
122
123 if is_state_empty {
126 let (root_hash, root_node) = if let Some(root_node) =
127 multiproof.account_subtree.into_inner().remove(&Nibbles::default())
128 {
129 (keccak256(&root_node), root_node)
130 } else {
131 (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
132 };
133 return Ok(B256Map::from_iter([(root_hash, root_node)]))
134 }
135
136 for account_node in multiproof.account_subtree.values() {
138 if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
139 entry.insert(account_node.clone());
140 }
141 }
142 for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
143 if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
144 entry.insert(storage_node.clone());
145 }
146 }
147
148 let (tx, rx) = mpsc::channel();
149 let blinded_provider_factory = WitnessTrieNodeProviderFactory::new(
150 ProofTrieNodeProviderFactory::new(self.trie_cursor_factory, self.hashed_cursor_factory),
151 tx,
152 );
153 let mut sparse_trie = SparseStateTrie::<SerialSparseTrie>::new();
154 sparse_trie.reveal_multiproof(multiproof)?;
155
156 for (hashed_address, hashed_slots) in
158 proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
159 {
160 let provider = blinded_provider_factory.storage_node_provider(hashed_address);
162 let storage = state.storages.get(&hashed_address);
163 let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
164 SparseStateTrieErrorKind::SparseStorageTrie(
165 hashed_address,
166 SparseTrieErrorKind::Blind,
167 ),
168 )?;
169 for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
170 let storage_nibbles = Nibbles::unpack(hashed_slot);
171 let maybe_leaf_value = storage
172 .and_then(|s| s.storage.get(&hashed_slot))
173 .filter(|v| !v.is_zero())
174 .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
175
176 if let Some(value) = maybe_leaf_value {
177 storage_trie.update_leaf(storage_nibbles, value, &provider).map_err(|err| {
178 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
179 })?;
180 } else {
181 storage_trie.remove_leaf(&storage_nibbles, &provider).map_err(|err| {
182 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
183 })?;
184 }
185 }
186
187 let account = state
188 .accounts
189 .get(&hashed_address)
190 .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
191 .unwrap_or_default();
192
193 if !sparse_trie.update_account(hashed_address, account, &blinded_provider_factory)? {
194 let nibbles = Nibbles::unpack(hashed_address);
195 sparse_trie.remove_account_leaf(&nibbles, &blinded_provider_factory)?;
196 }
197
198 while let Ok(node) = rx.try_recv() {
199 self.witness.insert(keccak256(&node), node);
200 }
201 }
202
203 Ok(self.witness)
204 }
205
206 fn get_proof_targets(
210 &self,
211 state: &HashedPostState,
212 ) -> Result<MultiProofTargets, StateProofError> {
213 let mut proof_targets = MultiProofTargets::default();
214 for hashed_address in state.accounts.keys() {
215 proof_targets.insert(*hashed_address, B256Set::default());
216 }
217 for (hashed_address, storage) in &state.storages {
218 let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
219 if storage.wiped {
220 let mut storage_cursor =
222 self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
223 let mut current_entry = storage_cursor.seek(B256::ZERO)?;
225 while let Some((hashed_slot, _)) = current_entry {
226 storage_keys.insert(hashed_slot);
227 current_entry = storage_cursor.next()?;
228 }
229 }
230 proof_targets.insert(*hashed_address, storage_keys);
231 }
232 Ok(proof_targets)
233 }
234}
235
236#[derive(Debug, Clone)]
237struct WitnessTrieNodeProviderFactory<F> {
238 provider_factory: F,
240 tx: mpsc::Sender<Bytes>,
242}
243
244impl<F> WitnessTrieNodeProviderFactory<F> {
245 const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
246 Self { provider_factory, tx }
247 }
248}
249
250impl<F> TrieNodeProviderFactory for WitnessTrieNodeProviderFactory<F>
251where
252 F: TrieNodeProviderFactory,
253 F::AccountNodeProvider: TrieNodeProvider,
254 F::StorageNodeProvider: TrieNodeProvider,
255{
256 type AccountNodeProvider = WitnessTrieNodeProvider<F::AccountNodeProvider>;
257 type StorageNodeProvider = WitnessTrieNodeProvider<F::StorageNodeProvider>;
258
259 fn account_node_provider(&self) -> Self::AccountNodeProvider {
260 let provider = self.provider_factory.account_node_provider();
261 WitnessTrieNodeProvider::new(provider, self.tx.clone())
262 }
263
264 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
265 let provider = self.provider_factory.storage_node_provider(account);
266 WitnessTrieNodeProvider::new(provider, self.tx.clone())
267 }
268}
269
270#[derive(Debug)]
271struct WitnessTrieNodeProvider<P> {
272 provider: P,
274 tx: mpsc::Sender<Bytes>,
276}
277
278impl<P> WitnessTrieNodeProvider<P> {
279 const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
280 Self { provider, tx }
281 }
282}
283
284impl<P: TrieNodeProvider> TrieNodeProvider for WitnessTrieNodeProvider<P> {
285 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
286 let maybe_node = self.provider.trie_node(path)?;
287 if let Some(node) = &maybe_node {
288 self.tx
289 .send(node.node.clone())
290 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
291 }
292 Ok(maybe_node)
293 }
294}