1use std::{collections::BTreeMap, fmt::Debug, sync::Arc};
2
3use crate::mock::{KeyVisit, KeyVisitType};
4
5use super::{HashedCursor, HashedCursorFactory, HashedStorageCursor};
6use alloy_primitives::{map::B256Map, B256, U256};
7use parking_lot::{Mutex, MutexGuard};
8use reth_primitives_traits::Account;
9use reth_storage_errors::db::DatabaseError;
10use reth_trie_common::HashedPostState;
11use tracing::instrument;
12
13#[derive(Clone, Default, Debug)]
15pub struct MockHashedCursorFactory {
16 hashed_accounts: Arc<BTreeMap<B256, Account>>,
17 hashed_storage_tries: Arc<B256Map<BTreeMap<B256, U256>>>,
18
19 visited_account_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
21 visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
23}
24
25impl MockHashedCursorFactory {
26 pub fn new(
28 hashed_accounts: BTreeMap<B256, Account>,
29 hashed_storage_tries: B256Map<BTreeMap<B256, U256>>,
30 ) -> Self {
31 let visited_storage_keys =
32 hashed_storage_tries.keys().map(|k| (*k, Default::default())).collect();
33 Self {
34 hashed_accounts: Arc::new(hashed_accounts),
35 hashed_storage_tries: Arc::new(hashed_storage_tries),
36 visited_account_keys: Default::default(),
37 visited_storage_keys: Arc::new(visited_storage_keys),
38 }
39 }
40
41 pub fn from_hashed_post_state(post_state: HashedPostState) -> Self {
43 let hashed_accounts: BTreeMap<B256, Account> = post_state
45 .accounts
46 .into_iter()
47 .filter_map(|(addr, account)| account.map(|acc| (addr, acc)))
48 .collect();
49
50 let mut hashed_storages: B256Map<BTreeMap<B256, U256>> = post_state
52 .storages
53 .into_iter()
54 .map(|(addr, hashed_storage)| {
55 let storage_map: BTreeMap<B256, U256> = hashed_storage
57 .storage
58 .into_iter()
59 .filter_map(|(slot, value)| (value != U256::ZERO).then_some((slot, value)))
60 .collect();
61 (addr, storage_map)
62 })
63 .collect();
64
65 for account in hashed_accounts.keys() {
67 hashed_storages.entry(*account).or_default();
68 }
69
70 Self::new(hashed_accounts, hashed_storages)
71 }
72
73 pub fn visited_account_keys(&self) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
75 self.visited_account_keys.lock()
76 }
77
78 pub fn visited_storage_keys(
80 &self,
81 hashed_address: B256,
82 ) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
83 self.visited_storage_keys.get(&hashed_address).expect("storage trie should exist").lock()
84 }
85}
86
87impl HashedCursorFactory for MockHashedCursorFactory {
88 type AccountCursor<'a>
89 = MockHashedCursor<Account>
90 where
91 Self: 'a;
92 type StorageCursor<'a>
93 = MockHashedCursor<U256>
94 where
95 Self: 'a;
96
97 fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
98 Ok(MockHashedCursor::new(self.hashed_accounts.clone(), self.visited_account_keys.clone()))
99 }
100
101 fn hashed_storage_cursor(
102 &self,
103 hashed_address: B256,
104 ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
105 MockHashedCursor::new_storage(
106 self.hashed_storage_tries.clone(),
107 self.visited_storage_keys.clone(),
108 hashed_address,
109 )
110 }
111}
112
113#[derive(Debug)]
115enum MockHashedCursorType<T> {
116 Account {
117 values: Arc<BTreeMap<B256, T>>,
118 visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
119 },
120 Storage {
121 all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
122 all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
123 current_hashed_address: B256,
124 },
125}
126
127#[derive(Debug)]
129pub struct MockHashedCursor<T> {
130 current_key: Option<B256>,
132 cursor_type: MockHashedCursorType<T>,
133}
134
135impl<T> MockHashedCursor<T> {
136 pub const fn new(
138 values: Arc<BTreeMap<B256, T>>,
139 visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
140 ) -> Self {
141 Self {
142 current_key: None,
143 cursor_type: MockHashedCursorType::Account { values, visited_keys },
144 }
145 }
146
147 pub fn new_storage(
149 all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
150 all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
151 hashed_address: B256,
152 ) -> Result<Self, DatabaseError> {
153 if !all_storage_values.contains_key(&hashed_address) {
154 return Err(DatabaseError::Other(format!(
155 "storage trie for {hashed_address:?} not found"
156 )));
157 }
158 Ok(Self {
159 current_key: None,
160 cursor_type: MockHashedCursorType::Storage {
161 all_storage_values,
162 all_visited_storage_keys,
163 current_hashed_address: hashed_address,
164 },
165 })
166 }
167
168 fn values(&self) -> &BTreeMap<B256, T> {
170 match &self.cursor_type {
171 MockHashedCursorType::Account { values, .. } => values.as_ref(),
172 MockHashedCursorType::Storage {
173 all_storage_values, current_hashed_address, ..
174 } => all_storage_values
175 .get(current_hashed_address)
176 .expect("current_hashed_address should exist in all_storage_values"),
177 }
178 }
179
180 fn visited_keys(&self) -> &Mutex<Vec<KeyVisit<B256>>> {
182 match &self.cursor_type {
183 MockHashedCursorType::Account { visited_keys, .. } => visited_keys.as_ref(),
184 MockHashedCursorType::Storage {
185 all_visited_storage_keys,
186 current_hashed_address,
187 ..
188 } => all_visited_storage_keys
189 .get(current_hashed_address)
190 .expect("current_hashed_address should exist in all_visited_storage_keys"),
191 }
192 }
193}
194
195impl<T: Debug + Clone> HashedCursor for MockHashedCursor<T> {
196 type Value = T;
197
198 #[instrument(skip(self), ret(level = "trace"))]
199 fn seek(&mut self, key: B256) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
200 let entry = self.values().iter().find_map(|(k, v)| (k >= &key).then(|| (*k, v.clone())));
202 if let Some((key, _)) = &entry {
203 self.current_key = Some(*key);
204 }
205 self.visited_keys().lock().push(KeyVisit {
206 visit_type: KeyVisitType::SeekNonExact(key),
207 visited_key: entry.as_ref().map(|(k, _)| *k),
208 });
209 Ok(entry)
210 }
211
212 #[instrument(skip(self), ret(level = "trace"))]
213 fn next(&mut self) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
214 let mut iter = self.values().iter();
215 iter.find(|(k, _)| {
218 self.current_key.as_ref().is_none_or(|current| k.starts_with(current.as_slice()))
219 })
220 .expect("current key should exist in values");
221 let entry = iter.next().map(|(k, v)| (*k, v.clone()));
223 if let Some((key, _)) = &entry {
224 self.current_key = Some(*key);
225 }
226 self.visited_keys().lock().push(KeyVisit {
227 visit_type: KeyVisitType::Next,
228 visited_key: entry.as_ref().map(|(k, _)| *k),
229 });
230 Ok(entry)
231 }
232
233 fn reset(&mut self) {
234 self.current_key = None;
235 }
236}
237
238impl<T: Debug + Clone> HashedStorageCursor for MockHashedCursor<T> {
239 #[instrument(level = "trace", skip(self), ret)]
240 fn is_storage_empty(&mut self) -> Result<bool, DatabaseError> {
241 Ok(self.values().is_empty())
242 }
243
244 fn set_hashed_address(&mut self, hashed_address: B256) {
245 self.reset();
246 match &mut self.cursor_type {
247 MockHashedCursorType::Storage { current_hashed_address, .. } => {
248 *current_hashed_address = hashed_address;
249 }
250 MockHashedCursorType::Account { .. } => {
251 panic!("set_hashed_address called on account cursor")
252 }
253 }
254 }
255}