1use super::ExecutedBlockWithTrieUpdates;
2use alloy_consensus::BlockHeader;
3use alloy_primitives::{keccak256, Address, BlockNumber, Bytes, StorageKey, StorageValue, B256};
4use reth_errors::ProviderResult;
5use reth_primitives_traits::{Account, Bytecode, NodePrimitives};
6use reth_storage_api::{
7 AccountReader, BlockHashReader, BytecodeReader, HashedPostStateProvider, StateProofProvider,
8 StateProvider, StateRootProvider, StorageRootProvider,
9};
10use reth_trie::{
11 updates::TrieUpdates, AccountProof, HashedPostState, HashedStorage, MultiProof,
12 MultiProofTargets, StorageMultiProof, TrieInput,
13};
14use revm_database::BundleState;
15use std::sync::OnceLock;
16
17#[expect(missing_debug_implementations)]
20pub struct MemoryOverlayStateProviderRef<
21 'a,
22 N: NodePrimitives = reth_ethereum_primitives::EthPrimitives,
23> {
24 pub(crate) historical: Box<dyn StateProvider + 'a>,
26 pub(crate) in_memory: Vec<ExecutedBlockWithTrieUpdates<N>>,
28 pub(crate) trie_input: OnceLock<TrieInput>,
30}
31
32pub type MemoryOverlayStateProvider<N> = MemoryOverlayStateProviderRef<'static, N>;
35
36impl<'a, N: NodePrimitives> MemoryOverlayStateProviderRef<'a, N> {
37 pub fn new(
45 historical: Box<dyn StateProvider + 'a>,
46 in_memory: Vec<ExecutedBlockWithTrieUpdates<N>>,
47 ) -> Self {
48 Self { historical, in_memory, trie_input: OnceLock::new() }
49 }
50
51 pub fn boxed(self) -> Box<dyn StateProvider + 'a> {
53 Box::new(self)
54 }
55
56 fn trie_input(&self) -> &TrieInput {
58 self.trie_input.get_or_init(|| {
59 TrieInput::from_blocks(
60 self.in_memory
61 .iter()
62 .rev()
63 .map(|block| (block.hashed_state.as_ref(), block.trie.as_ref())),
64 )
65 })
66 }
67}
68
69impl<N: NodePrimitives> BlockHashReader for MemoryOverlayStateProviderRef<'_, N> {
70 fn block_hash(&self, number: BlockNumber) -> ProviderResult<Option<B256>> {
71 for block in &self.in_memory {
72 if block.recovered_block().number() == number {
73 return Ok(Some(block.recovered_block().hash()));
74 }
75 }
76
77 self.historical.block_hash(number)
78 }
79
80 fn canonical_hashes_range(
81 &self,
82 start: BlockNumber,
83 end: BlockNumber,
84 ) -> ProviderResult<Vec<B256>> {
85 let range = start..end;
86 let mut earliest_block_number = None;
87 let mut in_memory_hashes = Vec::with_capacity(range.size_hint().0);
88
89 for block in &self.in_memory {
91 let block_num = block.recovered_block().number();
92 if range.contains(&block_num) {
93 in_memory_hashes.push(block.recovered_block().hash());
94 earliest_block_number = Some(block_num);
95 }
96 }
97
98 in_memory_hashes.reverse();
102
103 let mut hashes =
104 self.historical.canonical_hashes_range(start, earliest_block_number.unwrap_or(end))?;
105 hashes.append(&mut in_memory_hashes);
106 Ok(hashes)
107 }
108}
109
110impl<N: NodePrimitives> AccountReader for MemoryOverlayStateProviderRef<'_, N> {
111 fn basic_account(&self, address: &Address) -> ProviderResult<Option<Account>> {
112 for block in &self.in_memory {
113 if let Some(account) = block.execution_output.account(address) {
114 return Ok(account);
115 }
116 }
117
118 self.historical.basic_account(address)
119 }
120}
121
122impl<N: NodePrimitives> StateRootProvider for MemoryOverlayStateProviderRef<'_, N> {
123 fn state_root(&self, state: HashedPostState) -> ProviderResult<B256> {
124 self.state_root_from_nodes(TrieInput::from_state(state))
125 }
126
127 fn state_root_from_nodes(&self, mut input: TrieInput) -> ProviderResult<B256> {
128 input.prepend_self(self.trie_input().clone());
129 self.historical.state_root_from_nodes(input)
130 }
131
132 fn state_root_with_updates(
133 &self,
134 state: HashedPostState,
135 ) -> ProviderResult<(B256, TrieUpdates)> {
136 self.state_root_from_nodes_with_updates(TrieInput::from_state(state))
137 }
138
139 fn state_root_from_nodes_with_updates(
140 &self,
141 mut input: TrieInput,
142 ) -> ProviderResult<(B256, TrieUpdates)> {
143 input.prepend_self(self.trie_input().clone());
144 self.historical.state_root_from_nodes_with_updates(input)
145 }
146}
147
148impl<N: NodePrimitives> StorageRootProvider for MemoryOverlayStateProviderRef<'_, N> {
149 fn storage_root(&self, address: Address, storage: HashedStorage) -> ProviderResult<B256> {
151 let state = &self.trie_input().state;
152 let mut hashed_storage =
153 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
154 hashed_storage.extend(&storage);
155 self.historical.storage_root(address, hashed_storage)
156 }
157
158 fn storage_proof(
160 &self,
161 address: Address,
162 slot: B256,
163 storage: HashedStorage,
164 ) -> ProviderResult<reth_trie::StorageProof> {
165 let state = &self.trie_input().state;
166 let mut hashed_storage =
167 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
168 hashed_storage.extend(&storage);
169 self.historical.storage_proof(address, slot, hashed_storage)
170 }
171
172 fn storage_multiproof(
174 &self,
175 address: Address,
176 slots: &[B256],
177 storage: HashedStorage,
178 ) -> ProviderResult<StorageMultiProof> {
179 let state = &self.trie_input().state;
180 let mut hashed_storage =
181 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
182 hashed_storage.extend(&storage);
183 self.historical.storage_multiproof(address, slots, hashed_storage)
184 }
185}
186
187impl<N: NodePrimitives> StateProofProvider for MemoryOverlayStateProviderRef<'_, N> {
188 fn proof(
189 &self,
190 mut input: TrieInput,
191 address: Address,
192 slots: &[B256],
193 ) -> ProviderResult<AccountProof> {
194 input.prepend_self(self.trie_input().clone());
195 self.historical.proof(input, address, slots)
196 }
197
198 fn multiproof(
199 &self,
200 mut input: TrieInput,
201 targets: MultiProofTargets,
202 ) -> ProviderResult<MultiProof> {
203 input.prepend_self(self.trie_input().clone());
204 self.historical.multiproof(input, targets)
205 }
206
207 fn witness(&self, mut input: TrieInput, target: HashedPostState) -> ProviderResult<Vec<Bytes>> {
208 input.prepend_self(self.trie_input().clone());
209 self.historical.witness(input, target)
210 }
211}
212
213impl<N: NodePrimitives> HashedPostStateProvider for MemoryOverlayStateProviderRef<'_, N> {
214 fn hashed_post_state(&self, bundle_state: &BundleState) -> HashedPostState {
215 self.historical.hashed_post_state(bundle_state)
216 }
217}
218
219impl<N: NodePrimitives> StateProvider for MemoryOverlayStateProviderRef<'_, N> {
220 fn storage(
221 &self,
222 address: Address,
223 storage_key: StorageKey,
224 ) -> ProviderResult<Option<StorageValue>> {
225 for block in &self.in_memory {
226 if let Some(value) = block.execution_output.storage(&address, storage_key.into()) {
227 return Ok(Some(value));
228 }
229 }
230
231 self.historical.storage(address, storage_key)
232 }
233}
234
235impl<N: NodePrimitives> BytecodeReader for MemoryOverlayStateProviderRef<'_, N> {
236 fn bytecode_by_hash(&self, code_hash: &B256) -> ProviderResult<Option<Bytecode>> {
237 for block in &self.in_memory {
238 if let Some(contract) = block.execution_output.bytecode(code_hash) {
239 return Ok(Some(contract));
240 }
241 }
242
243 self.historical.bytecode_by_hash(code_hash)
244 }
245}