1use super::{TrieCursor, TrieCursorFactory};
2use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
3use alloy_primitives::B256;
4use reth_storage_errors::db::DatabaseError;
5use reth_trie_common::{BranchNodeCompact, Nibbles};
6
7#[derive(Debug, Clone)]
9pub struct InMemoryTrieCursorFactory<'a, CF> {
10 cursor_factory: CF,
12 trie_updates: &'a TrieUpdatesSorted,
14}
15
16impl<'a, CF> InMemoryTrieCursorFactory<'a, CF> {
17 pub const fn new(cursor_factory: CF, trie_updates: &'a TrieUpdatesSorted) -> Self {
19 Self { cursor_factory, trie_updates }
20 }
21}
22
23impl<'a, CF: TrieCursorFactory> TrieCursorFactory for InMemoryTrieCursorFactory<'a, CF> {
24 type AccountTrieCursor = InMemoryTrieCursor<'a, CF::AccountTrieCursor>;
25 type StorageTrieCursor = InMemoryTrieCursor<'a, CF::StorageTrieCursor>;
26
27 fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor, DatabaseError> {
28 let cursor = self.cursor_factory.account_trie_cursor()?;
29 Ok(InMemoryTrieCursor::new(Some(cursor), self.trie_updates.account_nodes_ref()))
30 }
31
32 fn storage_trie_cursor(
33 &self,
34 hashed_address: B256,
35 ) -> Result<Self::StorageTrieCursor, DatabaseError> {
36 static EMPTY_UPDATES: Vec<(Nibbles, Option<BranchNodeCompact>)> = Vec::new();
38
39 let storage_trie_updates = self.trie_updates.storage_tries.get(&hashed_address);
40 let (storage_nodes, cleared) = storage_trie_updates
41 .map(|u| (u.storage_nodes_ref(), u.is_deleted()))
42 .unwrap_or((&EMPTY_UPDATES, false));
43
44 let cursor = if cleared {
45 None
46 } else {
47 Some(self.cursor_factory.storage_trie_cursor(hashed_address)?)
48 };
49
50 Ok(InMemoryTrieCursor::new(cursor, storage_nodes))
51 }
52}
53
54#[derive(Debug)]
57pub struct InMemoryTrieCursor<'a, C> {
58 cursor: Option<C>,
60 in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
62 last_key: Option<Nibbles>,
64}
65
66impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
67 pub fn new(
70 cursor: Option<C>,
71 trie_updates: &'a [(Nibbles, Option<BranchNodeCompact>)],
72 ) -> Self {
73 let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates);
74 Self { cursor, in_memory_cursor, last_key: None }
75 }
76
77 fn seek_inner(
78 &mut self,
79 key: Nibbles,
80 exact: bool,
81 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
82 let mut mem_entry = self.in_memory_cursor.seek(&key);
83 let mut db_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten();
84
85 if exact {
88 return Ok(match (mem_entry, db_entry) {
89 (Some((mem_key, entry_inner)), _) if mem_key == key => {
90 entry_inner.map(|node| (key, node))
91 }
92 (_, Some((db_key, node))) if db_key == key => Some((key, node)),
93 _ => None,
94 })
95 }
96
97 loop {
98 match (mem_entry, &db_entry) {
99 (Some((mem_key, None)), _)
100 if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
101 {
102 mem_entry = self.in_memory_cursor.first_after(&mem_key);
106 }
107 (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
108 mem_entry = self.in_memory_cursor.first_after(&mem_key);
111 db_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten();
112 }
113 (Some((mem_key, Some(node))), _)
114 if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
115 {
116 return Ok(Some((mem_key, node)))
119 }
120 _ => return Ok(db_entry),
125 }
126 }
127 }
128
129 fn next_inner(
130 &mut self,
131 last: Nibbles,
132 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
133 let Some(key) = last.increment() else { return Ok(None) };
134 self.seek_inner(key, false)
135 }
136}
137
138impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
139 fn seek_exact(
140 &mut self,
141 key: Nibbles,
142 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
143 let entry = self.seek_inner(key, true)?;
144 self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
145 Ok(entry)
146 }
147
148 fn seek(
149 &mut self,
150 key: Nibbles,
151 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
152 let entry = self.seek_inner(key, false)?;
153 self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
154 Ok(entry)
155 }
156
157 fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
158 let next = match &self.last_key {
159 Some(last) => {
160 let entry = self.next_inner(*last)?;
161 self.last_key = entry.as_ref().map(|entry| entry.0);
162 entry
163 }
164 None => None,
166 };
167 Ok(next)
168 }
169
170 fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
171 match &self.last_key {
172 Some(key) => Ok(Some(*key)),
173 None => Ok(self.cursor.as_mut().map(|c| c.current()).transpose()?.flatten()),
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::trie_cursor::mock::MockTrieCursor;
182 use parking_lot::Mutex;
183 use std::{collections::BTreeMap, sync::Arc};
184
185 #[derive(Debug)]
186 struct InMemoryTrieCursorTestCase {
187 db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
188 in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
189 expected_results: Vec<(Nibbles, BranchNodeCompact)>,
190 }
191
192 fn execute_test(test_case: InMemoryTrieCursorTestCase) {
193 let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
194 test_case.db_nodes.into_iter().collect();
195 let db_nodes_arc = Arc::new(db_nodes_map);
196 let visited_keys = Arc::new(Mutex::new(Vec::new()));
197 let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
198
199 let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &test_case.in_memory_nodes);
200
201 let mut results = Vec::new();
202
203 if let Some(first_expected) = test_case.expected_results.first() {
204 if let Ok(Some(entry)) = cursor.seek(first_expected.0) {
205 results.push(entry);
206 }
207 }
208
209 while let Ok(Some(entry)) = cursor.next() {
210 results.push(entry);
211 }
212
213 assert_eq!(
214 results, test_case.expected_results,
215 "Results mismatch.\nGot: {:?}\nExpected: {:?}",
216 results, test_case.expected_results
217 );
218 }
219
220 #[test]
221 fn test_empty_db_and_memory() {
222 let test_case = InMemoryTrieCursorTestCase {
223 db_nodes: vec![],
224 in_memory_nodes: vec![],
225 expected_results: vec![],
226 };
227 execute_test(test_case);
228 }
229
230 #[test]
231 fn test_only_db_nodes() {
232 let db_nodes = vec![
233 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
234 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
235 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
236 ];
237
238 let test_case = InMemoryTrieCursorTestCase {
239 db_nodes: db_nodes.clone(),
240 in_memory_nodes: vec![],
241 expected_results: db_nodes,
242 };
243 execute_test(test_case);
244 }
245
246 #[test]
247 fn test_only_in_memory_nodes() {
248 let in_memory_nodes = vec![
249 (
250 Nibbles::from_nibbles([0x1]),
251 Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
252 ),
253 (
254 Nibbles::from_nibbles([0x2]),
255 Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
256 ),
257 (
258 Nibbles::from_nibbles([0x3]),
259 Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
260 ),
261 ];
262
263 let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
264 .iter()
265 .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
266 .collect();
267
268 let test_case =
269 InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
270 execute_test(test_case);
271 }
272
273 #[test]
274 fn test_in_memory_overwrites_db() {
275 let db_nodes = vec![
276 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
277 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
278 ];
279
280 let in_memory_nodes = vec![
281 (
282 Nibbles::from_nibbles([0x1]),
283 Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
284 ),
285 (
286 Nibbles::from_nibbles([0x3]),
287 Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
288 ),
289 ];
290
291 let expected_results = vec![
292 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
293 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
294 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
295 ];
296
297 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
298 execute_test(test_case);
299 }
300
301 #[test]
302 fn test_in_memory_deletes_db_nodes() {
303 let db_nodes = vec![
304 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
305 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
306 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
307 ];
308
309 let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
310
311 let expected_results = vec![
312 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
313 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
314 ];
315
316 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
317 execute_test(test_case);
318 }
319
320 #[test]
321 fn test_complex_interleaving() {
322 let db_nodes = vec![
323 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
324 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
325 (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
326 (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
327 ];
328
329 let in_memory_nodes = vec![
330 (
331 Nibbles::from_nibbles([0x2]),
332 Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
333 ),
334 (Nibbles::from_nibbles([0x3]), None),
335 (
336 Nibbles::from_nibbles([0x4]),
337 Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
338 ),
339 (
340 Nibbles::from_nibbles([0x6]),
341 Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
342 ),
343 (Nibbles::from_nibbles([0x7]), None),
344 (
345 Nibbles::from_nibbles([0x8]),
346 Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
347 ),
348 ];
349
350 let expected_results = vec![
351 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
352 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
353 (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
354 (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
355 (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
356 (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
357 ];
358
359 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
360 execute_test(test_case);
361 }
362
363 #[test]
364 fn test_seek_exact() {
365 let db_nodes = vec![
366 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
367 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
368 ];
369
370 let in_memory_nodes = vec![(
371 Nibbles::from_nibbles([0x2]),
372 Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
373 )];
374
375 let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
376 let db_nodes_arc = Arc::new(db_nodes_map);
377 let visited_keys = Arc::new(Mutex::new(Vec::new()));
378 let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
379
380 let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
381
382 let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
383 assert_eq!(
384 result,
385 Some((
386 Nibbles::from_nibbles([0x2]),
387 BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
388 ))
389 );
390
391 let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
392 assert_eq!(
393 result,
394 Some((
395 Nibbles::from_nibbles([0x3]),
396 BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
397 ))
398 );
399
400 let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
401 assert_eq!(result, None);
402 }
403
404 #[test]
405 fn test_multiple_consecutive_deletes() {
406 let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
407 .map(|i| {
408 (
409 Nibbles::from_nibbles([i]),
410 BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
411 )
412 })
413 .collect();
414
415 let in_memory_nodes = vec![
416 (Nibbles::from_nibbles([0x3]), None),
417 (Nibbles::from_nibbles([0x4]), None),
418 (Nibbles::from_nibbles([0x5]), None),
419 (Nibbles::from_nibbles([0x6]), None),
420 ];
421
422 let expected_results = vec![
423 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
424 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
425 (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
426 (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
427 (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
428 (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
429 ];
430
431 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
432 execute_test(test_case);
433 }
434
435 #[test]
436 fn test_empty_db_with_in_memory_deletes() {
437 let in_memory_nodes = vec![
438 (Nibbles::from_nibbles([0x1]), None),
439 (
440 Nibbles::from_nibbles([0x2]),
441 Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
442 ),
443 (Nibbles::from_nibbles([0x3]), None),
444 ];
445
446 let expected_results = vec![(
447 Nibbles::from_nibbles([0x2]),
448 BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
449 )];
450
451 let test_case =
452 InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
453 execute_test(test_case);
454 }
455
456 #[test]
457 fn test_current_key_tracking() {
458 let db_nodes = vec![(
459 Nibbles::from_nibbles([0x2]),
460 BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
461 )];
462
463 let in_memory_nodes = vec![
464 (
465 Nibbles::from_nibbles([0x1]),
466 Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
467 ),
468 (
469 Nibbles::from_nibbles([0x3]),
470 Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
471 ),
472 ];
473
474 let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
475 let db_nodes_arc = Arc::new(db_nodes_map);
476 let visited_keys = Arc::new(Mutex::new(Vec::new()));
477 let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
478
479 let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
480
481 assert_eq!(cursor.current().unwrap(), None);
482
483 cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
484 assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
485
486 cursor.next().unwrap();
487 assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
488
489 cursor.next().unwrap();
490 assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
491 }
492}