reth_trie/trie_cursor/
mock.rs

1use parking_lot::{Mutex, MutexGuard};
2use std::{collections::BTreeMap, sync::Arc};
3use tracing::instrument;
4
5use super::{TrieCursor, TrieCursorFactory, TrieStorageCursor};
6use crate::{
7    mock::{KeyVisit, KeyVisitType},
8    BranchNodeCompact, Nibbles,
9};
10use alloy_primitives::{map::B256Map, B256};
11use reth_storage_errors::db::DatabaseError;
12use reth_trie_common::updates::TrieUpdates;
13
14/// Mock trie cursor factory.
15#[derive(Clone, Default, Debug)]
16pub struct MockTrieCursorFactory {
17    account_trie_nodes: Arc<BTreeMap<Nibbles, BranchNodeCompact>>,
18    storage_tries: Arc<B256Map<BTreeMap<Nibbles, BranchNodeCompact>>>,
19
20    /// List of keys that the account trie cursor has visited.
21    visited_account_keys: Arc<Mutex<Vec<KeyVisit<Nibbles>>>>,
22    /// List of keys that the storage trie cursor has visited, per storage trie.
23    visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<Nibbles>>>>>,
24}
25
26impl MockTrieCursorFactory {
27    /// Creates a new mock trie cursor factory.
28    pub fn new(
29        account_trie_nodes: BTreeMap<Nibbles, BranchNodeCompact>,
30        storage_tries: B256Map<BTreeMap<Nibbles, BranchNodeCompact>>,
31    ) -> Self {
32        let visited_storage_keys = storage_tries.keys().map(|k| (*k, Default::default())).collect();
33        Self {
34            account_trie_nodes: Arc::new(account_trie_nodes),
35            storage_tries: Arc::new(storage_tries),
36            visited_account_keys: Default::default(),
37            visited_storage_keys: Arc::new(visited_storage_keys),
38        }
39    }
40
41    /// Creates a new mock trie cursor factory from `TrieUpdates`.
42    pub fn from_trie_updates(updates: TrieUpdates) -> Self {
43        // Convert account nodes from HashMap to BTreeMap
44        let account_trie_nodes: BTreeMap<Nibbles, BranchNodeCompact> =
45            updates.account_nodes.into_iter().collect();
46
47        // Convert storage tries
48        let storage_tries: B256Map<BTreeMap<Nibbles, BranchNodeCompact>> = updates
49            .storage_tries
50            .into_iter()
51            .map(|(addr, storage_updates)| {
52                // Convert storage nodes from HashMap to BTreeMap
53                let storage_nodes: BTreeMap<Nibbles, BranchNodeCompact> =
54                    storage_updates.storage_nodes.into_iter().collect();
55                (addr, storage_nodes)
56            })
57            .collect();
58
59        Self::new(account_trie_nodes, storage_tries)
60    }
61
62    /// Returns a reference to the list of visited account keys.
63    pub fn visited_account_keys(&self) -> MutexGuard<'_, Vec<KeyVisit<Nibbles>>> {
64        self.visited_account_keys.lock()
65    }
66
67    /// Returns a reference to the list of visited storage keys for the given hashed address.
68    pub fn visited_storage_keys(
69        &self,
70        hashed_address: B256,
71    ) -> MutexGuard<'_, Vec<KeyVisit<Nibbles>>> {
72        self.visited_storage_keys.get(&hashed_address).expect("storage trie should exist").lock()
73    }
74}
75
76impl TrieCursorFactory for MockTrieCursorFactory {
77    type AccountTrieCursor<'a>
78        = MockTrieCursor
79    where
80        Self: 'a;
81    type StorageTrieCursor<'a>
82        = MockTrieCursor
83    where
84        Self: 'a;
85
86    /// Generates a mock account trie cursor.
87    fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
88        Ok(MockTrieCursor::new(self.account_trie_nodes.clone(), self.visited_account_keys.clone()))
89    }
90
91    /// Generates a mock storage trie cursor.
92    fn storage_trie_cursor(
93        &self,
94        hashed_address: B256,
95    ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
96        MockTrieCursor::new_storage(
97            self.storage_tries.clone(),
98            self.visited_storage_keys.clone(),
99            hashed_address,
100        )
101    }
102}
103
104/// Mock trie cursor type - determines whether this is an account or storage cursor.
105#[derive(Debug)]
106enum MockTrieCursorType {
107    Account {
108        trie_nodes: Arc<BTreeMap<Nibbles, BranchNodeCompact>>,
109        visited_keys: Arc<Mutex<Vec<KeyVisit<Nibbles>>>>,
110    },
111    Storage {
112        all_storage_tries: Arc<B256Map<BTreeMap<Nibbles, BranchNodeCompact>>>,
113        all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<Nibbles>>>>>,
114        current_hashed_address: B256,
115    },
116}
117
118/// Mock trie cursor.
119#[derive(Debug)]
120#[non_exhaustive]
121pub struct MockTrieCursor {
122    /// The current key. If set, it is guaranteed to exist in `trie_nodes`.
123    current_key: Option<Nibbles>,
124    cursor_type: MockTrieCursorType,
125}
126
127impl MockTrieCursor {
128    /// Creates a new mock trie cursor for accounts with the given trie nodes and key tracking.
129    pub const fn new(
130        trie_nodes: Arc<BTreeMap<Nibbles, BranchNodeCompact>>,
131        visited_keys: Arc<Mutex<Vec<KeyVisit<Nibbles>>>>,
132    ) -> Self {
133        Self {
134            current_key: None,
135            cursor_type: MockTrieCursorType::Account { trie_nodes, visited_keys },
136        }
137    }
138
139    /// Creates a new mock trie cursor for storage with access to all storage tries.
140    pub fn new_storage(
141        all_storage_tries: Arc<B256Map<BTreeMap<Nibbles, BranchNodeCompact>>>,
142        all_visited_storage_keys: Arc<B256Map<Mutex<Vec<KeyVisit<Nibbles>>>>>,
143        hashed_address: B256,
144    ) -> Result<Self, DatabaseError> {
145        if !all_storage_tries.contains_key(&hashed_address) {
146            return Err(DatabaseError::Other(format!(
147                "storage trie for {hashed_address:?} not found"
148            )));
149        }
150        Ok(Self {
151            current_key: None,
152            cursor_type: MockTrieCursorType::Storage {
153                all_storage_tries,
154                all_visited_storage_keys,
155                current_hashed_address: hashed_address,
156            },
157        })
158    }
159
160    /// Returns the trie nodes map for the current cursor type.
161    fn trie_nodes(&self) -> &BTreeMap<Nibbles, BranchNodeCompact> {
162        match &self.cursor_type {
163            MockTrieCursorType::Account { trie_nodes, .. } => trie_nodes.as_ref(),
164            MockTrieCursorType::Storage { all_storage_tries, current_hashed_address, .. } => {
165                all_storage_tries
166                    .get(current_hashed_address)
167                    .expect("current_hashed_address should exist in all_storage_tries")
168            }
169        }
170    }
171
172    /// Returns the visited keys mutex for the current cursor type.
173    fn visited_keys(&self) -> &Mutex<Vec<KeyVisit<Nibbles>>> {
174        match &self.cursor_type {
175            MockTrieCursorType::Account { visited_keys, .. } => visited_keys.as_ref(),
176            MockTrieCursorType::Storage {
177                all_visited_storage_keys,
178                current_hashed_address,
179                ..
180            } => all_visited_storage_keys
181                .get(current_hashed_address)
182                .expect("current_hashed_address should exist in all_visited_storage_keys"),
183        }
184    }
185}
186
187impl TrieCursor for MockTrieCursor {
188    #[instrument(skip(self), ret(level = "trace"))]
189    fn seek_exact(
190        &mut self,
191        key: Nibbles,
192    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
193        let entry = self.trie_nodes().get(&key).cloned().map(|value| (key, value));
194        if let Some((key, _)) = &entry {
195            self.current_key = Some(*key);
196        }
197        self.visited_keys().lock().push(KeyVisit {
198            visit_type: KeyVisitType::SeekExact(key),
199            visited_key: entry.as_ref().map(|(k, _)| *k),
200        });
201        Ok(entry)
202    }
203
204    #[instrument(skip(self), ret(level = "trace"))]
205    fn seek(
206        &mut self,
207        key: Nibbles,
208    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
209        // Find the first key that is greater than or equal to the given key.
210        let entry =
211            self.trie_nodes().iter().find_map(|(k, v)| (k >= &key).then(|| (*k, v.clone())));
212        if let Some((key, _)) = &entry {
213            self.current_key = Some(*key);
214        }
215        self.visited_keys().lock().push(KeyVisit {
216            visit_type: KeyVisitType::SeekNonExact(key),
217            visited_key: entry.as_ref().map(|(k, _)| *k),
218        });
219        Ok(entry)
220    }
221
222    #[instrument(skip(self), ret(level = "trace"))]
223    fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
224        let mut iter = self.trie_nodes().iter();
225        // Jump to the first key that has a prefix of the current key if it's set, or to the first
226        // key otherwise.
227        iter.find(|(k, _)| self.current_key.as_ref().is_none_or(|current| k.starts_with(current)))
228            .expect("current key should exist in trie nodes");
229        // Get the next key-value pair.
230        let entry = iter.next().map(|(k, v)| (*k, v.clone()));
231        if let Some((key, _)) = &entry {
232            self.current_key = Some(*key);
233        }
234        self.visited_keys().lock().push(KeyVisit {
235            visit_type: KeyVisitType::Next,
236            visited_key: entry.as_ref().map(|(k, _)| *k),
237        });
238        Ok(entry)
239    }
240
241    #[instrument(skip(self), ret(level = "trace"))]
242    fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
243        Ok(self.current_key)
244    }
245
246    fn reset(&mut self) {
247        self.current_key = None;
248    }
249}
250
251impl TrieStorageCursor for MockTrieCursor {
252    fn set_hashed_address(&mut self, hashed_address: B256) {
253        self.reset();
254        match &mut self.cursor_type {
255            MockTrieCursorType::Storage { current_hashed_address, .. } => {
256                *current_hashed_address = hashed_address;
257            }
258            MockTrieCursorType::Account { .. } => {
259                panic!("set_hashed_address called on account cursor")
260            }
261        }
262    }
263}