use super::ExecutedBlock;
use alloy_primitives::{
keccak256,
map::{HashMap, HashSet},
Address, BlockNumber, Bytes, StorageKey, StorageValue, B256,
};
use reth_errors::ProviderResult;
use reth_primitives::{Account, Bytecode};
use reth_storage_api::{
AccountReader, BlockHashReader, StateProofProvider, StateProvider, StateRootProvider,
StorageRootProvider,
};
use reth_trie::{
updates::TrieUpdates, AccountProof, HashedPostState, HashedStorage, MultiProof, TrieInput,
};
use std::sync::OnceLock;
#[allow(missing_debug_implementations)]
pub struct MemoryOverlayStateProviderRef<'a> {
pub(crate) historical: Box<dyn StateProvider + 'a>,
pub(crate) in_memory: Vec<ExecutedBlock>,
pub(crate) trie_state: OnceLock<MemoryOverlayTrieState>,
}
#[allow(missing_debug_implementations)]
pub struct MemoryOverlayStateProvider {
pub(crate) historical: Box<dyn StateProvider>,
pub(crate) in_memory: Vec<ExecutedBlock>,
pub(crate) trie_state: OnceLock<MemoryOverlayTrieState>,
}
macro_rules! impl_state_provider {
([$($tokens:tt)*],$type:ty, $historical_type:ty) => {
impl $($tokens)* $type {
pub fn new(historical: $historical_type, in_memory: Vec<ExecutedBlock>) -> Self {
Self { historical, in_memory, trie_state: OnceLock::new() }
}
pub fn boxed(self) -> $historical_type {
Box::new(self)
}
fn trie_state(&self) -> &MemoryOverlayTrieState {
self.trie_state.get_or_init(|| {
let mut trie_state = MemoryOverlayTrieState::default();
for block in self.in_memory.iter().rev() {
trie_state.state.extend_ref(block.hashed_state.as_ref());
trie_state.nodes.extend_ref(block.trie.as_ref());
}
trie_state
})
}
}
impl $($tokens)* BlockHashReader for $type {
fn block_hash(&self, number: BlockNumber) -> ProviderResult<Option<B256>> {
for block in &self.in_memory {
if block.block.number == number {
return Ok(Some(block.block.hash()))
}
}
self.historical.block_hash(number)
}
fn canonical_hashes_range(
&self,
start: BlockNumber,
end: BlockNumber,
) -> ProviderResult<Vec<B256>> {
let range = start..end;
let mut earliest_block_number = None;
let mut in_memory_hashes = Vec::new();
for block in &self.in_memory {
if range.contains(&block.block.number) {
in_memory_hashes.insert(0, block.block.hash());
earliest_block_number = Some(block.block.number);
}
}
let mut hashes =
self.historical.canonical_hashes_range(start, earliest_block_number.unwrap_or(end))?;
hashes.append(&mut in_memory_hashes);
Ok(hashes)
}
}
impl $($tokens)* AccountReader for $type {
fn basic_account(&self, address: Address) -> ProviderResult<Option<Account>> {
for block in &self.in_memory {
if let Some(account) = block.execution_output.account(&address) {
return Ok(account)
}
}
self.historical.basic_account(address)
}
}
impl $($tokens)* StateRootProvider for $type {
fn state_root(&self, state: HashedPostState) -> ProviderResult<B256> {
self.state_root_from_nodes(TrieInput::from_state(state))
}
fn state_root_from_nodes(&self, mut input: TrieInput) -> ProviderResult<B256> {
let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
input.prepend_cached(nodes, state);
self.historical.state_root_from_nodes(input)
}
fn state_root_with_updates(
&self,
state: HashedPostState,
) -> ProviderResult<(B256, TrieUpdates)> {
self.state_root_from_nodes_with_updates(TrieInput::from_state(state))
}
fn state_root_from_nodes_with_updates(
&self,
mut input: TrieInput,
) -> ProviderResult<(B256, TrieUpdates)> {
let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
input.prepend_cached(nodes, state);
self.historical.state_root_from_nodes_with_updates(input)
}
}
impl $($tokens)* StorageRootProvider for $type {
fn storage_root(&self, address: Address, storage: HashedStorage) -> ProviderResult<B256> {
let state = &self.trie_state().state;
let mut hashed_storage =
state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
hashed_storage.extend(&storage);
self.historical.storage_root(address, hashed_storage)
}
fn storage_proof(
&self,
address: Address,
slot: B256,
storage: HashedStorage,
) -> ProviderResult<reth_trie::StorageProof> {
let state = &self.trie_state().state;
let mut hashed_storage =
state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
hashed_storage.extend(&storage);
self.historical.storage_proof(address, slot, hashed_storage)
}
}
impl $($tokens)* StateProofProvider for $type {
fn proof(
&self,
mut input: TrieInput,
address: Address,
slots: &[B256],
) -> ProviderResult<AccountProof> {
let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
input.prepend_cached(nodes, state);
self.historical.proof(input, address, slots)
}
fn multiproof(
&self,
mut input: TrieInput,
targets: HashMap<B256, HashSet<B256>>,
) -> ProviderResult<MultiProof> {
let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
input.prepend_cached(nodes, state);
self.historical.multiproof(input, targets)
}
fn witness(
&self,
mut input: TrieInput,
target: HashedPostState,
) -> ProviderResult<HashMap<B256, Bytes>> {
let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
input.prepend_cached(nodes, state);
self.historical.witness(input, target)
}
}
impl $($tokens)* StateProvider for $type {
fn storage(
&self,
address: Address,
storage_key: StorageKey,
) -> ProviderResult<Option<StorageValue>> {
for block in &self.in_memory {
if let Some(value) = block.execution_output.storage(&address, storage_key.into()) {
return Ok(Some(value))
}
}
self.historical.storage(address, storage_key)
}
fn bytecode_by_hash(&self, code_hash: B256) -> ProviderResult<Option<Bytecode>> {
for block in &self.in_memory {
if let Some(contract) = block.execution_output.bytecode(&code_hash) {
return Ok(Some(contract))
}
}
self.historical.bytecode_by_hash(code_hash)
}
}
};
}
impl_state_provider!([], MemoryOverlayStateProvider, Box<dyn StateProvider>);
impl_state_provider!([<'a>], MemoryOverlayStateProviderRef<'a>, Box<dyn StateProvider + 'a>);
#[derive(Clone, Default, Debug)]
pub(crate) struct MemoryOverlayTrieState {
pub(crate) nodes: TrieUpdates,
pub(crate) state: HashedPostState,
}