reth_trie/hashed_cursor/
mock.rs

1use std::{collections::BTreeMap, fmt::Debug, sync::Arc};
2
3use crate::mock::{KeyVisit, KeyVisitType};
4
5use super::{HashedCursor, HashedCursorFactory, HashedStorageCursor};
6use alloy_primitives::{map::B256Map, B256, U256};
7use parking_lot::{Mutex, MutexGuard};
8use reth_primitives_traits::Account;
9use reth_storage_errors::db::DatabaseError;
10use reth_trie_common::HashedPostState;
11use tracing::instrument;
12
13/// Mock hashed cursor factory.
14#[derive(Clone, Default, Debug)]
15pub struct MockHashedCursorFactory {
16    hashed_accounts: Arc<BTreeMap<B256, Account>>,
17    hashed_storage_tries: Arc<B256Map<BTreeMap<B256, U256>>>,
18
19    /// List of keys that the hashed accounts cursor has visited.
20    visited_account_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
21    /// List of keys that the hashed storages cursor has visited, per storage trie.
22    visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
23}
24
25impl MockHashedCursorFactory {
26    /// Creates a new mock hashed cursor factory.
27    pub fn new(
28        hashed_accounts: BTreeMap<B256, Account>,
29        hashed_storage_tries: B256Map<BTreeMap<B256, U256>>,
30    ) -> Self {
31        let visited_storage_keys =
32            hashed_storage_tries.keys().map(|k| (*k, Default::default())).collect();
33        Self {
34            hashed_accounts: Arc::new(hashed_accounts),
35            hashed_storage_tries: Arc::new(hashed_storage_tries),
36            visited_account_keys: Default::default(),
37            visited_storage_keys: Arc::new(visited_storage_keys),
38        }
39    }
40
41    /// Creates a new mock hashed cursor factory from a `HashedPostState`.
42    pub fn from_hashed_post_state(post_state: HashedPostState) -> Self {
43        // Extract accounts from post state, filtering out None (deleted accounts)
44        let hashed_accounts: BTreeMap<B256, Account> = post_state
45            .accounts
46            .into_iter()
47            .filter_map(|(addr, account)| account.map(|acc| (addr, acc)))
48            .collect();
49
50        // Extract storages from post state
51        let mut hashed_storages: B256Map<BTreeMap<B256, U256>> = post_state
52            .storages
53            .into_iter()
54            .map(|(addr, hashed_storage)| {
55                // Convert HashedStorage to BTreeMap, filtering out zero values (deletions)
56                let storage_map: BTreeMap<B256, U256> = hashed_storage
57                    .storage
58                    .into_iter()
59                    .filter_map(|(slot, value)| (value != U256::ZERO).then_some((slot, value)))
60                    .collect();
61                (addr, storage_map)
62            })
63            .collect();
64
65        // Ensure all accounts have at least an empty storage
66        for account in hashed_accounts.keys() {
67            hashed_storages.entry(*account).or_default();
68        }
69
70        Self::new(hashed_accounts, hashed_storages)
71    }
72
73    /// Returns a reference to the list of visited hashed account keys.
74    pub fn visited_account_keys(&self) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
75        self.visited_account_keys.lock()
76    }
77
78    /// Returns a reference to the list of visited hashed storage keys for the given hashed address.
79    pub fn visited_storage_keys(
80        &self,
81        hashed_address: B256,
82    ) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
83        self.visited_storage_keys.get(&hashed_address).expect("storage trie should exist").lock()
84    }
85}
86
87impl HashedCursorFactory for MockHashedCursorFactory {
88    type AccountCursor<'a>
89        = MockHashedCursor<Account>
90    where
91        Self: 'a;
92    type StorageCursor<'a>
93        = MockHashedCursor<U256>
94    where
95        Self: 'a;
96
97    fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
98        Ok(MockHashedCursor::new(self.hashed_accounts.clone(), self.visited_account_keys.clone()))
99    }
100
101    fn hashed_storage_cursor(
102        &self,
103        hashed_address: B256,
104    ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
105        MockHashedCursor::new_storage(
106            self.hashed_storage_tries.clone(),
107            self.visited_storage_keys.clone(),
108            hashed_address,
109        )
110    }
111}
112
113/// Mock hashed cursor type - determines whether this is an account or storage cursor.
114#[derive(Debug)]
115enum MockHashedCursorType<T> {
116    Account {
117        values: Arc<BTreeMap<B256, T>>,
118        visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
119    },
120    Storage {
121        all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
122        all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
123        current_hashed_address: B256,
124    },
125}
126
127/// Mock hashed cursor.
128#[derive(Debug)]
129pub struct MockHashedCursor<T> {
130    /// The current key. If set, it is guaranteed to exist in `values`.
131    current_key: Option<B256>,
132    cursor_type: MockHashedCursorType<T>,
133}
134
135impl<T> MockHashedCursor<T> {
136    /// Creates a new mock hashed cursor for accounts with the given values and key tracking.
137    pub const fn new(
138        values: Arc<BTreeMap<B256, T>>,
139        visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
140    ) -> Self {
141        Self {
142            current_key: None,
143            cursor_type: MockHashedCursorType::Account { values, visited_keys },
144        }
145    }
146
147    /// Creates a new mock hashed cursor for storage with access to all storage tries.
148    pub fn new_storage(
149        all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
150        all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
151        hashed_address: B256,
152    ) -> Result<Self, DatabaseError> {
153        if !all_storage_values.contains_key(&hashed_address) {
154            return Err(DatabaseError::Other(format!(
155                "storage trie for {hashed_address:?} not found"
156            )));
157        }
158        Ok(Self {
159            current_key: None,
160            cursor_type: MockHashedCursorType::Storage {
161                all_storage_values,
162                all_visited_storage_keys,
163                current_hashed_address: hashed_address,
164            },
165        })
166    }
167
168    /// Returns the values map for the current cursor type.
169    fn values(&self) -> &BTreeMap<B256, T> {
170        match &self.cursor_type {
171            MockHashedCursorType::Account { values, .. } => values.as_ref(),
172            MockHashedCursorType::Storage {
173                all_storage_values, current_hashed_address, ..
174            } => all_storage_values
175                .get(current_hashed_address)
176                .expect("current_hashed_address should exist in all_storage_values"),
177        }
178    }
179
180    /// Returns the visited keys mutex for the current cursor type.
181    fn visited_keys(&self) -> &Mutex<Vec<KeyVisit<B256>>> {
182        match &self.cursor_type {
183            MockHashedCursorType::Account { visited_keys, .. } => visited_keys.as_ref(),
184            MockHashedCursorType::Storage {
185                all_visited_storage_keys,
186                current_hashed_address,
187                ..
188            } => all_visited_storage_keys
189                .get(current_hashed_address)
190                .expect("current_hashed_address should exist in all_visited_storage_keys"),
191        }
192    }
193}
194
195impl<T: Debug + Clone> HashedCursor for MockHashedCursor<T> {
196    type Value = T;
197
198    #[instrument(skip(self), ret(level = "trace"))]
199    fn seek(&mut self, key: B256) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
200        // Find the first key that is greater than or equal to the given key.
201        let entry = self.values().iter().find_map(|(k, v)| (k >= &key).then(|| (*k, v.clone())));
202        if let Some((key, _)) = &entry {
203            self.current_key = Some(*key);
204        }
205        self.visited_keys().lock().push(KeyVisit {
206            visit_type: KeyVisitType::SeekNonExact(key),
207            visited_key: entry.as_ref().map(|(k, _)| *k),
208        });
209        Ok(entry)
210    }
211
212    #[instrument(skip(self), ret(level = "trace"))]
213    fn next(&mut self) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
214        let mut iter = self.values().iter();
215        // Jump to the first key that has a prefix of the current key if it's set, or to the first
216        // key otherwise.
217        iter.find(|(k, _)| {
218            self.current_key.as_ref().is_none_or(|current| k.starts_with(current.as_slice()))
219        })
220        .expect("current key should exist in values");
221        // Get the next key-value pair.
222        let entry = iter.next().map(|(k, v)| (*k, v.clone()));
223        if let Some((key, _)) = &entry {
224            self.current_key = Some(*key);
225        }
226        self.visited_keys().lock().push(KeyVisit {
227            visit_type: KeyVisitType::Next,
228            visited_key: entry.as_ref().map(|(k, _)| *k),
229        });
230        Ok(entry)
231    }
232
233    fn reset(&mut self) {
234        self.current_key = None;
235    }
236}
237
238impl<T: Debug + Clone> HashedStorageCursor for MockHashedCursor<T> {
239    #[instrument(level = "trace", skip(self), ret)]
240    fn is_storage_empty(&mut self) -> Result<bool, DatabaseError> {
241        Ok(self.values().is_empty())
242    }
243
244    fn set_hashed_address(&mut self, hashed_address: B256) {
245        self.reset();
246        match &mut self.cursor_type {
247            MockHashedCursorType::Storage { current_hashed_address, .. } => {
248                *current_hashed_address = hashed_address;
249            }
250            MockHashedCursorType::Account { .. } => {
251                panic!("set_hashed_address called on account cursor")
252            }
253        }
254    }
255}