reth_trie/hashed_cursor/
mock.rs

1use std::{collections::BTreeMap, fmt::Debug, sync::Arc};
2
3use crate::mock::{KeyVisit, KeyVisitType};
4
5use super::{HashedCursor, HashedCursorFactory, HashedStorageCursor};
6use alloy_primitives::{map::B256Map, B256, U256};
7use parking_lot::{Mutex, MutexGuard};
8use reth_primitives_traits::Account;
9use reth_storage_errors::db::DatabaseError;
10use reth_trie_common::HashedPostState;
11use tracing::instrument;
12
13/// Mock hashed cursor factory.
14#[derive(Clone, Default, Debug)]
15pub struct MockHashedCursorFactory {
16    hashed_accounts: Arc<BTreeMap<B256, Account>>,
17    hashed_storage_tries: Arc<B256Map<BTreeMap<B256, U256>>>,
18
19    /// List of keys that the hashed accounts cursor has visited.
20    visited_account_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
21    /// List of keys that the hashed storages cursor has visited, per storage trie.
22    visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
23}
24
25impl MockHashedCursorFactory {
26    /// Creates a new mock hashed cursor factory.
27    pub fn new(
28        hashed_accounts: BTreeMap<B256, Account>,
29        hashed_storage_tries: B256Map<BTreeMap<B256, U256>>,
30    ) -> Self {
31        let visited_storage_keys =
32            hashed_storage_tries.keys().map(|k| (*k, Default::default())).collect();
33        Self {
34            hashed_accounts: Arc::new(hashed_accounts),
35            hashed_storage_tries: Arc::new(hashed_storage_tries),
36            visited_account_keys: Default::default(),
37            visited_storage_keys: Arc::new(visited_storage_keys),
38        }
39    }
40
41    /// Creates a new mock hashed cursor factory from a `HashedPostState`.
42    pub fn from_hashed_post_state(post_state: HashedPostState) -> Self {
43        // Extract accounts from post state, filtering out None (deleted accounts)
44        let hashed_accounts: BTreeMap<B256, Account> = post_state
45            .accounts
46            .into_iter()
47            .filter_map(|(addr, account)| account.map(|acc| (addr, acc)))
48            .collect();
49
50        // Extract storages from post state
51        let hashed_storages: B256Map<BTreeMap<B256, U256>> = post_state
52            .storages
53            .into_iter()
54            .map(|(addr, hashed_storage)| {
55                // Convert HashedStorage to BTreeMap, filtering out zero values (deletions)
56                let storage_map: BTreeMap<B256, U256> = hashed_storage
57                    .storage
58                    .into_iter()
59                    .filter_map(|(slot, value)| (value != U256::ZERO).then_some((slot, value)))
60                    .collect();
61                (addr, storage_map)
62            })
63            .collect();
64
65        Self::new(hashed_accounts, hashed_storages)
66    }
67
68    /// Returns a reference to the list of visited hashed account keys.
69    pub fn visited_account_keys(&self) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
70        self.visited_account_keys.lock()
71    }
72
73    /// Returns a reference to the list of visited hashed storage keys for the given hashed address.
74    pub fn visited_storage_keys(
75        &self,
76        hashed_address: B256,
77    ) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
78        self.visited_storage_keys.get(&hashed_address).expect("storage trie should exist").lock()
79    }
80}
81
82impl HashedCursorFactory for MockHashedCursorFactory {
83    type AccountCursor<'a>
84        = MockHashedCursor<Account>
85    where
86        Self: 'a;
87    type StorageCursor<'a>
88        = MockHashedCursor<U256>
89    where
90        Self: 'a;
91
92    fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
93        Ok(MockHashedCursor::new(self.hashed_accounts.clone(), self.visited_account_keys.clone()))
94    }
95
96    fn hashed_storage_cursor(
97        &self,
98        hashed_address: B256,
99    ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
100        MockHashedCursor::new_storage(
101            self.hashed_storage_tries.clone(),
102            self.visited_storage_keys.clone(),
103            hashed_address,
104        )
105    }
106}
107
108/// Mock hashed cursor type - determines whether this is an account or storage cursor.
109#[derive(Debug)]
110enum MockHashedCursorType<T> {
111    Account {
112        values: Arc<BTreeMap<B256, T>>,
113        visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
114    },
115    Storage {
116        all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
117        all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
118        current_hashed_address: B256,
119    },
120}
121
122/// Mock hashed cursor.
123#[derive(Debug)]
124pub struct MockHashedCursor<T> {
125    /// The current key. If set, it is guaranteed to exist in `values`.
126    current_key: Option<B256>,
127    cursor_type: MockHashedCursorType<T>,
128}
129
130impl<T> MockHashedCursor<T> {
131    /// Creates a new mock hashed cursor for accounts with the given values and key tracking.
132    pub const fn new(
133        values: Arc<BTreeMap<B256, T>>,
134        visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
135    ) -> Self {
136        Self {
137            current_key: None,
138            cursor_type: MockHashedCursorType::Account { values, visited_keys },
139        }
140    }
141
142    /// Creates a new mock hashed cursor for storage with access to all storage tries.
143    pub fn new_storage(
144        all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
145        all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
146        hashed_address: B256,
147    ) -> Result<Self, DatabaseError> {
148        if !all_storage_values.contains_key(&hashed_address) {
149            return Err(DatabaseError::Other(format!(
150                "storage trie for {hashed_address:?} not found"
151            )));
152        }
153        Ok(Self {
154            current_key: None,
155            cursor_type: MockHashedCursorType::Storage {
156                all_storage_values,
157                all_visited_storage_keys,
158                current_hashed_address: hashed_address,
159            },
160        })
161    }
162
163    /// Returns the values map for the current cursor type.
164    fn values(&self) -> &BTreeMap<B256, T> {
165        match &self.cursor_type {
166            MockHashedCursorType::Account { values, .. } => values.as_ref(),
167            MockHashedCursorType::Storage {
168                all_storage_values, current_hashed_address, ..
169            } => all_storage_values
170                .get(current_hashed_address)
171                .expect("current_hashed_address should exist in all_storage_values"),
172        }
173    }
174
175    /// Returns the visited keys mutex for the current cursor type.
176    fn visited_keys(&self) -> &Mutex<Vec<KeyVisit<B256>>> {
177        match &self.cursor_type {
178            MockHashedCursorType::Account { visited_keys, .. } => visited_keys.as_ref(),
179            MockHashedCursorType::Storage {
180                all_visited_storage_keys,
181                current_hashed_address,
182                ..
183            } => all_visited_storage_keys
184                .get(current_hashed_address)
185                .expect("current_hashed_address should exist in all_visited_storage_keys"),
186        }
187    }
188}
189
190impl<T: Debug + Clone> HashedCursor for MockHashedCursor<T> {
191    type Value = T;
192
193    #[instrument(skip(self), ret(level = "trace"))]
194    fn seek(&mut self, key: B256) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
195        // Find the first key that is greater than or equal to the given key.
196        let entry = self.values().iter().find_map(|(k, v)| (k >= &key).then(|| (*k, v.clone())));
197        if let Some((key, _)) = &entry {
198            self.current_key = Some(*key);
199        }
200        self.visited_keys().lock().push(KeyVisit {
201            visit_type: KeyVisitType::SeekNonExact(key),
202            visited_key: entry.as_ref().map(|(k, _)| *k),
203        });
204        Ok(entry)
205    }
206
207    #[instrument(skip(self), ret(level = "trace"))]
208    fn next(&mut self) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
209        let mut iter = self.values().iter();
210        // Jump to the first key that has a prefix of the current key if it's set, or to the first
211        // key otherwise.
212        iter.find(|(k, _)| {
213            self.current_key.as_ref().is_none_or(|current| k.starts_with(current.as_slice()))
214        })
215        .expect("current key should exist in values");
216        // Get the next key-value pair.
217        let entry = iter.next().map(|(k, v)| (*k, v.clone()));
218        if let Some((key, _)) = &entry {
219            self.current_key = Some(*key);
220        }
221        self.visited_keys().lock().push(KeyVisit {
222            visit_type: KeyVisitType::Next,
223            visited_key: entry.as_ref().map(|(k, _)| *k),
224        });
225        Ok(entry)
226    }
227
228    fn reset(&mut self) {
229        self.current_key = None;
230    }
231}
232
233impl<T: Debug + Clone> HashedStorageCursor for MockHashedCursor<T> {
234    #[instrument(level = "trace", skip(self), ret)]
235    fn is_storage_empty(&mut self) -> Result<bool, DatabaseError> {
236        Ok(self.values().is_empty())
237    }
238
239    fn set_hashed_address(&mut self, hashed_address: B256) {
240        self.reset();
241        match &mut self.cursor_type {
242            MockHashedCursorType::Storage { current_hashed_address, .. } => {
243                *current_hashed_address = hashed_address;
244            }
245            MockHashedCursorType::Account { .. } => {
246                panic!("set_hashed_address called on account cursor")
247            }
248        }
249    }
250}