1use super::ExecutedBlockWithTrieUpdates;
2use alloy_consensus::BlockHeader;
3use alloy_primitives::{keccak256, Address, BlockNumber, Bytes, StorageKey, StorageValue, B256};
4use reth_errors::ProviderResult;
5use reth_primitives_traits::{Account, Bytecode, NodePrimitives};
6use reth_storage_api::{
7 AccountReader, BlockHashReader, HashedPostStateProvider, StateProofProvider, StateProvider,
8 StateRootProvider, StorageRootProvider,
9};
10use reth_trie::{
11 updates::TrieUpdates, AccountProof, HashedPostState, HashedStorage, MultiProof,
12 MultiProofTargets, StorageMultiProof, TrieInput,
13};
14use revm_database::BundleState;
15use std::sync::OnceLock;
16
17#[allow(missing_debug_implementations)]
20pub struct MemoryOverlayStateProviderRef<
21 'a,
22 N: NodePrimitives = reth_ethereum_primitives::EthPrimitives,
23> {
24 pub(crate) historical: Box<dyn StateProvider + 'a>,
26 pub(crate) in_memory: Vec<ExecutedBlockWithTrieUpdates<N>>,
28 pub(crate) trie_state: OnceLock<MemoryOverlayTrieState>,
30}
31
32pub type MemoryOverlayStateProvider<N> = MemoryOverlayStateProviderRef<'static, N>;
35
36impl<'a, N: NodePrimitives> MemoryOverlayStateProviderRef<'a, N> {
37 pub fn new(
45 historical: Box<dyn StateProvider + 'a>,
46 in_memory: Vec<ExecutedBlockWithTrieUpdates<N>>,
47 ) -> Self {
48 Self { historical, in_memory, trie_state: OnceLock::new() }
49 }
50
51 pub fn boxed(self) -> Box<dyn StateProvider + 'a> {
53 Box::new(self)
54 }
55
56 fn trie_state(&self) -> &MemoryOverlayTrieState {
58 self.trie_state.get_or_init(|| {
59 let mut trie_state = MemoryOverlayTrieState::default();
60 for block in self.in_memory.iter().rev() {
61 trie_state.state.extend_ref(block.hashed_state.as_ref());
62 trie_state.nodes.extend_ref(block.trie.as_ref());
63 }
64 trie_state
65 })
66 }
67}
68
69impl<N: NodePrimitives> BlockHashReader for MemoryOverlayStateProviderRef<'_, N> {
70 fn block_hash(&self, number: BlockNumber) -> ProviderResult<Option<B256>> {
71 for block in &self.in_memory {
72 if block.recovered_block().number() == number {
73 return Ok(Some(block.recovered_block().hash()));
74 }
75 }
76
77 self.historical.block_hash(number)
78 }
79
80 fn canonical_hashes_range(
81 &self,
82 start: BlockNumber,
83 end: BlockNumber,
84 ) -> ProviderResult<Vec<B256>> {
85 let range = start..end;
86 let mut earliest_block_number = None;
87 let mut in_memory_hashes = Vec::new();
88 for block in &self.in_memory {
89 if range.contains(&block.recovered_block().number()) {
90 in_memory_hashes.insert(0, block.recovered_block().hash());
91 earliest_block_number = Some(block.recovered_block().number());
92 }
93 }
94
95 let mut hashes =
96 self.historical.canonical_hashes_range(start, earliest_block_number.unwrap_or(end))?;
97 hashes.append(&mut in_memory_hashes);
98 Ok(hashes)
99 }
100}
101
102impl<N: NodePrimitives> AccountReader for MemoryOverlayStateProviderRef<'_, N> {
103 fn basic_account(&self, address: &Address) -> ProviderResult<Option<Account>> {
104 for block in &self.in_memory {
105 if let Some(account) = block.execution_output.account(address) {
106 return Ok(account);
107 }
108 }
109
110 self.historical.basic_account(address)
111 }
112}
113
114impl<N: NodePrimitives> StateRootProvider for MemoryOverlayStateProviderRef<'_, N> {
115 fn state_root(&self, state: HashedPostState) -> ProviderResult<B256> {
116 self.state_root_from_nodes(TrieInput::from_state(state))
117 }
118
119 fn state_root_from_nodes(&self, mut input: TrieInput) -> ProviderResult<B256> {
120 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
121 input.prepend_cached(nodes, state);
122 self.historical.state_root_from_nodes(input)
123 }
124
125 fn state_root_with_updates(
126 &self,
127 state: HashedPostState,
128 ) -> ProviderResult<(B256, TrieUpdates)> {
129 self.state_root_from_nodes_with_updates(TrieInput::from_state(state))
130 }
131
132 fn state_root_from_nodes_with_updates(
133 &self,
134 mut input: TrieInput,
135 ) -> ProviderResult<(B256, TrieUpdates)> {
136 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
137 input.prepend_cached(nodes, state);
138 self.historical.state_root_from_nodes_with_updates(input)
139 }
140}
141
142impl<N: NodePrimitives> StorageRootProvider for MemoryOverlayStateProviderRef<'_, N> {
143 fn storage_root(&self, address: Address, storage: HashedStorage) -> ProviderResult<B256> {
145 let state = &self.trie_state().state;
146 let mut hashed_storage =
147 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
148 hashed_storage.extend(&storage);
149 self.historical.storage_root(address, hashed_storage)
150 }
151
152 fn storage_proof(
154 &self,
155 address: Address,
156 slot: B256,
157 storage: HashedStorage,
158 ) -> ProviderResult<reth_trie::StorageProof> {
159 let state = &self.trie_state().state;
160 let mut hashed_storage =
161 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
162 hashed_storage.extend(&storage);
163 self.historical.storage_proof(address, slot, hashed_storage)
164 }
165
166 fn storage_multiproof(
168 &self,
169 address: Address,
170 slots: &[B256],
171 storage: HashedStorage,
172 ) -> ProviderResult<StorageMultiProof> {
173 let state = &self.trie_state().state;
174 let mut hashed_storage =
175 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
176 hashed_storage.extend(&storage);
177 self.historical.storage_multiproof(address, slots, hashed_storage)
178 }
179}
180
181impl<N: NodePrimitives> StateProofProvider for MemoryOverlayStateProviderRef<'_, N> {
182 fn proof(
183 &self,
184 mut input: TrieInput,
185 address: Address,
186 slots: &[B256],
187 ) -> ProviderResult<AccountProof> {
188 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
189 input.prepend_cached(nodes, state);
190 self.historical.proof(input, address, slots)
191 }
192
193 fn multiproof(
194 &self,
195 mut input: TrieInput,
196 targets: MultiProofTargets,
197 ) -> ProviderResult<MultiProof> {
198 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
199 input.prepend_cached(nodes, state);
200 self.historical.multiproof(input, targets)
201 }
202
203 fn witness(&self, mut input: TrieInput, target: HashedPostState) -> ProviderResult<Vec<Bytes>> {
204 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
205 input.prepend_cached(nodes, state);
206 self.historical.witness(input, target)
207 }
208}
209
210impl<N: NodePrimitives> HashedPostStateProvider for MemoryOverlayStateProviderRef<'_, N> {
211 fn hashed_post_state(&self, bundle_state: &BundleState) -> HashedPostState {
212 self.historical.hashed_post_state(bundle_state)
213 }
214}
215
216impl<N: NodePrimitives> StateProvider for MemoryOverlayStateProviderRef<'_, N> {
217 fn storage(
218 &self,
219 address: Address,
220 storage_key: StorageKey,
221 ) -> ProviderResult<Option<StorageValue>> {
222 for block in &self.in_memory {
223 if let Some(value) = block.execution_output.storage(&address, storage_key.into()) {
224 return Ok(Some(value));
225 }
226 }
227
228 self.historical.storage(address, storage_key)
229 }
230
231 fn bytecode_by_hash(&self, code_hash: &B256) -> ProviderResult<Option<Bytecode>> {
232 for block in &self.in_memory {
233 if let Some(contract) = block.execution_output.bytecode(code_hash) {
234 return Ok(Some(contract));
235 }
236 }
237
238 self.historical.bytecode_by_hash(code_hash)
239 }
240}
241
242#[derive(Clone, Default, Debug)]
244pub(crate) struct MemoryOverlayTrieState {
245 pub(crate) nodes: TrieUpdates,
247 pub(crate) state: HashedPostState,
249}