use crate::{
hashed_cursor::{HashedCursorFactory, HashedStorageCursor},
node_iter::{TrieElement, TrieNodeIter},
prefix_set::TriePrefixSetsMut,
trie_cursor::TrieCursorFactory,
walker::TrieWalker,
HashBuilder, Nibbles,
};
use alloy_primitives::{keccak256, Address, B256};
use alloy_rlp::{BufMut, Encodable};
use reth_execution_errors::trie::StateProofError;
use reth_trie_common::{
proof::ProofRetainer, AccountProof, MultiProof, StorageMultiProof, TrieAccount,
};
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct Proof<T, H> {
trie_cursor_factory: T,
hashed_cursor_factory: H,
prefix_sets: TriePrefixSetsMut,
targets: HashMap<B256, HashSet<B256>>,
}
impl<T, H> Proof<T, H> {
pub fn new(t: T, h: H) -> Self {
Self {
trie_cursor_factory: t,
hashed_cursor_factory: h,
prefix_sets: TriePrefixSetsMut::default(),
targets: HashMap::default(),
}
}
pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> Proof<TF, H> {
Proof {
trie_cursor_factory,
hashed_cursor_factory: self.hashed_cursor_factory,
prefix_sets: self.prefix_sets,
targets: self.targets,
}
}
pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> Proof<T, HF> {
Proof {
trie_cursor_factory: self.trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: self.prefix_sets,
targets: self.targets,
}
}
pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
self.prefix_sets = prefix_sets;
self
}
pub fn with_target(self, target: (B256, HashSet<B256>)) -> Self {
self.with_targets(HashMap::from([target]))
}
pub fn with_targets(mut self, targets: HashMap<B256, HashSet<B256>>) -> Self {
self.targets = targets;
self
}
}
impl<T, H> Proof<T, H>
where
T: TrieCursorFactory,
H: HashedCursorFactory + Clone,
{
pub fn account_proof(
self,
address: Address,
slots: &[B256],
) -> Result<AccountProof, StateProofError> {
Ok(self
.with_target((keccak256(address), slots.iter().map(keccak256).collect()))
.multiproof()?
.account_proof(address, slots)?)
}
pub fn multiproof(&self) -> Result<MultiProof, StateProofError> {
let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
let mut prefix_set = self.prefix_sets.account_prefix_set.clone();
prefix_set.extend(self.targets.keys().map(Nibbles::unpack));
let walker = TrieWalker::new(trie_cursor, prefix_set.freeze());
let retainer = ProofRetainer::from_iter(self.targets.keys().map(Nibbles::unpack));
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut storages = HashMap::default();
let mut account_rlp = Vec::with_capacity(128);
let mut account_node_iter = TrieNodeIter::new(walker, hashed_account_cursor);
while let Some(account_node) = account_node_iter.try_next()? {
match account_node {
TrieElement::Branch(node) => {
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_address, account) => {
let storage_multiproof = self.storage_multiproof(hashed_address)?;
account_rlp.clear();
let account = TrieAccount::from((account, storage_multiproof.root));
account.encode(&mut account_rlp as &mut dyn BufMut);
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
storages.insert(hashed_address, storage_multiproof);
}
}
}
let _ = hash_builder.root();
Ok(MultiProof { account_subtree: hash_builder.take_proofs(), storages })
}
pub fn storage_multiproof(
&self,
hashed_address: B256,
) -> Result<StorageMultiProof, StateProofError> {
let mut hashed_storage_cursor =
self.hashed_cursor_factory.hashed_storage_cursor(hashed_address)?;
if hashed_storage_cursor.is_storage_empty()? {
return Ok(StorageMultiProof::default())
}
let target_nibbles = self
.targets
.get(&hashed_address)
.map_or(Vec::new(), |slots| slots.iter().map(Nibbles::unpack).collect());
let mut prefix_set =
self.prefix_sets.storage_prefix_sets.get(&hashed_address).cloned().unwrap_or_default();
prefix_set.extend(target_nibbles.clone());
let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(hashed_address)?;
let walker = TrieWalker::new(trie_cursor, prefix_set.freeze());
let retainer = ProofRetainer::from_iter(target_nibbles);
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor);
while let Some(node) = storage_node_iter.try_next()? {
match node {
TrieElement::Branch(node) => {
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_slot, value) => {
hash_builder.add_leaf(
Nibbles::unpack(hashed_slot),
alloy_rlp::encode_fixed_size(&value).as_ref(),
);
}
}
}
let root = hash_builder.root();
Ok(StorageMultiProof { root, subtree: hash_builder.take_proofs() })
}
}