1use crate::primitives::alloy_primitives::{BlockNumber, StorageKey, StorageValue};
2use alloy_primitives::{Address, B256, U256};
3use core::ops::{Deref, DerefMut};
4use reth_primitives_traits::Account;
5use reth_storage_api::{AccountReader, BlockHashReader, BytecodeReader, StateProvider};
6use reth_storage_errors::provider::{ProviderError, ProviderResult};
7use revm::{bytecode::Bytecode, state::AccountInfo, Database, DatabaseRef};
8
9pub trait EvmStateProvider {
13 fn basic_account(&self, address: &Address) -> ProviderResult<Option<Account>>;
17
18 fn block_hash(&self, number: BlockNumber) -> ProviderResult<Option<B256>>;
21
22 fn bytecode_by_hash(
24 &self,
25 code_hash: &B256,
26 ) -> ProviderResult<Option<reth_primitives_traits::Bytecode>>;
27
28 fn storage(
30 &self,
31 account: Address,
32 storage_key: StorageKey,
33 ) -> ProviderResult<Option<StorageValue>>;
34}
35
36impl<T: StateProvider> EvmStateProvider for T {
38 fn basic_account(&self, address: &Address) -> ProviderResult<Option<Account>> {
39 <T as AccountReader>::basic_account(self, address)
40 }
41
42 fn block_hash(&self, number: BlockNumber) -> ProviderResult<Option<B256>> {
43 <T as BlockHashReader>::block_hash(self, number)
44 }
45
46 fn bytecode_by_hash(
47 &self,
48 code_hash: &B256,
49 ) -> ProviderResult<Option<reth_primitives_traits::Bytecode>> {
50 <T as BytecodeReader>::bytecode_by_hash(self, code_hash)
51 }
52
53 fn storage(
54 &self,
55 account: Address,
56 storage_key: StorageKey,
57 ) -> ProviderResult<Option<StorageValue>> {
58 <T as StateProvider>::storage(self, account, storage_key)
59 }
60}
61
62#[derive(Clone)]
65pub struct StateProviderDatabase<DB>(pub DB);
66
67impl<DB> StateProviderDatabase<DB> {
68 pub const fn new(db: DB) -> Self {
70 Self(db)
71 }
72
73 pub fn into_inner(self) -> DB {
75 self.0
76 }
77}
78
79impl<DB> core::fmt::Debug for StateProviderDatabase<DB> {
80 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
81 f.debug_struct("StateProviderDatabase").finish_non_exhaustive()
82 }
83}
84
85impl<DB> AsRef<DB> for StateProviderDatabase<DB> {
86 fn as_ref(&self) -> &DB {
87 self
88 }
89}
90
91impl<DB> Deref for StateProviderDatabase<DB> {
92 type Target = DB;
93
94 fn deref(&self) -> &Self::Target {
95 &self.0
96 }
97}
98
99impl<DB> DerefMut for StateProviderDatabase<DB> {
100 fn deref_mut(&mut self) -> &mut Self::Target {
101 &mut self.0
102 }
103}
104
105impl<DB: EvmStateProvider> Database for StateProviderDatabase<DB> {
106 type Error = ProviderError;
107
108 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
113 self.basic_ref(address)
114 }
115
116 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
120 self.code_by_hash_ref(code_hash)
121 }
122
123 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
127 self.storage_ref(address, index)
128 }
129
130 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
135 self.block_hash_ref(number)
136 }
137}
138
139impl<DB: EvmStateProvider> DatabaseRef for StateProviderDatabase<DB> {
140 type Error = <Self as Database>::Error;
141
142 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
147 Ok(self.basic_account(&address)?.map(Into::into))
148 }
149
150 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
154 Ok(self.bytecode_by_hash(&code_hash)?.unwrap_or_default().0)
155 }
156
157 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
161 Ok(self.0.storage(address, B256::new(index.to_be_bytes()))?.unwrap_or_default())
162 }
163
164 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
168 Ok(self.0.block_hash(number)?.unwrap_or_default())
170 }
171}
172
173#[derive(Clone)]
183pub struct DatabaseStateProvider<DB>(pub DB);
184
185impl<DB> DatabaseStateProvider<DB> {
186 pub const fn new(db: DB) -> Self {
188 Self(db)
189 }
190
191 pub fn into_inner(self) -> DB {
193 self.0
194 }
195
196 pub const fn inner(&self) -> &DB {
198 &self.0
199 }
200}
201
202impl<DB> core::fmt::Debug for DatabaseStateProvider<DB> {
203 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
204 f.debug_struct("DatabaseStateProvider").finish_non_exhaustive()
205 }
206}
207
208impl<DB> AccountReader for DatabaseStateProvider<DB>
209where
210 DB: DatabaseRef<Error = ProviderError>,
211{
212 fn basic_account(&self, address: &Address) -> ProviderResult<Option<Account>> {
213 Ok(self.0.basic_ref(*address)?.map(Into::into))
214 }
215}
216
217impl<DB> BytecodeReader for DatabaseStateProvider<DB>
218where
219 DB: DatabaseRef<Error = ProviderError>,
220{
221 fn bytecode_by_hash(
222 &self,
223 code_hash: &B256,
224 ) -> ProviderResult<Option<reth_primitives_traits::Bytecode>> {
225 Ok(Some(reth_primitives_traits::Bytecode(self.0.code_by_hash_ref(*code_hash)?)))
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::cached::CachedReads;
233 use alloy_consensus::constants::KECCAK_EMPTY;
234 use alloy_primitives::Bytes;
235 use core::sync::atomic::{AtomicUsize, Ordering};
236 use std::sync::Arc;
237
238 #[derive(Clone)]
239 struct CountingDatabaseRef {
240 address: Address,
241 code_hash: B256,
242 account: Option<AccountInfo>,
243 bytecode: Bytecode,
244 account_reads: Arc<AtomicUsize>,
245 bytecode_reads: Arc<AtomicUsize>,
246 fail_account_reads: bool,
247 fail_bytecode_reads: bool,
248 }
249
250 impl CountingDatabaseRef {
251 fn new(address: Address, account: Option<AccountInfo>, bytecode: Bytecode) -> Self {
252 let code_hash = account.as_ref().map(|account| account.code_hash).unwrap_or_default();
253 Self {
254 address,
255 code_hash,
256 account,
257 bytecode,
258 account_reads: Arc::default(),
259 bytecode_reads: Arc::default(),
260 fail_account_reads: false,
261 fail_bytecode_reads: false,
262 }
263 }
264 }
265
266 impl DatabaseRef for CountingDatabaseRef {
267 type Error = ProviderError;
268
269 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
270 if self.fail_account_reads {
271 return Err(ProviderError::UnsupportedProvider)
272 }
273
274 self.account_reads.fetch_add(1, Ordering::Relaxed);
275 Ok((address == self.address).then(|| self.account.clone()).flatten())
276 }
277
278 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
279 if self.fail_bytecode_reads {
280 return Err(ProviderError::UnsupportedProvider)
281 }
282
283 self.bytecode_reads.fetch_add(1, Ordering::Relaxed);
284 Ok(if code_hash == self.code_hash {
285 self.bytecode.clone()
286 } else {
287 Bytecode::default()
288 })
289 }
290
291 fn storage_ref(&self, _address: Address, _index: U256) -> Result<U256, Self::Error> {
292 Ok(U256::ZERO)
293 }
294
295 fn block_hash_ref(&self, _number: u64) -> Result<B256, Self::Error> {
296 Ok(B256::ZERO)
297 }
298 }
299
300 #[test]
301 fn database_state_provider_maps_missing_account() {
302 let address = Address::repeat_byte(0x01);
303 let db = CountingDatabaseRef::new(address, None, Bytecode::default());
304 let provider = DatabaseStateProvider::new(db);
305
306 assert_eq!(provider.basic_account(&address).unwrap(), None);
307 }
308
309 #[test]
310 fn database_state_provider_maps_empty_code_hash() {
311 let address = Address::repeat_byte(0x01);
312 let account = AccountInfo {
313 nonce: 7,
314 balance: U256::from(42),
315 code_hash: KECCAK_EMPTY,
316 code: None,
317 account_id: None,
318 };
319 let db = CountingDatabaseRef::new(address, Some(account), Bytecode::default());
320 let provider = DatabaseStateProvider::new(db);
321
322 assert_eq!(
323 provider.basic_account(&address).unwrap(),
324 Some(Account { nonce: 7, balance: U256::from(42), bytecode_hash: None })
325 );
326 }
327
328 #[test]
329 fn database_state_provider_maps_code_hash_and_bytecode() {
330 let address = Address::repeat_byte(0x01);
331 let code_hash = B256::repeat_byte(0x42);
332 let bytecode = Bytecode::new_raw(Bytes::from_static(&[0x60, 0x00]));
333 let account = AccountInfo {
334 nonce: 7,
335 balance: U256::from(42),
336 code_hash,
337 code: Some(bytecode.clone()),
338 account_id: None,
339 };
340 let db = CountingDatabaseRef::new(address, Some(account), bytecode.clone());
341 let provider = DatabaseStateProvider::new(db);
342
343 assert_eq!(
344 provider.basic_account(&address).unwrap(),
345 Some(Account { nonce: 7, balance: U256::from(42), bytecode_hash: Some(code_hash) })
346 );
347 assert_eq!(
348 provider.bytecode_by_hash(&code_hash).unwrap(),
349 Some(reth_primitives_traits::Bytecode(bytecode))
350 );
351 }
352
353 #[test]
354 fn database_state_provider_wraps_default_bytecode_for_unknown_hash() {
355 let address = Address::repeat_byte(0x01);
356 let unknown_hash = B256::repeat_byte(0x42);
357 let db = CountingDatabaseRef::new(address, None, Bytecode::default());
358 let provider = DatabaseStateProvider::new(db);
359
360 assert_eq!(
361 provider.bytecode_by_hash(&unknown_hash).unwrap(),
362 Some(reth_primitives_traits::Bytecode(Bytecode::default()))
363 );
364 }
365
366 #[test]
367 fn database_state_provider_propagates_database_errors() {
368 let address = Address::repeat_byte(0x01);
369 let code_hash = B256::repeat_byte(0x42);
370 let db = CountingDatabaseRef {
371 fail_account_reads: true,
372 fail_bytecode_reads: true,
373 ..CountingDatabaseRef::new(address, None, Bytecode::default())
374 };
375 let provider = DatabaseStateProvider::new(db);
376
377 assert!(matches!(
378 provider.basic_account(&address),
379 Err(ProviderError::UnsupportedProvider)
380 ));
381 assert!(matches!(
382 provider.bytecode_by_hash(&code_hash),
383 Err(ProviderError::UnsupportedProvider)
384 ));
385 }
386
387 #[test]
388 fn database_state_provider_uses_cached_reads() {
389 let address = Address::repeat_byte(0x01);
390 let code_hash = B256::repeat_byte(0x42);
391 let bytecode = Bytecode::new_raw(Bytes::from_static(&[0x60, 0x00]));
392 let account = AccountInfo {
393 nonce: 7,
394 balance: U256::from(42),
395 code_hash,
396 code: Some(bytecode.clone()),
397 account_id: None,
398 };
399 let db = CountingDatabaseRef::new(address, Some(account), bytecode.clone());
400 let account_reads = db.account_reads.clone();
401 let bytecode_reads = db.bytecode_reads.clone();
402 let mut cached_reads = CachedReads::default();
403 let provider = DatabaseStateProvider::new(cached_reads.as_db(db));
404
405 assert_eq!(
406 provider.basic_account(&address).unwrap(),
407 Some(Account { nonce: 7, balance: U256::from(42), bytecode_hash: Some(code_hash) })
408 );
409 assert_eq!(
410 provider.basic_account(&address).unwrap(),
411 Some(Account { nonce: 7, balance: U256::from(42), bytecode_hash: Some(code_hash) })
412 );
413 assert_eq!(account_reads.load(Ordering::Relaxed), 1);
414
415 assert_eq!(
416 provider.bytecode_by_hash(&code_hash).unwrap(),
417 Some(reth_primitives_traits::Bytecode(bytecode.clone()))
418 );
419 assert_eq!(
420 provider.bytecode_by_hash(&code_hash).unwrap(),
421 Some(reth_primitives_traits::Bytecode(bytecode))
422 );
423 assert_eq!(bytecode_reads.load(Ordering::Relaxed), 1);
424 }
425}