1use super::ExecutedBlock;
2use alloy_consensus::BlockHeader;
3use alloy_primitives::{keccak256, Address, BlockNumber, Bytes, StorageKey, StorageValue, B256};
4use reth_errors::ProviderResult;
5use reth_primitives_traits::{Account, Bytecode, NodePrimitives};
6use reth_storage_api::{
7 AccountReader, BlockHashReader, BytecodeReader, HashedPostStateProvider, StateProofProvider,
8 StateProvider, StateRootProvider, StorageRootProvider,
9};
10use reth_trie::{
11 updates::TrieUpdates, AccountProof, HashedPostState, HashedStorage, MultiProof,
12 MultiProofTargets, StorageMultiProof, TrieInput,
13};
14use revm_database::BundleState;
15use std::sync::OnceLock;
16
17#[expect(missing_debug_implementations)]
20pub struct MemoryOverlayStateProviderRef<
21 'a,
22 N: NodePrimitives = reth_ethereum_primitives::EthPrimitives,
23> {
24 pub(crate) historical: Box<dyn StateProvider + 'a>,
26 pub(crate) in_memory: Vec<ExecutedBlock<N>>,
28 pub(crate) trie_input: OnceLock<TrieInput>,
30}
31
32pub type MemoryOverlayStateProvider<N> = MemoryOverlayStateProviderRef<'static, N>;
35
36impl<'a, N: NodePrimitives> MemoryOverlayStateProviderRef<'a, N> {
37 pub fn new(historical: Box<dyn StateProvider + 'a>, in_memory: Vec<ExecutedBlock<N>>) -> Self {
45 Self { historical, in_memory, trie_input: OnceLock::new() }
46 }
47
48 pub fn boxed(self) -> Box<dyn StateProvider + 'a> {
50 Box::new(self)
51 }
52
53 fn trie_input(&self) -> &TrieInput {
55 self.trie_input.get_or_init(|| {
56 TrieInput::from_blocks(
57 self.in_memory
58 .iter()
59 .rev()
60 .map(|block| (block.hashed_state.as_ref(), block.trie_updates.as_ref())),
61 )
62 })
63 }
64
65 fn merged_hashed_storage(&self, address: Address, storage: HashedStorage) -> HashedStorage {
66 let state = &self.trie_input().state;
67 let mut hashed = state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
68 hashed.extend(&storage);
69 hashed
70 }
71}
72
73impl<N: NodePrimitives> BlockHashReader for MemoryOverlayStateProviderRef<'_, N> {
74 fn block_hash(&self, number: BlockNumber) -> ProviderResult<Option<B256>> {
75 for block in &self.in_memory {
76 if block.recovered_block().number() == number {
77 return Ok(Some(block.recovered_block().hash()));
78 }
79 }
80
81 self.historical.block_hash(number)
82 }
83
84 fn canonical_hashes_range(
85 &self,
86 start: BlockNumber,
87 end: BlockNumber,
88 ) -> ProviderResult<Vec<B256>> {
89 let range = start..end;
90 let mut earliest_block_number = None;
91 let mut in_memory_hashes = Vec::with_capacity(range.size_hint().0);
92
93 for block in &self.in_memory {
95 let block_num = block.recovered_block().number();
96 if range.contains(&block_num) {
97 in_memory_hashes.push(block.recovered_block().hash());
98 earliest_block_number = Some(block_num);
99 }
100 }
101
102 in_memory_hashes.reverse();
106
107 let mut hashes =
108 self.historical.canonical_hashes_range(start, earliest_block_number.unwrap_or(end))?;
109 hashes.append(&mut in_memory_hashes);
110 Ok(hashes)
111 }
112}
113
114impl<N: NodePrimitives> AccountReader for MemoryOverlayStateProviderRef<'_, N> {
115 fn basic_account(&self, address: &Address) -> ProviderResult<Option<Account>> {
116 for block in &self.in_memory {
117 if let Some(account) = block.execution_output.account(address) {
118 return Ok(account);
119 }
120 }
121
122 self.historical.basic_account(address)
123 }
124}
125
126impl<N: NodePrimitives> StateRootProvider for MemoryOverlayStateProviderRef<'_, N> {
127 fn state_root(&self, state: HashedPostState) -> ProviderResult<B256> {
128 self.state_root_from_nodes(TrieInput::from_state(state))
129 }
130
131 fn state_root_from_nodes(&self, mut input: TrieInput) -> ProviderResult<B256> {
132 input.prepend_self(self.trie_input().clone());
133 self.historical.state_root_from_nodes(input)
134 }
135
136 fn state_root_with_updates(
137 &self,
138 state: HashedPostState,
139 ) -> ProviderResult<(B256, TrieUpdates)> {
140 self.state_root_from_nodes_with_updates(TrieInput::from_state(state))
141 }
142
143 fn state_root_from_nodes_with_updates(
144 &self,
145 mut input: TrieInput,
146 ) -> ProviderResult<(B256, TrieUpdates)> {
147 input.prepend_self(self.trie_input().clone());
148 self.historical.state_root_from_nodes_with_updates(input)
149 }
150}
151
152impl<N: NodePrimitives> StorageRootProvider for MemoryOverlayStateProviderRef<'_, N> {
153 fn storage_root(&self, address: Address, storage: HashedStorage) -> ProviderResult<B256> {
155 let merged = self.merged_hashed_storage(address, storage);
156 self.historical.storage_root(address, merged)
157 }
158
159 fn storage_proof(
161 &self,
162 address: Address,
163 slot: B256,
164 storage: HashedStorage,
165 ) -> ProviderResult<reth_trie::StorageProof> {
166 let merged = self.merged_hashed_storage(address, storage);
167 self.historical.storage_proof(address, slot, merged)
168 }
169
170 fn storage_multiproof(
172 &self,
173 address: Address,
174 slots: &[B256],
175 storage: HashedStorage,
176 ) -> ProviderResult<StorageMultiProof> {
177 let merged = self.merged_hashed_storage(address, storage);
178 self.historical.storage_multiproof(address, slots, merged)
179 }
180}
181
182impl<N: NodePrimitives> StateProofProvider for MemoryOverlayStateProviderRef<'_, N> {
183 fn proof(
184 &self,
185 mut input: TrieInput,
186 address: Address,
187 slots: &[B256],
188 ) -> ProviderResult<AccountProof> {
189 input.prepend_self(self.trie_input().clone());
190 self.historical.proof(input, address, slots)
191 }
192
193 fn multiproof(
194 &self,
195 mut input: TrieInput,
196 targets: MultiProofTargets,
197 ) -> ProviderResult<MultiProof> {
198 input.prepend_self(self.trie_input().clone());
199 self.historical.multiproof(input, targets)
200 }
201
202 fn witness(&self, mut input: TrieInput, target: HashedPostState) -> ProviderResult<Vec<Bytes>> {
203 input.prepend_self(self.trie_input().clone());
204 self.historical.witness(input, target)
205 }
206}
207
208impl<N: NodePrimitives> HashedPostStateProvider for MemoryOverlayStateProviderRef<'_, N> {
209 fn hashed_post_state(&self, bundle_state: &BundleState) -> HashedPostState {
210 self.historical.hashed_post_state(bundle_state)
211 }
212}
213
214impl<N: NodePrimitives> StateProvider for MemoryOverlayStateProviderRef<'_, N> {
215 fn storage(
216 &self,
217 address: Address,
218 storage_key: StorageKey,
219 ) -> ProviderResult<Option<StorageValue>> {
220 for block in &self.in_memory {
221 if let Some(value) = block.execution_output.storage(&address, storage_key.into()) {
222 return Ok(Some(value));
223 }
224 }
225
226 self.historical.storage(address, storage_key)
227 }
228}
229
230impl<N: NodePrimitives> BytecodeReader for MemoryOverlayStateProviderRef<'_, N> {
231 fn bytecode_by_hash(&self, code_hash: &B256) -> ProviderResult<Option<Bytecode>> {
232 for block in &self.in_memory {
233 if let Some(contract) = block.execution_output.bytecode(code_hash) {
234 return Ok(Some(contract));
235 }
236 }
237
238 self.historical.bytecode_by_hash(code_hash)
239 }
240}