1use parking_lot::{Mutex, MutexGuard};
2use std::{collections::BTreeMap, sync::Arc};
3use tracing::instrument;
4
5use super::{TrieCursor, TrieCursorFactory, TrieStorageCursor};
6use crate::{
7 mock::{KeyVisit, KeyVisitType},
8 BranchNodeCompact, Nibbles,
9};
10use alloy_primitives::{map::B256Map, B256};
11use reth_storage_errors::db::DatabaseError;
12use reth_trie_common::updates::TrieUpdates;
13
14#[derive(Clone, Default, Debug)]
16pub struct MockTrieCursorFactory {
17 account_trie_nodes: Arc<BTreeMap<Nibbles, BranchNodeCompact>>,
18 storage_tries: Arc<B256Map<BTreeMap<Nibbles, BranchNodeCompact>>>,
19
20 visited_account_keys: Arc<Mutex<Vec<KeyVisit<Nibbles>>>>,
22 visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<Nibbles>>>>>,
24}
25
26impl MockTrieCursorFactory {
27 pub fn new(
29 account_trie_nodes: BTreeMap<Nibbles, BranchNodeCompact>,
30 storage_tries: B256Map<BTreeMap<Nibbles, BranchNodeCompact>>,
31 ) -> Self {
32 let visited_storage_keys = storage_tries.keys().map(|k| (*k, Default::default())).collect();
33 Self {
34 account_trie_nodes: Arc::new(account_trie_nodes),
35 storage_tries: Arc::new(storage_tries),
36 visited_account_keys: Default::default(),
37 visited_storage_keys: Arc::new(visited_storage_keys),
38 }
39 }
40
41 pub fn from_trie_updates(updates: TrieUpdates) -> Self {
43 let account_trie_nodes: BTreeMap<Nibbles, BranchNodeCompact> =
45 updates.account_nodes.into_iter().collect();
46
47 let storage_tries: B256Map<BTreeMap<Nibbles, BranchNodeCompact>> = updates
49 .storage_tries
50 .into_iter()
51 .map(|(addr, storage_updates)| {
52 let storage_nodes: BTreeMap<Nibbles, BranchNodeCompact> =
54 storage_updates.storage_nodes.into_iter().collect();
55 (addr, storage_nodes)
56 })
57 .collect();
58
59 Self::new(account_trie_nodes, storage_tries)
60 }
61
62 pub fn visited_account_keys(&self) -> MutexGuard<'_, Vec<KeyVisit<Nibbles>>> {
64 self.visited_account_keys.lock()
65 }
66
67 pub fn visited_storage_keys(
69 &self,
70 hashed_address: B256,
71 ) -> MutexGuard<'_, Vec<KeyVisit<Nibbles>>> {
72 self.visited_storage_keys.get(&hashed_address).expect("storage trie should exist").lock()
73 }
74}
75
76impl TrieCursorFactory for MockTrieCursorFactory {
77 type AccountTrieCursor<'a>
78 = MockTrieCursor
79 where
80 Self: 'a;
81 type StorageTrieCursor<'a>
82 = MockTrieCursor
83 where
84 Self: 'a;
85
86 fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
88 Ok(MockTrieCursor::new(self.account_trie_nodes.clone(), self.visited_account_keys.clone()))
89 }
90
91 fn storage_trie_cursor(
93 &self,
94 hashed_address: B256,
95 ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
96 MockTrieCursor::new_storage(
97 self.storage_tries.clone(),
98 self.visited_storage_keys.clone(),
99 hashed_address,
100 )
101 }
102}
103
104#[derive(Debug)]
106enum MockTrieCursorType {
107 Account {
108 trie_nodes: Arc<BTreeMap<Nibbles, BranchNodeCompact>>,
109 visited_keys: Arc<Mutex<Vec<KeyVisit<Nibbles>>>>,
110 },
111 Storage {
112 all_storage_tries: Arc<B256Map<BTreeMap<Nibbles, BranchNodeCompact>>>,
113 all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<Nibbles>>>>>,
114 current_hashed_address: B256,
115 },
116}
117
118#[derive(Debug)]
120#[non_exhaustive]
121pub struct MockTrieCursor {
122 current_key: Option<Nibbles>,
124 cursor_type: MockTrieCursorType,
125}
126
127impl MockTrieCursor {
128 pub const fn new(
130 trie_nodes: Arc<BTreeMap<Nibbles, BranchNodeCompact>>,
131 visited_keys: Arc<Mutex<Vec<KeyVisit<Nibbles>>>>,
132 ) -> Self {
133 Self {
134 current_key: None,
135 cursor_type: MockTrieCursorType::Account { trie_nodes, visited_keys },
136 }
137 }
138
139 pub fn new_storage(
141 all_storage_tries: Arc<B256Map<BTreeMap<Nibbles, BranchNodeCompact>>>,
142 all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<Nibbles>>>>>,
143 hashed_address: B256,
144 ) -> Result<Self, DatabaseError> {
145 if !all_storage_tries.contains_key(&hashed_address) {
146 return Err(DatabaseError::Other(format!(
147 "storage trie for {hashed_address:?} not found"
148 )));
149 }
150 Ok(Self {
151 current_key: None,
152 cursor_type: MockTrieCursorType::Storage {
153 all_storage_tries,
154 all_visited_storage_keys,
155 current_hashed_address: hashed_address,
156 },
157 })
158 }
159
160 fn trie_nodes(&self) -> &BTreeMap<Nibbles, BranchNodeCompact> {
162 match &self.cursor_type {
163 MockTrieCursorType::Account { trie_nodes, .. } => trie_nodes.as_ref(),
164 MockTrieCursorType::Storage { all_storage_tries, current_hashed_address, .. } => {
165 all_storage_tries
166 .get(current_hashed_address)
167 .expect("current_hashed_address should exist in all_storage_tries")
168 }
169 }
170 }
171
172 fn visited_keys(&self) -> &Mutex<Vec<KeyVisit<Nibbles>>> {
174 match &self.cursor_type {
175 MockTrieCursorType::Account { visited_keys, .. } => visited_keys.as_ref(),
176 MockTrieCursorType::Storage {
177 all_visited_storage_keys,
178 current_hashed_address,
179 ..
180 } => all_visited_storage_keys
181 .get(current_hashed_address)
182 .expect("current_hashed_address should exist in all_visited_storage_keys"),
183 }
184 }
185}
186
187impl TrieCursor for MockTrieCursor {
188 #[instrument(skip(self), ret(level = "trace"))]
189 fn seek_exact(
190 &mut self,
191 key: Nibbles,
192 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
193 let entry = self.trie_nodes().get(&key).cloned().map(|value| (key, value));
194 if let Some((key, _)) = &entry {
195 self.current_key = Some(*key);
196 }
197 self.visited_keys().lock().push(KeyVisit {
198 visit_type: KeyVisitType::SeekExact(key),
199 visited_key: entry.as_ref().map(|(k, _)| *k),
200 });
201 Ok(entry)
202 }
203
204 #[instrument(skip(self), ret(level = "trace"))]
205 fn seek(
206 &mut self,
207 key: Nibbles,
208 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
209 let entry =
211 self.trie_nodes().iter().find_map(|(k, v)| (k >= &key).then(|| (*k, v.clone())));
212 if let Some((key, _)) = &entry {
213 self.current_key = Some(*key);
214 }
215 self.visited_keys().lock().push(KeyVisit {
216 visit_type: KeyVisitType::SeekNonExact(key),
217 visited_key: entry.as_ref().map(|(k, _)| *k),
218 });
219 Ok(entry)
220 }
221
222 #[instrument(skip(self), ret(level = "trace"))]
223 fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
224 let mut iter = self.trie_nodes().iter();
225 iter.find(|(k, _)| self.current_key.as_ref().is_none_or(|current| k.starts_with(current)))
228 .expect("current key should exist in trie nodes");
229 let entry = iter.next().map(|(k, v)| (*k, v.clone()));
231 if let Some((key, _)) = &entry {
232 self.current_key = Some(*key);
233 }
234 self.visited_keys().lock().push(KeyVisit {
235 visit_type: KeyVisitType::Next,
236 visited_key: entry.as_ref().map(|(k, _)| *k),
237 });
238 Ok(entry)
239 }
240
241 #[instrument(skip(self), ret(level = "trace"))]
242 fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
243 Ok(self.current_key)
244 }
245
246 fn reset(&mut self) {
247 self.current_key = None;
248 }
249}
250
251impl TrieStorageCursor for MockTrieCursor {
252 fn set_hashed_address(&mut self, hashed_address: B256) {
253 self.reset();
254 match &mut self.cursor_type {
255 MockTrieCursorType::Storage { current_hashed_address, .. } => {
256 *current_hashed_address = hashed_address;
257 }
258 MockTrieCursorType::Account { .. } => {
259 panic!("set_hashed_address called on account cursor")
260 }
261 }
262 }
263}