1use crate::{
2 hashed_cursor::{HashedCursor, HashedCursorFactory},
3 proof::{Proof, ProofTrieNodeProviderFactory},
4 trie_cursor::TrieCursorFactory,
5};
6use alloy_rlp::EMPTY_STRING_CODE;
7use alloy_trie::EMPTY_ROOT_HASH;
8use reth_trie_common::HashedPostState;
9use reth_trie_sparse::SparseTrie;
10
11use alloy_primitives::{
12 keccak256,
13 map::{B256Map, B256Set, Entry, HashMap},
14 Bytes, B256,
15};
16use itertools::Itertools;
17use reth_execution_errors::{
18 SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError,
19 TrieWitnessError,
20};
21use reth_trie_common::{MultiProofTargets, Nibbles};
22use reth_trie_sparse::{
23 provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory},
24 SparseStateTrie,
25};
26use std::sync::mpsc;
27
28#[derive(Debug)]
30pub struct TrieWitness<T, H> {
31 trie_cursor_factory: T,
33 hashed_cursor_factory: H,
35 always_include_root_node: bool,
40 witness: B256Map<Bytes>,
42}
43
44impl<T, H> TrieWitness<T, H> {
45 pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
47 Self {
48 trie_cursor_factory,
49 hashed_cursor_factory,
50 always_include_root_node: false,
51 witness: HashMap::default(),
52 }
53 }
54
55 pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> TrieWitness<TF, H> {
57 TrieWitness {
58 trie_cursor_factory,
59 hashed_cursor_factory: self.hashed_cursor_factory,
60 always_include_root_node: self.always_include_root_node,
61 witness: self.witness,
62 }
63 }
64
65 pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
67 TrieWitness {
68 trie_cursor_factory: self.trie_cursor_factory,
69 hashed_cursor_factory,
70 always_include_root_node: self.always_include_root_node,
71 witness: self.witness,
72 }
73 }
74
75 pub const fn always_include_root_node(mut self) -> Self {
79 self.always_include_root_node = true;
80 self
81 }
82}
83
84impl<T, H> TrieWitness<T, H>
85where
86 T: TrieCursorFactory + Clone,
87 H: HashedCursorFactory + Clone,
88{
89 pub fn compute(mut self, state: HashedPostState) -> Result<B256Map<Bytes>, TrieWitnessError> {
96 let is_state_empty = state.is_empty();
97 if is_state_empty && !self.always_include_root_node {
98 return Ok(Default::default())
99 }
100
101 let proof_targets = if is_state_empty {
102 MultiProofTargets::account(B256::ZERO)
103 } else {
104 self.get_proof_targets(&state)?
105 };
106 let multiproof =
107 Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
108 .multiproof(proof_targets.clone())?;
109
110 if is_state_empty {
113 let (root_hash, root_node) = if let Some(root_node) =
114 multiproof.account_subtree.into_inner().remove(&Nibbles::default())
115 {
116 (keccak256(&root_node), root_node)
117 } else {
118 (EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))
119 };
120 return Ok(B256Map::from_iter([(root_hash, root_node)]))
121 }
122
123 for account_node in multiproof.account_subtree.values() {
125 if let Entry::Vacant(entry) = self.witness.entry(keccak256(account_node.as_ref())) {
126 entry.insert(account_node.clone());
127 }
128 }
129 for storage_node in multiproof.storages.values().flat_map(|s| s.subtree.values()) {
130 if let Entry::Vacant(entry) = self.witness.entry(keccak256(storage_node.as_ref())) {
131 entry.insert(storage_node.clone());
132 }
133 }
134
135 let (tx, rx) = mpsc::channel();
136 let blinded_provider_factory = WitnessTrieNodeProviderFactory::new(
137 ProofTrieNodeProviderFactory::new(self.trie_cursor_factory, self.hashed_cursor_factory),
138 tx,
139 );
140 let mut sparse_trie = SparseStateTrie::new();
141 sparse_trie.reveal_multiproof(multiproof)?;
142
143 for (hashed_address, hashed_slots) in
145 proof_targets.into_iter().sorted_unstable_by_key(|(ha, _)| *ha)
146 {
147 let provider = blinded_provider_factory.storage_node_provider(hashed_address);
149 let storage = state.storages.get(&hashed_address);
150 let storage_trie = sparse_trie.storage_trie_mut(&hashed_address).ok_or(
151 SparseStateTrieErrorKind::SparseStorageTrie(
152 hashed_address,
153 SparseTrieErrorKind::Blind,
154 ),
155 )?;
156 for hashed_slot in hashed_slots.into_iter().sorted_unstable() {
157 let storage_nibbles = Nibbles::unpack(hashed_slot);
158 let maybe_leaf_value = storage
159 .and_then(|s| s.storage.get(&hashed_slot))
160 .filter(|v| !v.is_zero())
161 .map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
162
163 if let Some(value) = maybe_leaf_value {
164 storage_trie.update_leaf(storage_nibbles, value, &provider).map_err(|err| {
165 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
166 })?;
167 } else {
168 storage_trie.remove_leaf(&storage_nibbles, &provider).map_err(|err| {
169 SparseStateTrieErrorKind::SparseStorageTrie(hashed_address, err.into_kind())
170 })?;
171 }
172 }
173
174 let account = state
175 .accounts
176 .get(&hashed_address)
177 .ok_or(TrieWitnessError::MissingAccount(hashed_address))?
178 .unwrap_or_default();
179
180 if !sparse_trie.update_account(hashed_address, account, &blinded_provider_factory)? {
181 let nibbles = Nibbles::unpack(hashed_address);
182 sparse_trie.remove_account_leaf(&nibbles, &blinded_provider_factory)?;
183 }
184
185 while let Ok(node) = rx.try_recv() {
186 self.witness.insert(keccak256(&node), node);
187 }
188 }
189
190 Ok(self.witness)
191 }
192
193 fn get_proof_targets(
197 &self,
198 state: &HashedPostState,
199 ) -> Result<MultiProofTargets, StateProofError> {
200 let mut proof_targets = MultiProofTargets::default();
201 for hashed_address in state.accounts.keys() {
202 proof_targets.insert(*hashed_address, B256Set::default());
203 }
204 for (hashed_address, storage) in &state.storages {
205 let mut storage_keys = storage.storage.keys().copied().collect::<B256Set>();
206 if storage.wiped {
207 let mut storage_cursor =
209 self.hashed_cursor_factory.hashed_storage_cursor(*hashed_address)?;
210 let mut current_entry = storage_cursor.seek(B256::ZERO)?;
212 while let Some((hashed_slot, _)) = current_entry {
213 storage_keys.insert(hashed_slot);
214 current_entry = storage_cursor.next()?;
215 }
216 }
217 proof_targets.insert(*hashed_address, storage_keys);
218 }
219 Ok(proof_targets)
220 }
221}
222
223#[derive(Debug, Clone)]
224struct WitnessTrieNodeProviderFactory<F> {
225 provider_factory: F,
227 tx: mpsc::Sender<Bytes>,
229}
230
231impl<F> WitnessTrieNodeProviderFactory<F> {
232 const fn new(provider_factory: F, tx: mpsc::Sender<Bytes>) -> Self {
233 Self { provider_factory, tx }
234 }
235}
236
237impl<F> TrieNodeProviderFactory for WitnessTrieNodeProviderFactory<F>
238where
239 F: TrieNodeProviderFactory,
240 F::AccountNodeProvider: TrieNodeProvider,
241 F::StorageNodeProvider: TrieNodeProvider,
242{
243 type AccountNodeProvider = WitnessTrieNodeProvider<F::AccountNodeProvider>;
244 type StorageNodeProvider = WitnessTrieNodeProvider<F::StorageNodeProvider>;
245
246 fn account_node_provider(&self) -> Self::AccountNodeProvider {
247 let provider = self.provider_factory.account_node_provider();
248 WitnessTrieNodeProvider::new(provider, self.tx.clone())
249 }
250
251 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
252 let provider = self.provider_factory.storage_node_provider(account);
253 WitnessTrieNodeProvider::new(provider, self.tx.clone())
254 }
255}
256
257#[derive(Debug)]
258struct WitnessTrieNodeProvider<P> {
259 provider: P,
261 tx: mpsc::Sender<Bytes>,
263}
264
265impl<P> WitnessTrieNodeProvider<P> {
266 const fn new(provider: P, tx: mpsc::Sender<Bytes>) -> Self {
267 Self { provider, tx }
268 }
269}
270
271impl<P: TrieNodeProvider> TrieNodeProvider for WitnessTrieNodeProvider<P> {
272 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
273 let maybe_node = self.provider.trie_node(path)?;
274 if let Some(node) = &maybe_node {
275 self.tx
276 .send(node.node.clone())
277 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
278 }
279 Ok(maybe_node)
280 }
281}