use alloy_primitives::{
map::{Entry, HashMap},
Address, B256, U256,
};
use core::cell::RefCell;
use revm::primitives::{
db::{Database, DatabaseRef},
AccountInfo, Bytecode,
};
#[derive(Debug, Clone, Default)]
pub struct CachedReads {
accounts: HashMap<Address, CachedAccount>,
contracts: HashMap<B256, Bytecode>,
block_hashes: HashMap<u64, B256>,
}
impl CachedReads {
pub fn as_db<DB>(&mut self, db: DB) -> CachedReadsDBRef<'_, DB> {
self.as_db_mut(db).into_db()
}
pub fn as_db_mut<DB>(&mut self, db: DB) -> CachedReadsDbMut<'_, DB> {
CachedReadsDbMut { cached: self, db }
}
pub fn insert_account(
&mut self,
address: Address,
info: AccountInfo,
storage: HashMap<U256, U256>,
) {
self.accounts.insert(address, CachedAccount { info: Some(info), storage });
}
pub fn extend(&mut self, other: Self) {
self.accounts.extend(other.accounts);
self.contracts.extend(other.contracts);
self.block_hashes.extend(other.block_hashes);
}
}
#[derive(Debug)]
pub struct CachedReadsDbMut<'a, DB> {
pub cached: &'a mut CachedReads,
pub db: DB,
}
impl<'a, DB> CachedReadsDbMut<'a, DB> {
pub const fn into_db(self) -> CachedReadsDBRef<'a, DB> {
CachedReadsDBRef { inner: RefCell::new(self) }
}
pub const fn inner(&self) -> &DB {
&self.db
}
}
impl<DB, T> AsRef<T> for CachedReadsDbMut<'_, DB>
where
DB: AsRef<T>,
{
fn as_ref(&self) -> &T {
self.inner().as_ref()
}
}
impl<DB: DatabaseRef> Database for CachedReadsDbMut<'_, DB> {
type Error = <DB as DatabaseRef>::Error;
fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
let basic = match self.cached.accounts.entry(address) {
Entry::Occupied(entry) => entry.get().info.clone(),
Entry::Vacant(entry) => {
entry.insert(CachedAccount::new(self.db.basic_ref(address)?)).info.clone()
}
};
Ok(basic)
}
fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
let code = match self.cached.contracts.entry(code_hash) {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => entry.insert(self.db.code_by_hash_ref(code_hash)?).clone(),
};
Ok(code)
}
fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
match self.cached.accounts.entry(address) {
Entry::Occupied(mut acc_entry) => match acc_entry.get_mut().storage.entry(index) {
Entry::Occupied(entry) => Ok(*entry.get()),
Entry::Vacant(entry) => Ok(*entry.insert(self.db.storage_ref(address, index)?)),
},
Entry::Vacant(acc_entry) => {
let info = self.db.basic_ref(address)?;
let (account, value) = if info.is_some() {
let value = self.db.storage_ref(address, index)?;
let mut account = CachedAccount::new(info);
account.storage.insert(index, value);
(account, value)
} else {
(CachedAccount::new(info), U256::ZERO)
};
acc_entry.insert(account);
Ok(value)
}
}
}
fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
let code = match self.cached.block_hashes.entry(number) {
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => *entry.insert(self.db.block_hash_ref(number)?),
};
Ok(code)
}
}
#[derive(Debug)]
pub struct CachedReadsDBRef<'a, DB> {
pub inner: RefCell<CachedReadsDbMut<'a, DB>>,
}
impl<DB: DatabaseRef> DatabaseRef for CachedReadsDBRef<'_, DB> {
type Error = <DB as DatabaseRef>::Error;
fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
self.inner.borrow_mut().basic(address)
}
fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
self.inner.borrow_mut().code_by_hash(code_hash)
}
fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
self.inner.borrow_mut().storage(address, index)
}
fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
self.inner.borrow_mut().block_hash(number)
}
}
#[derive(Debug, Clone)]
struct CachedAccount {
info: Option<AccountInfo>,
storage: HashMap<U256, U256>,
}
impl CachedAccount {
fn new(info: Option<AccountInfo>) -> Self {
Self { info, storage: HashMap::default() }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extend_with_two_cached_reads() {
let hash1 = B256::from_slice(&[1u8; 32]);
let hash2 = B256::from_slice(&[2u8; 32]);
let address1 = Address::from_slice(&[1u8; 20]);
let address2 = Address::from_slice(&[2u8; 20]);
let mut primary = {
let mut cache = CachedReads::default();
cache.accounts.insert(address1, CachedAccount::new(Some(AccountInfo::default())));
cache.contracts.insert(hash1, Bytecode::default());
cache.block_hashes.insert(1, hash1);
cache
};
let additional = {
let mut cache = CachedReads::default();
cache.accounts.insert(address2, CachedAccount::new(Some(AccountInfo::default())));
cache.contracts.insert(hash2, Bytecode::default());
cache.block_hashes.insert(2, hash2);
cache
};
primary.extend(additional);
assert!(
primary.accounts.len() == 2 &&
primary.contracts.len() == 2 &&
primary.block_hashes.len() == 2,
"All maps should contain 2 entries"
);
assert!(
primary.accounts.contains_key(&address1) &&
primary.accounts.contains_key(&address2) &&
primary.contracts.contains_key(&hash1) &&
primary.contracts.contains_key(&hash2) &&
primary.block_hashes.get(&1) == Some(&hash1) &&
primary.block_hashes.get(&2) == Some(&hash2),
"All expected entries should be present"
);
}
}