1use std::{collections::BTreeMap, fmt::Debug, sync::Arc};
2
3use crate::mock::{KeyVisit, KeyVisitType};
4
5use super::{HashedCursor, HashedCursorFactory, HashedStorageCursor};
6use alloy_primitives::{map::B256Map, B256, U256};
7use parking_lot::{Mutex, MutexGuard};
8use reth_primitives_traits::Account;
9use reth_storage_errors::db::DatabaseError;
10use reth_trie_common::HashedPostState;
11use tracing::instrument;
12
13#[derive(Clone, Default, Debug)]
15pub struct MockHashedCursorFactory {
16 hashed_accounts: Arc<BTreeMap<B256, Account>>,
17 hashed_storage_tries: Arc<B256Map<BTreeMap<B256, U256>>>,
18
19 visited_account_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
21 visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
23}
24
25impl MockHashedCursorFactory {
26 pub fn new(
28 hashed_accounts: BTreeMap<B256, Account>,
29 hashed_storage_tries: B256Map<BTreeMap<B256, U256>>,
30 ) -> Self {
31 let visited_storage_keys =
32 hashed_storage_tries.keys().map(|k| (*k, Default::default())).collect();
33 Self {
34 hashed_accounts: Arc::new(hashed_accounts),
35 hashed_storage_tries: Arc::new(hashed_storage_tries),
36 visited_account_keys: Default::default(),
37 visited_storage_keys: Arc::new(visited_storage_keys),
38 }
39 }
40
41 pub fn from_hashed_post_state(post_state: HashedPostState) -> Self {
43 let hashed_accounts: BTreeMap<B256, Account> = post_state
45 .accounts
46 .into_iter()
47 .filter_map(|(addr, account)| account.map(|acc| (addr, acc)))
48 .collect();
49
50 let hashed_storages: B256Map<BTreeMap<B256, U256>> = post_state
52 .storages
53 .into_iter()
54 .map(|(addr, hashed_storage)| {
55 let storage_map: BTreeMap<B256, U256> = hashed_storage
57 .storage
58 .into_iter()
59 .filter_map(|(slot, value)| (value != U256::ZERO).then_some((slot, value)))
60 .collect();
61 (addr, storage_map)
62 })
63 .collect();
64
65 Self::new(hashed_accounts, hashed_storages)
66 }
67
68 pub fn visited_account_keys(&self) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
70 self.visited_account_keys.lock()
71 }
72
73 pub fn visited_storage_keys(
75 &self,
76 hashed_address: B256,
77 ) -> MutexGuard<'_, Vec<KeyVisit<B256>>> {
78 self.visited_storage_keys.get(&hashed_address).expect("storage trie should exist").lock()
79 }
80}
81
82impl HashedCursorFactory for MockHashedCursorFactory {
83 type AccountCursor<'a>
84 = MockHashedCursor<Account>
85 where
86 Self: 'a;
87 type StorageCursor<'a>
88 = MockHashedCursor<U256>
89 where
90 Self: 'a;
91
92 fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
93 Ok(MockHashedCursor::new(self.hashed_accounts.clone(), self.visited_account_keys.clone()))
94 }
95
96 fn hashed_storage_cursor(
97 &self,
98 hashed_address: B256,
99 ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
100 MockHashedCursor::new_storage(
101 self.hashed_storage_tries.clone(),
102 self.visited_storage_keys.clone(),
103 hashed_address,
104 )
105 }
106}
107
108#[derive(Debug)]
110enum MockHashedCursorType<T> {
111 Account {
112 values: Arc<BTreeMap<B256, T>>,
113 visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
114 },
115 Storage {
116 all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
117 all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
118 current_hashed_address: B256,
119 },
120}
121
122#[derive(Debug)]
124pub struct MockHashedCursor<T> {
125 current_key: Option<B256>,
127 cursor_type: MockHashedCursorType<T>,
128}
129
130impl<T> MockHashedCursor<T> {
131 pub const fn new(
133 values: Arc<BTreeMap<B256, T>>,
134 visited_keys: Arc<Mutex<Vec<KeyVisit<B256>>>>,
135 ) -> Self {
136 Self {
137 current_key: None,
138 cursor_type: MockHashedCursorType::Account { values, visited_keys },
139 }
140 }
141
142 pub fn new_storage(
144 all_storage_values: Arc<B256Map<BTreeMap<B256, T>>>,
145 all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<B256>>>>>,
146 hashed_address: B256,
147 ) -> Result<Self, DatabaseError> {
148 if !all_storage_values.contains_key(&hashed_address) {
149 return Err(DatabaseError::Other(format!(
150 "storage trie for {hashed_address:?} not found"
151 )));
152 }
153 Ok(Self {
154 current_key: None,
155 cursor_type: MockHashedCursorType::Storage {
156 all_storage_values,
157 all_visited_storage_keys,
158 current_hashed_address: hashed_address,
159 },
160 })
161 }
162
163 fn values(&self) -> &BTreeMap<B256, T> {
165 match &self.cursor_type {
166 MockHashedCursorType::Account { values, .. } => values.as_ref(),
167 MockHashedCursorType::Storage {
168 all_storage_values, current_hashed_address, ..
169 } => all_storage_values
170 .get(current_hashed_address)
171 .expect("current_hashed_address should exist in all_storage_values"),
172 }
173 }
174
175 fn visited_keys(&self) -> &Mutex<Vec<KeyVisit<B256>>> {
177 match &self.cursor_type {
178 MockHashedCursorType::Account { visited_keys, .. } => visited_keys.as_ref(),
179 MockHashedCursorType::Storage {
180 all_visited_storage_keys,
181 current_hashed_address,
182 ..
183 } => all_visited_storage_keys
184 .get(current_hashed_address)
185 .expect("current_hashed_address should exist in all_visited_storage_keys"),
186 }
187 }
188}
189
190impl<T: Debug + Clone> HashedCursor for MockHashedCursor<T> {
191 type Value = T;
192
193 #[instrument(skip(self), ret(level = "trace"))]
194 fn seek(&mut self, key: B256) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
195 let entry = self.values().iter().find_map(|(k, v)| (k >= &key).then(|| (*k, v.clone())));
197 if let Some((key, _)) = &entry {
198 self.current_key = Some(*key);
199 }
200 self.visited_keys().lock().push(KeyVisit {
201 visit_type: KeyVisitType::SeekNonExact(key),
202 visited_key: entry.as_ref().map(|(k, _)| *k),
203 });
204 Ok(entry)
205 }
206
207 #[instrument(skip(self), ret(level = "trace"))]
208 fn next(&mut self) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
209 let mut iter = self.values().iter();
210 iter.find(|(k, _)| {
213 self.current_key.as_ref().is_none_or(|current| k.starts_with(current.as_slice()))
214 })
215 .expect("current key should exist in values");
216 let entry = iter.next().map(|(k, v)| (*k, v.clone()));
218 if let Some((key, _)) = &entry {
219 self.current_key = Some(*key);
220 }
221 self.visited_keys().lock().push(KeyVisit {
222 visit_type: KeyVisitType::Next,
223 visited_key: entry.as_ref().map(|(k, _)| *k),
224 });
225 Ok(entry)
226 }
227
228 fn reset(&mut self) {
229 self.current_key = None;
230 }
231}
232
233impl<T: Debug + Clone> HashedStorageCursor for MockHashedCursor<T> {
234 #[instrument(level = "trace", skip(self), ret)]
235 fn is_storage_empty(&mut self) -> Result<bool, DatabaseError> {
236 Ok(self.values().is_empty())
237 }
238
239 fn set_hashed_address(&mut self, hashed_address: B256) {
240 self.reset();
241 match &mut self.cursor_type {
242 MockHashedCursorType::Storage { current_hashed_address, .. } => {
243 *current_hashed_address = hashed_address;
244 }
245 MockHashedCursorType::Account { .. } => {
246 panic!("set_hashed_address called on account cursor")
247 }
248 }
249 }
250}