1use crate::{
2 hashed_cursor::{HashedCursor, HashedCursorFactory},
3 prefix_set::TriePrefixSetsMut,
4 proof::{Proof, ProofBlindedProviderFactory},
5 trie_cursor::TrieCursorFactory,
6};
7use alloy_rlp::EMPTY_STRING_CODE;
8use alloy_trie::EMPTY_ROOT_HASH;
9use reth_trie_common::HashedPostState;
10
11use alloy_primitives::{
12 keccak256,
13 map::{B256Map, B256Set, Entry, HashMap},
14 Bytes, B256,
15};
16use itertools::Itertools;
17use reth_execution_errors::{
18 SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
19 TrieWitnessError,
20};
21use reth_trie_common::{MultiProofTargets, Nibbles};
22use reth_trie_sparse::{
23 blinded::{BlindedProvider, BlindedProviderFactory, RevealedNode},
24 SparseStateTrie,
25};
26use std::sync::{mpsc, Arc};
27
28#[derive(Debug)]
30pub struct TrieWitness<T, H> {
31 trie_cursor_factory: T,
33 hashed_cursor_factory: H,
35 prefix_sets: TriePrefixSetsMut,
37 always_include_root_node: bool,
42 witness: B256Map<Bytes>,
44}
45
46impl<T, H> TrieWitness<T, H> {
47 pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
49 Self {
50 trie_cursor_factory,
51 hashed_cursor_factory,
52 prefix_sets: TriePrefixSetsMut::default(),
53 always_include_root_node: false,
54 witness: HashMap::default(),
55 }
56 }
57
58 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
60 TrieWitness {
61 trie_cursor_factory,
62 hashed_cursor_factory: self.hashed_cursor_factory,
63 prefix_sets: self.prefix_sets,
64 always_include_root_node: self.always_include_root_node,
65 witness: self.witness,
66 }
67 }
68
69 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
71 TrieWitness {
72 trie_cursor_factory: self.trie_cursor_factory,
73 hashed_cursor_factory,
74 prefix_sets: self.prefix_sets,
75 always_include_root_node: self.always_include_root_node,
76 witness: self.witness,
77 }
78 }
79
80 pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
82 self.prefix_sets = prefix_sets;
83 self
84 }
85
86 pub const fn always_include_root_node(mut self) -> Self {
90 self.always_include_root_node = true;
91 self
92 }
93}
94
95impl<T, H> TrieWitness<T, H>
96where
97 T: TrieCursorFactory + Clone + Send + Sync,
98 H: HashedCursorFactory + Clone + Send + Sync,
99{
100 pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
107 let is_state_empty = state.is_empty();
108 if is_state_empty && !self.always_include_root_node {
109 return Ok(Default::default())
110 }
111
112 let proof_targets = if is_state_empty {
113 MultiProofTargets::account(B256::ZERO)
114 } else {
115 self.get_proof_targets(&state)?
116 };
117 let multiproof =
118 Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
119 .with_prefix_sets_mut(self.prefix_sets.clone())
120 .multiproof(proof_targets.clone())?;
121
122 if is_state_empty {
125 let (root_hash, root_node) = if let Some(root_node) =
126 multiproof.account_subtree.into_inner().remove(&Nibbles::default())
127 {
128 (keccak256(&root_node), root_node)
129 } else {
130 (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
131 };
132 return Ok(B256Map::from_iter([(root_hash, root_node)]))
133 }
134
135 for account_node in multiproof.account_subtree.values() {
137 if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
138 entry.insert(account_node.clone());
139 }
140 }
141 for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
142 if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
143 entry.insert(storage_node.clone());
144 }
145 }
146
147 let (tx, rx) = mpsc::channel();
148 let blinded_provider_factory = WitnessBlindedProviderFactory::new(
149 ProofBlindedProviderFactory::new(
150 self.trie_cursor_factory,
151 self.hashed_cursor_factory,
152 Arc::new(self.prefix_sets),
153 ),
154 tx,
155 );
156 let mut sparse_trie = SparseStateTrie::new(blinded_provider_factory);
157 sparse_trie.reveal_multiproof(multiproof)?;
158
159 for (hashed_address, hashed_slots) in
161 proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
162 {
163 let storage = state.storages.get(&hashed_address);
165 let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
166 SparseStateTrieErrorKind::SparseStorageTrie(
167 hashed_address,
168 SparseTrieErrorKind::Blind,
169 ),
170 )?;
171 for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
172 let storage_nibbles = Nibbles::unpack(hashed_slot);
173 let maybe_leaf_value = storage
174 .and_then(|s| s.storage.get(&hashed_slot))
175 .filter(|v| !v.is_zero())
176 .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
177
178 if let Some(value) = maybe_leaf_value {
179 storage_trie.update_leaf(storage_nibbles, value).map_err(|err| {
180 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
181 })?;
182 } else {
183 storage_trie.remove_leaf(&storage_nibbles).map_err(|err| {
184 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
185 })?;
186 }
187 }
188
189 storage_trie.root();
191
192 let account = state
193 .accounts
194 .get(&hashed_address)
195 .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
196 .unwrap_or_default();
197 sparse_trie.update_account(hashed_address, account)?;
198
199 while let Ok(node) = rx.try_recv() {
200 self.witness.insert(keccak256(&node), node);
201 }
202 }
203
204 Ok(self.witness)
205 }
206
207 fn get_proof_targets(
211 &self,
212 state: &HashedPostState,
213 ) -> Result<MultiProofTargets, StateProofError> {
214 let mut proof_targets = MultiProofTargets::default();
215 for hashed_address in state.accounts.keys() {
216 proof_targets.insert(*hashed_address, B256Set::default());
217 }
218 for (hashed_address, storage) in &state.storages {
219 let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
220 if storage.wiped {
221 let mut storage_cursor =
223 self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
224 let mut current_entry = storage_cursor.seek(B256::ZERO)?;
226 while let Some((hashed_slot, _)) = current_entry {
227 storage_keys.insert(hashed_slot);
228 current_entry = storage_cursor.next()?;
229 }
230 }
231 proof_targets.insert(*hashed_address, storage_keys);
232 }
233 Ok(proof_targets)
234 }
235}
236
237#[derive(Debug, Clone)]
238struct WitnessBlindedProviderFactory<F> {
239 provider_factory: F,
241 tx: mpsc::Sender<Bytes>,
243}
244
245impl<F> WitnessBlindedProviderFactory<F> {
246 const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
247 Self { provider_factory, tx }
248 }
249}
250
251impl<F> BlindedProviderFactory for WitnessBlindedProviderFactory<F>
252where
253 F: BlindedProviderFactory,
254 F::AccountNodeProvider: BlindedProvider,
255 F::StorageNodeProvider: BlindedProvider,
256{
257 type AccountNodeProvider = WitnessBlindedProvider<F::AccountNodeProvider>;
258 type StorageNodeProvider = WitnessBlindedProvider<F::StorageNodeProvider>;
259
260 fn account_node_provider(&self) -> Self::AccountNodeProvider {
261 let provider = self.provider_factory.account_node_provider();
262 WitnessBlindedProvider::new(provider, self.tx.clone())
263 }
264
265 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
266 let provider = self.provider_factory.storage_node_provider(account);
267 WitnessBlindedProvider::new(provider, self.tx.clone())
268 }
269}
270
271#[derive(Debug)]
272struct WitnessBlindedProvider<P> {
273 provider: P,
275 tx: mpsc::Sender<Bytes>,
277}
278
279impl<P> WitnessBlindedProvider<P> {
280 const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
281 Self { provider, tx }
282 }
283}
284
285impl<P: BlindedProvider> BlindedProvider for WitnessBlindedProvider<P> {
286 fn blinded_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
287 let maybe_node = self.provider.blinded_node(path)?;
288 if let Some(node) = &maybe_node {
289 self.tx
290 .send(node.node.clone())
291 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
292 }
293 Ok(maybe_node)
294 }
295}