use crate::{
hashed_cursor::{HashedCursorFactory, HashedStorageCursor},
node_iter::{TrieElement, TrieNodeIter},
prefix_set::{PrefixSet, TriePrefixSets},
progress::{IntermediateStateRootState, StateRootProgress},
stats::TrieTracker,
trie_cursor::TrieCursorFactory,
updates::{StorageTrieUpdates, TrieUpdates},
walker::TrieWalker,
HashBuilder, Nibbles, TrieAccount,
};
use alloy_consensus::EMPTY_ROOT_HASH;
use alloy_primitives::{keccak256, Address, B256};
use alloy_rlp::{BufMut, Encodable};
use reth_execution_errors::{StateRootError, StorageRootError};
use tracing::trace;
#[cfg(feature = "metrics")]
use crate::metrics::{StateRootMetrics, TrieRootMetrics};
#[derive(Debug)]
pub struct StateRoot<T, H> {
pub trie_cursor_factory: T,
pub hashed_cursor_factory: H,
pub prefix_sets: TriePrefixSets,
previous_state: Option<IntermediateStateRootState>,
threshold: u64,
#[cfg(feature = "metrics")]
metrics: StateRootMetrics,
}
impl<T, H> StateRoot<T, H> {
pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
Self {
trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: TriePrefixSets::default(),
previous_state: None,
threshold: 100_000,
#[cfg(feature = "metrics")]
metrics: StateRootMetrics::default(),
}
}
pub fn with_prefix_sets(mut self, prefix_sets: TriePrefixSets) -> Self {
self.prefix_sets = prefix_sets;
self
}
pub const fn with_threshold(mut self, threshold: u64) -> Self {
self.threshold = threshold;
self
}
pub const fn with_no_threshold(mut self) -> Self {
self.threshold = u64::MAX;
self
}
pub fn with_intermediate_state(mut self, state: Option<IntermediateStateRootState>) -> Self {
self.previous_state = state;
self
}
pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> StateRoot<T, HF> {
StateRoot {
trie_cursor_factory: self.trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: self.prefix_sets,
threshold: self.threshold,
previous_state: self.previous_state,
#[cfg(feature = "metrics")]
metrics: self.metrics,
}
}
pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StateRoot<TF, H> {
StateRoot {
trie_cursor_factory,
hashed_cursor_factory: self.hashed_cursor_factory,
prefix_sets: self.prefix_sets,
threshold: self.threshold,
previous_state: self.previous_state,
#[cfg(feature = "metrics")]
metrics: self.metrics,
}
}
}
impl<T, H> StateRoot<T, H>
where
T: TrieCursorFactory + Clone,
H: HashedCursorFactory + Clone,
{
pub fn root_with_updates(self) -> Result<(B256, TrieUpdates), StateRootError> {
match self.with_no_threshold().calculate(true)? {
StateRootProgress::Complete(root, _, updates) => Ok((root, updates)),
StateRootProgress::Progress(..) => unreachable!(), }
}
pub fn root(self) -> Result<B256, StateRootError> {
match self.calculate(false)? {
StateRootProgress::Complete(root, _, _) => Ok(root),
StateRootProgress::Progress(..) => unreachable!(), }
}
pub fn root_with_progress(self) -> Result<StateRootProgress, StateRootError> {
self.calculate(true)
}
fn calculate(self, retain_updates: bool) -> Result<StateRootProgress, StateRootError> {
trace!(target: "trie::state_root", "calculating state root");
let mut tracker = TrieTracker::default();
let mut trie_updates = TrieUpdates::default();
let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
let (mut hash_builder, mut account_node_iter) = match self.previous_state {
Some(state) => {
let hash_builder = state.hash_builder.with_updates(retain_updates);
let walker = TrieWalker::from_stack(
trie_cursor,
state.walker_stack,
self.prefix_sets.account_prefix_set,
)
.with_deletions_retained(retain_updates);
let node_iter = TrieNodeIter::new(walker, hashed_account_cursor)
.with_last_hashed_key(state.last_account_key);
(hash_builder, node_iter)
}
None => {
let hash_builder = HashBuilder::default().with_updates(retain_updates);
let walker = TrieWalker::new(trie_cursor, self.prefix_sets.account_prefix_set)
.with_deletions_retained(retain_updates);
let node_iter = TrieNodeIter::new(walker, hashed_account_cursor);
(hash_builder, node_iter)
}
};
let mut account_rlp = Vec::with_capacity(128);
let mut hashed_entries_walked = 0;
let mut updated_storage_nodes = 0;
while let Some(node) = account_node_iter.try_next()? {
match node {
TrieElement::Branch(node) => {
tracker.inc_branch();
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_address, account) => {
tracker.inc_leaf();
hashed_entries_walked += 1;
let storage_root_calculator = StorageRoot::new_hashed(
self.trie_cursor_factory.clone(),
self.hashed_cursor_factory.clone(),
hashed_address,
#[cfg(feature = "metrics")]
self.metrics.storage_trie.clone(),
)
.with_prefix_set(
self.prefix_sets
.storage_prefix_sets
.get(&hashed_address)
.cloned()
.unwrap_or_default(),
);
let storage_root = if retain_updates {
let (root, storage_slots_walked, updates) =
storage_root_calculator.root_with_updates()?;
hashed_entries_walked += storage_slots_walked;
updated_storage_nodes += updates.len();
trie_updates.insert_storage_updates(hashed_address, updates);
root
} else {
storage_root_calculator.root()?
};
account_rlp.clear();
let account = TrieAccount::from((account, storage_root));
account.encode(&mut account_rlp as &mut dyn BufMut);
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
let total_updates_len = updated_storage_nodes +
account_node_iter.walker.removed_keys_len() +
hash_builder.updates_len();
if retain_updates && total_updates_len as u64 >= self.threshold {
let (walker_stack, walker_deleted_keys) = account_node_iter.walker.split();
trie_updates.removed_nodes.extend(walker_deleted_keys);
let (hash_builder, hash_builder_updates) = hash_builder.split();
trie_updates.account_nodes.extend(hash_builder_updates);
let state = IntermediateStateRootState {
hash_builder,
walker_stack,
last_account_key: hashed_address,
};
return Ok(StateRootProgress::Progress(
Box::new(state),
hashed_entries_walked,
trie_updates,
))
}
}
}
}
let root = hash_builder.root();
trie_updates.finalize(
account_node_iter.walker,
hash_builder,
self.prefix_sets.destroyed_accounts,
);
let stats = tracker.finish();
#[cfg(feature = "metrics")]
self.metrics.state_trie.record(stats);
trace!(
target: "trie::state_root",
%root,
duration = ?stats.duration(),
branches_added = stats.branches_added(),
leaves_added = stats.leaves_added(),
"calculated state root"
);
Ok(StateRootProgress::Complete(root, hashed_entries_walked, trie_updates))
}
}
#[derive(Debug)]
pub struct StorageRoot<T, H> {
pub trie_cursor_factory: T,
pub hashed_cursor_factory: H,
pub hashed_address: B256,
pub prefix_set: PrefixSet,
#[cfg(feature = "metrics")]
metrics: TrieRootMetrics,
}
impl<T, H> StorageRoot<T, H> {
pub fn new(
trie_cursor_factory: T,
hashed_cursor_factory: H,
address: Address,
#[cfg(feature = "metrics")] metrics: TrieRootMetrics,
) -> Self {
Self::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
keccak256(address),
#[cfg(feature = "metrics")]
metrics,
)
}
pub fn new_hashed(
trie_cursor_factory: T,
hashed_cursor_factory: H,
hashed_address: B256,
#[cfg(feature = "metrics")] metrics: TrieRootMetrics,
) -> Self {
Self {
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
prefix_set: PrefixSet::default(),
#[cfg(feature = "metrics")]
metrics,
}
}
pub fn with_prefix_set(mut self, prefix_set: PrefixSet) -> Self {
self.prefix_set = prefix_set;
self
}
pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> StorageRoot<T, HF> {
StorageRoot {
trie_cursor_factory: self.trie_cursor_factory,
hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
#[cfg(feature = "metrics")]
metrics: self.metrics,
}
}
pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StorageRoot<TF, H> {
StorageRoot {
trie_cursor_factory,
hashed_cursor_factory: self.hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
#[cfg(feature = "metrics")]
metrics: self.metrics,
}
}
}
impl<T, H> StorageRoot<T, H>
where
T: TrieCursorFactory,
H: HashedCursorFactory,
{
pub fn root_with_updates(self) -> Result<(B256, usize, StorageTrieUpdates), StorageRootError> {
self.calculate(true)
}
pub fn root(self) -> Result<B256, StorageRootError> {
let (root, _, _) = self.calculate(false)?;
Ok(root)
}
pub fn calculate(
self,
retain_updates: bool,
) -> Result<(B256, usize, StorageTrieUpdates), StorageRootError> {
trace!(target: "trie::storage_root", hashed_address = ?self.hashed_address, "calculating storage root");
let mut hashed_storage_cursor =
self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
if hashed_storage_cursor.is_storage_empty()? {
return Ok((EMPTY_ROOT_HASH, 0, StorageTrieUpdates::deleted()))
}
let mut tracker = TrieTracker::default();
let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
let walker =
TrieWalker::new(trie_cursor, self.prefix_set).with_deletions_retained(retain_updates);
let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor);
while let Some(node) = storage_node_iter.try_next()? {
match node {
TrieElement::Branch(node) => {
tracker.inc_branch();
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_slot, value) => {
tracker.inc_leaf();
hash_builder.add_leaf(
Nibbles::unpack(hashed_slot),
alloy_rlp::encode_fixed_size(&value).as_ref(),
);
}
}
}
let root = hash_builder.root();
let mut trie_updates = StorageTrieUpdates::default();
trie_updates.finalize(storage_node_iter.walker, hash_builder);
let stats = tracker.finish();
#[cfg(feature = "metrics")]
self.metrics.record(stats);
trace!(
target: "trie::storage_root",
%root,
hashed_address = %self.hashed_address,
duration = ?stats.duration(),
branches_added = stats.branches_added(),
leaves_added = stats.leaves_added(),
"calculated storage root"
);
let storage_slots_walked = stats.leaves_added() as usize;
Ok((root, storage_slots_walked, trie_updates))
}
}