use super::{TrieCursor, TrieCursorFactory};
use crate::{
forward_cursor::ForwardInMemoryCursor,
updates::{StorageTrieUpdatesSorted, TrieUpdatesSorted},
};
use alloy_primitives::B256;
use reth_storage_errors::db::DatabaseError;
use reth_trie_common::{BranchNodeCompact, Nibbles};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct InMemoryTrieCursorFactory<'a, CF> {
cursor_factory: CF,
trie_updates: &'a TrieUpdatesSorted,
}
impl<'a, CF> InMemoryTrieCursorFactory<'a, CF> {
pub const fn new(cursor_factory: CF, trie_updates: &'a TrieUpdatesSorted) -> Self {
Self { cursor_factory, trie_updates }
}
}
impl<'a, CF: TrieCursorFactory> TrieCursorFactory for InMemoryTrieCursorFactory<'a, CF> {
type AccountTrieCursor = InMemoryAccountTrieCursor<'a, CF::AccountTrieCursor>;
type StorageTrieCursor = InMemoryStorageTrieCursor<'a, CF::StorageTrieCursor>;
fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor, DatabaseError> {
let cursor = self.cursor_factory.account_trie_cursor()?;
Ok(InMemoryAccountTrieCursor::new(cursor, self.trie_updates))
}
fn storage_trie_cursor(
&self,
hashed_address: B256,
) -> Result<Self::StorageTrieCursor, DatabaseError> {
let cursor = self.cursor_factory.storage_trie_cursor(hashed_address)?;
Ok(InMemoryStorageTrieCursor::new(
hashed_address,
cursor,
self.trie_updates.storage_tries.get(&hashed_address),
))
}
}
#[derive(Debug)]
pub struct InMemoryAccountTrieCursor<'a, C> {
cursor: C,
in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, BranchNodeCompact>,
removed_nodes: &'a HashSet<Nibbles>,
last_key: Option<Nibbles>,
}
impl<'a, C: TrieCursor> InMemoryAccountTrieCursor<'a, C> {
pub const fn new(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self {
let in_memory_cursor = ForwardInMemoryCursor::new(&trie_updates.account_nodes);
Self {
cursor,
in_memory_cursor,
removed_nodes: &trie_updates.removed_nodes,
last_key: None,
}
}
fn seek_inner(
&mut self,
key: Nibbles,
exact: bool,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.seek(&key);
if exact && in_memory.as_ref().map_or(false, |entry| entry.0 == key) {
return Ok(in_memory)
}
let mut db_entry = self.cursor.seek(key.clone())?;
while db_entry.as_ref().map_or(false, |entry| self.removed_nodes.contains(&entry.0)) {
db_entry = self.cursor.next()?;
}
Ok(compare_trie_node_entries(in_memory, db_entry)
.filter(|(nibbles, _)| !exact || nibbles == &key))
}
fn next_inner(
&mut self,
last: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.first_after(&last);
let mut db_entry = self.cursor.seek(last.clone())?;
while db_entry
.as_ref()
.map_or(false, |entry| entry.0 < last || self.removed_nodes.contains(&entry.0))
{
db_entry = self.cursor.next()?;
}
Ok(compare_trie_node_entries(in_memory, db_entry))
}
}
impl<C: TrieCursor> TrieCursor for InMemoryAccountTrieCursor<'_, C> {
fn seek_exact(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, true)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}
fn seek(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, false)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let next = match &self.last_key {
Some(last) => {
let entry = self.next_inner(last.clone())?;
self.last_key = entry.as_ref().map(|entry| entry.0.clone());
entry
}
None => None,
};
Ok(next)
}
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
match &self.last_key {
Some(key) => Ok(Some(key.clone())),
None => self.cursor.current(),
}
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct InMemoryStorageTrieCursor<'a, C> {
hashed_address: B256,
cursor: C,
in_memory_cursor: Option<ForwardInMemoryCursor<'a, Nibbles, BranchNodeCompact>>,
removed_nodes: Option<&'a HashSet<Nibbles>>,
storage_trie_cleared: bool,
last_key: Option<Nibbles>,
}
impl<'a, C> InMemoryStorageTrieCursor<'a, C> {
pub fn new(
hashed_address: B256,
cursor: C,
updates: Option<&'a StorageTrieUpdatesSorted>,
) -> Self {
let in_memory_cursor = updates.map(|u| ForwardInMemoryCursor::new(&u.storage_nodes));
let removed_nodes = updates.map(|u| &u.removed_nodes);
let storage_trie_cleared = updates.map_or(false, |u| u.is_deleted);
Self {
hashed_address,
cursor,
in_memory_cursor,
removed_nodes,
storage_trie_cleared,
last_key: None,
}
}
}
impl<C: TrieCursor> InMemoryStorageTrieCursor<'_, C> {
fn seek_inner(
&mut self,
key: Nibbles,
exact: bool,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.seek(&key));
if self.storage_trie_cleared ||
(exact && in_memory.as_ref().map_or(false, |entry| entry.0 == key))
{
return Ok(in_memory.filter(|(nibbles, _)| !exact || nibbles == &key))
}
let mut db_entry = self.cursor.seek(key.clone())?;
while db_entry.as_ref().map_or(false, |entry| {
self.removed_nodes.as_ref().map_or(false, |r| r.contains(&entry.0))
}) {
db_entry = self.cursor.next()?;
}
Ok(compare_trie_node_entries(in_memory, db_entry)
.filter(|(nibbles, _)| !exact || nibbles == &key))
}
fn next_inner(
&mut self,
last: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.first_after(&last));
if self.storage_trie_cleared {
return Ok(in_memory)
}
let mut db_entry = self.cursor.seek(last.clone())?;
while db_entry.as_ref().map_or(false, |entry| {
entry.0 < last || self.removed_nodes.as_ref().map_or(false, |r| r.contains(&entry.0))
}) {
db_entry = self.cursor.next()?;
}
Ok(compare_trie_node_entries(in_memory, db_entry))
}
}
impl<C: TrieCursor> TrieCursor for InMemoryStorageTrieCursor<'_, C> {
fn seek_exact(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, true)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}
fn seek(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, false)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let next = match &self.last_key {
Some(last) => {
let entry = self.next_inner(last.clone())?;
self.last_key = entry.as_ref().map(|entry| entry.0.clone());
entry
}
None => None,
};
Ok(next)
}
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
match &self.last_key {
Some(key) => Ok(Some(key.clone())),
None => self.cursor.current(),
}
}
}
fn compare_trie_node_entries(
mut in_memory_item: Option<(Nibbles, BranchNodeCompact)>,
mut db_item: Option<(Nibbles, BranchNodeCompact)>,
) -> Option<(Nibbles, BranchNodeCompact)> {
if let Some((in_memory_entry, db_entry)) = in_memory_item.as_ref().zip(db_item.as_ref()) {
if in_memory_entry.0 <= db_entry.0 {
in_memory_item.take()
} else {
db_item.take()
}
} else {
db_item.or(in_memory_item)
}
}