Skip to main content

reth_trie/trie_cursor/
in_memory.rs

1use super::{TrieCursor, TrieCursorFactory, TrieStorageCursor};
2use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
3use alloy_primitives::B256;
4use reth_storage_errors::db::DatabaseError;
5use reth_trie_common::{BranchNodeCompact, Nibbles};
6
7/// The trie cursor factory for the trie updates.
8#[derive(Debug, Clone)]
9pub struct InMemoryTrieCursorFactory<CF, T> {
10    /// Underlying trie cursor factory.
11    cursor_factory: CF,
12    /// Reference to sorted trie updates.
13    trie_updates: T,
14}
15
16impl<CF, T> InMemoryTrieCursorFactory<CF, T> {
17    /// Create a new trie cursor factory.
18    pub const fn new(cursor_factory: CF, trie_updates: T) -> Self {
19        Self { cursor_factory, trie_updates }
20    }
21}
22
23impl<'overlay, CF, T> TrieCursorFactory for InMemoryTrieCursorFactory<CF, &'overlay T>
24where
25    CF: TrieCursorFactory + 'overlay,
26    T: AsRef<TrieUpdatesSorted>,
27{
28    type AccountTrieCursor<'cursor>
29        = InMemoryTrieCursor<'overlay, CF::AccountTrieCursor<'cursor>>
30    where
31        Self: 'cursor;
32
33    type StorageTrieCursor<'cursor>
34        = InMemoryTrieCursor<'overlay, CF::StorageTrieCursor<'cursor>>
35    where
36        Self: 'cursor;
37
38    fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
39        let cursor = self.cursor_factory.account_trie_cursor()?;
40        Ok(InMemoryTrieCursor::new_account(cursor, self.trie_updates.as_ref()))
41    }
42
43    fn storage_trie_cursor(
44        &self,
45        hashed_address: B256,
46    ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
47        let trie_updates = self.trie_updates.as_ref();
48        let cursor = self.cursor_factory.storage_trie_cursor(hashed_address)?;
49        Ok(InMemoryTrieCursor::new_storage(cursor, trie_updates, hashed_address))
50    }
51}
52
53/// A cursor to iterate over trie updates and corresponding database entries.
54/// It will always give precedence to the data from the trie updates.
55#[derive(Debug)]
56pub struct InMemoryTrieCursor<'a, C> {
57    /// The underlying cursor.
58    cursor: C,
59    /// Tracks whether the DB cursor is available, positioned, or exhausted.
60    db_cursor_state: DbCursorState,
61    /// Forward-only in-memory cursor over storage trie nodes.
62    in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
63    /// The key most recently returned from the Cursor.
64    last_key: Option<Nibbles>,
65    #[cfg(debug_assertions)]
66    /// Whether an initial seek was called.
67    seeked: bool,
68    /// Reference to the full trie updates.
69    trie_updates: &'a TrieUpdatesSorted,
70}
71
72#[derive(Debug)]
73enum DbCursorState {
74    NeedsPosition,
75    Positioned((Nibbles, BranchNodeCompact)),
76    Exhausted,
77    Wiped,
78}
79
80impl DbCursorState {
81    const fn new(cursor_wiped: bool) -> Self {
82        if cursor_wiped {
83            Self::Wiped
84        } else {
85            Self::NeedsPosition
86        }
87    }
88
89    const fn entry(&self) -> Option<&(Nibbles, BranchNodeCompact)> {
90        match self {
91            Self::Positioned(entry) => Some(entry),
92            Self::NeedsPosition | Self::Exhausted | Self::Wiped => None,
93        }
94    }
95
96    fn set_entry(&mut self, entry: Option<(Nibbles, BranchNodeCompact)>) {
97        *self = match entry {
98            Some(entry) => Self::Positioned(entry),
99            None => Self::Exhausted,
100        };
101    }
102}
103
104impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
105    /// Create new account trie cursor which combines a DB cursor and the trie updates.
106    pub fn new_account(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self {
107        let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates.account_nodes_ref());
108        Self {
109            cursor,
110            db_cursor_state: DbCursorState::NeedsPosition,
111            in_memory_cursor,
112            last_key: None,
113            #[cfg(debug_assertions)]
114            seeked: false,
115            trie_updates,
116        }
117    }
118
119    /// Create new storage trie cursor with full trie updates reference.
120    /// This allows the cursor to switch between storage tries when `set_hashed_address` is called.
121    pub fn new_storage(
122        cursor: C,
123        trie_updates: &'a TrieUpdatesSorted,
124        hashed_address: B256,
125    ) -> Self {
126        let (in_memory_cursor, cursor_wiped) =
127            Self::get_storage_overlay(trie_updates, hashed_address);
128        Self {
129            cursor,
130            db_cursor_state: DbCursorState::new(cursor_wiped),
131            in_memory_cursor,
132            last_key: None,
133            #[cfg(debug_assertions)]
134            seeked: false,
135            trie_updates,
136        }
137    }
138
139    /// Returns the storage overlay for `hashed_address` and whether it was deleted.
140    fn get_storage_overlay(
141        trie_updates: &'a TrieUpdatesSorted,
142        hashed_address: B256,
143    ) -> (ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>, bool) {
144        let storage_trie_updates = trie_updates.storage_tries_ref().get(&hashed_address);
145        let cursor_wiped = storage_trie_updates.is_some_and(|u| u.is_deleted());
146        let storage_nodes = storage_trie_updates.map(|u| u.storage_nodes_ref()).unwrap_or(&[]);
147
148        (ForwardInMemoryCursor::new(storage_nodes), cursor_wiped)
149    }
150
151    /// Returns a mutable reference to the underlying cursor if it's not wiped, None otherwise.
152    fn get_cursor_mut(&mut self) -> Option<&mut C> {
153        (!matches!(self.db_cursor_state, DbCursorState::Wiped)).then_some(&mut self.cursor)
154    }
155
156    /// Asserts that the next entry to be returned from the cursor is not previous to the last entry
157    /// returned.
158    fn set_last_key(&mut self, next_entry: &Option<(Nibbles, BranchNodeCompact)>) {
159        let next_key = next_entry.as_ref().map(|e| e.0);
160        debug_assert!(
161            self.last_key.is_none_or(|last| next_key.is_none_or(|next| next >= last)),
162            "Cannot return entry {:?} previous to the last returned entry at {:?}",
163            next_key,
164            self.last_key,
165        );
166        self.last_key = next_key;
167    }
168
169    /// Positions the DB cursor state using the underlying cursor when needed.
170    fn cursor_seek(&mut self, key: Nibbles) -> Result<(), DatabaseError> {
171        // Only seek if:
172        // 1. We have a cursor entry and need to seek forward (entry.0 < key), OR
173        // 2. The DB cursor needs to be positioned.
174        let should_seek = match &self.db_cursor_state {
175            DbCursorState::NeedsPosition => true,
176            DbCursorState::Positioned((entry_key, _)) => entry_key < &key,
177            DbCursorState::Exhausted | DbCursorState::Wiped => false,
178        };
179
180        if should_seek {
181            let entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten();
182            self.db_cursor_state.set_entry(entry);
183        }
184
185        Ok(())
186    }
187
188    /// Advances the DB cursor state to the subsequent entry using the underlying cursor.
189    fn cursor_next(&mut self) -> Result<(), DatabaseError> {
190        #[cfg(debug_assertions)]
191        {
192            debug_assert!(self.seeked);
193            debug_assert!(!matches!(self.db_cursor_state, DbCursorState::NeedsPosition));
194        }
195
196        // Exhausted and wiped states are stable; only advance if the DB cursor currently points to
197        // an entry.
198        if matches!(self.db_cursor_state, DbCursorState::Positioned(_)) {
199            let entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten();
200            self.db_cursor_state.set_entry(entry);
201        }
202
203        Ok(())
204    }
205
206    /// Compares the current in-memory entry with the current entry of the cursor, and applies the
207    /// in-memory entry to the cursor entry as an overlay.
208    //
209    /// This may consume and move forward the current entries when the overlay indicates a removed
210    /// node.
211    fn choose_next_entry(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
212        loop {
213            let mem_entry = self.in_memory_cursor.current().cloned();
214            let db_entry = self.db_cursor_state.entry();
215
216            match (mem_entry, db_entry) {
217                (Some((mem_key, None)), _)
218                    if db_entry.is_none_or(|(db_key, _)| &mem_key < db_key) =>
219                {
220                    // If overlay has a removed node but DB cursor is exhausted or ahead of the
221                    // in-memory cursor then move ahead in-memory, as there might be further
222                    // non-removed overlay nodes.
223                    self.in_memory_cursor.first_after(&mem_key);
224                }
225                (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
226                    // If overlay has a removed node which is returned from DB then move both
227                    // cursors ahead to the next key.
228                    self.in_memory_cursor.first_after(&mem_key);
229                    self.cursor_next()?;
230                }
231                (Some((mem_key, Some(node))), _)
232                    if db_entry.is_none_or(|(db_key, _)| &mem_key <= db_key) =>
233                {
234                    // If overlay returns a node prior to the DB's node, or the DB is exhausted,
235                    // then we return the overlay's node.
236                    return Ok(Some((mem_key, node)))
237                }
238                // All other cases:
239                // - mem_key > db_key
240                // - overlay is exhausted
241                // Return the db_entry. If DB is also exhausted then this returns None.
242                _ => return Ok(db_entry.cloned()),
243            }
244        }
245    }
246}
247
248impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
249    fn seek_exact(
250        &mut self,
251        key: Nibbles,
252    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
253        let mem_entry = self.in_memory_cursor.seek(&key);
254
255        if let Some((mem_key, entry_inner)) = mem_entry &&
256            *mem_key == key
257        {
258            #[cfg(debug_assertions)]
259            {
260                self.seeked = true;
261            }
262
263            // An exact overlay hit can move the logical cursor ahead without touching the DB. If
264            // the DB cursor was still behind this key, force a re-seek before the next DB-backed
265            // operation so `next()` cannot return a stale earlier entry.
266            if matches!(&self.db_cursor_state, DbCursorState::Positioned((db_key, _)) if db_key < &key)
267            {
268                self.db_cursor_state = DbCursorState::NeedsPosition;
269            }
270
271            let entry = entry_inner.clone().map(|node| (key, node));
272            self.set_last_key(&entry);
273            return Ok(entry)
274        }
275
276        self.cursor_seek(key)?;
277
278        #[cfg(debug_assertions)]
279        {
280            self.seeked = true;
281        }
282
283        let entry = match self.db_cursor_state.entry() {
284            Some((db_key, node)) if db_key == &key => Some((key, node.clone())),
285            _ => None,
286        };
287
288        self.set_last_key(&entry);
289        Ok(entry)
290    }
291
292    fn seek(
293        &mut self,
294        key: Nibbles,
295    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
296        let mem_entry = self.in_memory_cursor.seek(&key);
297
298        if let Some((mem_key, Some(node))) = mem_entry &&
299            *mem_key == key
300        {
301            #[cfg(debug_assertions)]
302            {
303                self.seeked = true;
304            }
305
306            // An exact overlay hit is the first logical entry at or after `key`, so the DB cursor
307            // can stay lazy until a later operation needs it.
308            if matches!(&self.db_cursor_state, DbCursorState::Positioned((db_key, _)) if db_key < &key)
309            {
310                self.db_cursor_state = DbCursorState::NeedsPosition;
311            }
312
313            let entry = Some((key, node.clone()));
314            self.set_last_key(&entry);
315            return Ok(entry)
316        }
317
318        self.cursor_seek(key)?;
319
320        #[cfg(debug_assertions)]
321        {
322            self.seeked = true;
323        }
324
325        let entry = self.choose_next_entry()?;
326        self.set_last_key(&entry);
327        Ok(entry)
328    }
329
330    fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
331        #[cfg(debug_assertions)]
332        {
333            debug_assert!(self.seeked, "Cursor must be seek'd before next is called");
334        }
335
336        // A `last_key` of `None` indicates that the cursor is exhausted.
337        let Some(last_key) = self.last_key else {
338            return Ok(None);
339        };
340
341        // If either cursor is currently pointing to the last entry which was returned then consume
342        // that entry so that `choose_next_entry` is looking at the subsequent one.
343        if let Some((key, _)) = self.in_memory_cursor.current() &&
344            key == &last_key
345        {
346            self.in_memory_cursor.first_after(&last_key);
347        }
348
349        if matches!(self.db_cursor_state, DbCursorState::NeedsPosition) {
350            self.cursor_seek(last_key)?;
351        }
352
353        if let Some((key, _)) = self.db_cursor_state.entry() &&
354            key == &last_key
355        {
356            self.cursor_next()?;
357        }
358
359        let entry = self.choose_next_entry()?;
360        self.set_last_key(&entry);
361        Ok(entry)
362    }
363
364    fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
365        match &self.last_key {
366            Some(key) => Ok(Some(*key)),
367            None => Ok(self.get_cursor_mut().map(|c| c.current()).transpose()?.flatten()),
368        }
369    }
370
371    fn reset(&mut self) {
372        self.cursor.reset();
373        self.in_memory_cursor.reset();
374
375        self.db_cursor_state = DbCursorState::NeedsPosition;
376        self.last_key = None;
377        #[cfg(debug_assertions)]
378        {
379            self.seeked = false;
380        }
381    }
382}
383
384impl<C: TrieStorageCursor> TrieStorageCursor for InMemoryTrieCursor<'_, C> {
385    fn set_hashed_address(&mut self, hashed_address: B256) {
386        self.reset();
387        self.cursor.set_hashed_address(hashed_address);
388        let (in_memory_cursor, cursor_wiped) =
389            Self::get_storage_overlay(self.trie_updates, hashed_address);
390        self.in_memory_cursor = in_memory_cursor;
391        self.db_cursor_state = DbCursorState::new(cursor_wiped);
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::trie_cursor::mock::MockTrieCursor;
399    use parking_lot::Mutex;
400    use std::{collections::BTreeMap, sync::Arc};
401
402    #[derive(Debug)]
403    struct InMemoryTrieCursorTestCase {
404        db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
405        in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
406        expected_results: Vec<(Nibbles, BranchNodeCompact)>,
407    }
408
409    fn execute_test(test_case: InMemoryTrieCursorTestCase) {
410        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
411            test_case.db_nodes.into_iter().collect();
412        let db_nodes_arc = Arc::new(db_nodes_map);
413        let visited_keys = Arc::new(Mutex::new(Vec::new()));
414        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
415
416        let trie_updates = TrieUpdatesSorted::new(test_case.in_memory_nodes, Default::default());
417        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
418
419        let mut results = Vec::new();
420
421        if let Some(first_expected) = test_case.expected_results.first() &&
422            let Ok(Some(entry)) = cursor.seek(first_expected.0)
423        {
424            results.push(entry);
425        }
426
427        if !test_case.expected_results.is_empty() {
428            while let Ok(Some(entry)) = cursor.next() {
429                results.push(entry);
430            }
431        }
432
433        assert_eq!(
434            results, test_case.expected_results,
435            "Results mismatch.\nGot: {:?}\nExpected: {:?}",
436            results, test_case.expected_results
437        );
438    }
439
440    #[test]
441    fn test_empty_db_and_memory() {
442        let test_case = InMemoryTrieCursorTestCase {
443            db_nodes: vec![],
444            in_memory_nodes: vec![],
445            expected_results: vec![],
446        };
447        execute_test(test_case);
448    }
449
450    #[test]
451    fn test_only_db_nodes() {
452        let db_nodes = vec![
453            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
454            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
455            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
456        ];
457
458        let test_case = InMemoryTrieCursorTestCase {
459            db_nodes: db_nodes.clone(),
460            in_memory_nodes: vec![],
461            expected_results: db_nodes,
462        };
463        execute_test(test_case);
464    }
465
466    #[test]
467    fn test_only_in_memory_nodes() {
468        let in_memory_nodes = vec![
469            (
470                Nibbles::from_nibbles([0x1]),
471                Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
472            ),
473            (
474                Nibbles::from_nibbles([0x2]),
475                Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
476            ),
477            (
478                Nibbles::from_nibbles([0x3]),
479                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
480            ),
481        ];
482
483        let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
484            .iter()
485            .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
486            .collect();
487
488        let test_case =
489            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
490        execute_test(test_case);
491    }
492
493    #[test]
494    fn test_in_memory_overwrites_db() {
495        let db_nodes = vec![
496            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
497            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
498        ];
499
500        let in_memory_nodes = vec![
501            (
502                Nibbles::from_nibbles([0x1]),
503                Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
504            ),
505            (
506                Nibbles::from_nibbles([0x3]),
507                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
508            ),
509        ];
510
511        let expected_results = vec![
512            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
513            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
514            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
515        ];
516
517        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
518        execute_test(test_case);
519    }
520
521    #[test]
522    fn test_in_memory_deletes_db_nodes() {
523        let db_nodes = vec![
524            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
525            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
526            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
527        ];
528
529        let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
530
531        let expected_results = vec![
532            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
533            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
534        ];
535
536        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
537        execute_test(test_case);
538    }
539
540    #[test]
541    fn test_complex_interleaving() {
542        let db_nodes = vec![
543            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
544            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
545            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
546            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
547        ];
548
549        let in_memory_nodes = vec![
550            (
551                Nibbles::from_nibbles([0x2]),
552                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
553            ),
554            (Nibbles::from_nibbles([0x3]), None),
555            (
556                Nibbles::from_nibbles([0x4]),
557                Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
558            ),
559            (
560                Nibbles::from_nibbles([0x6]),
561                Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
562            ),
563            (Nibbles::from_nibbles([0x7]), None),
564            (
565                Nibbles::from_nibbles([0x8]),
566                Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
567            ),
568        ];
569
570        let expected_results = vec![
571            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
572            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
573            (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
574            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
575            (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
576            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
577        ];
578
579        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
580        execute_test(test_case);
581    }
582
583    #[test]
584    fn test_seek_exact() {
585        let db_nodes = vec![
586            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
587            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
588        ];
589
590        let in_memory_nodes = vec![(
591            Nibbles::from_nibbles([0x2]),
592            Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
593        )];
594
595        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
596        let db_nodes_arc = Arc::new(db_nodes_map);
597        let visited_keys = Arc::new(Mutex::new(Vec::new()));
598        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys.clone());
599
600        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
601        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
602
603        let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
604        assert_eq!(
605            result,
606            Some((
607                Nibbles::from_nibbles([0x2]),
608                BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
609            ))
610        );
611        assert!(visited_keys.lock().is_empty(), "exact overlay hit should not touch the DB cursor");
612
613        let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
614        assert_eq!(
615            result,
616            Some((
617                Nibbles::from_nibbles([0x3]),
618                BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
619            ))
620        );
621
622        let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
623        assert_eq!(result, None);
624    }
625
626    #[test]
627    fn test_seek_overlay_exact_hit_does_not_touch_db_until_next() {
628        let db_nodes = vec![
629            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
630            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
631        ];
632
633        let in_memory_nodes = vec![(
634            Nibbles::from_nibbles([0x2]),
635            Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
636        )];
637
638        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
639        let db_nodes_arc = Arc::new(db_nodes_map);
640        let visited_keys = Arc::new(Mutex::new(Vec::new()));
641        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys.clone());
642
643        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
644        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
645
646        let result = cursor.seek(Nibbles::from_nibbles([0x2])).unwrap();
647        assert_eq!(
648            result,
649            Some((
650                Nibbles::from_nibbles([0x2]),
651                BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)
652            ))
653        );
654        assert!(visited_keys.lock().is_empty(), "exact overlay hit should not touch the DB cursor");
655
656        let result = cursor.next().unwrap();
657        assert_eq!(
658            result,
659            Some((
660                Nibbles::from_nibbles([0x3]),
661                BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
662            ))
663        );
664        assert!(!visited_keys.lock().is_empty(), "next should lazily position the DB cursor");
665    }
666
667    #[test]
668    fn test_seek_overlay_exact_hit_repositions_stale_db_on_next() {
669        let db_nodes = vec![
670            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
671            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
672        ];
673
674        let in_memory_nodes = vec![(
675            Nibbles::from_nibbles([0x2]),
676            Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
677        )];
678
679        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
680        let db_nodes_arc = Arc::new(db_nodes_map);
681        let visited_keys = Arc::new(Mutex::new(Vec::new()));
682        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys.clone());
683
684        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
685        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
686
687        let result = cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
688        assert_eq!(
689            result,
690            Some((
691                Nibbles::from_nibbles([0x1]),
692                BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)
693            ))
694        );
695        assert_eq!(visited_keys.lock().len(), 1);
696
697        let result = cursor.seek(Nibbles::from_nibbles([0x2])).unwrap();
698        assert_eq!(
699            result,
700            Some((
701                Nibbles::from_nibbles([0x2]),
702                BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
703            ))
704        );
705        assert_eq!(visited_keys.lock().len(), 1, "exact overlay hit should not seek the DB");
706
707        let result = cursor.next().unwrap();
708        assert_eq!(
709            result,
710            Some((
711                Nibbles::from_nibbles([0x3]),
712                BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
713            ))
714        );
715    }
716
717    #[test]
718    fn test_multiple_consecutive_deletes() {
719        let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
720            .map(|i| {
721                (
722                    Nibbles::from_nibbles([i]),
723                    BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
724                )
725            })
726            .collect();
727
728        let in_memory_nodes = vec![
729            (Nibbles::from_nibbles([0x3]), None),
730            (Nibbles::from_nibbles([0x4]), None),
731            (Nibbles::from_nibbles([0x5]), None),
732            (Nibbles::from_nibbles([0x6]), None),
733        ];
734
735        let expected_results = vec![
736            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
737            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
738            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
739            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
740            (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
741            (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
742        ];
743
744        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
745        execute_test(test_case);
746    }
747
748    #[test]
749    fn test_empty_db_with_in_memory_deletes() {
750        let in_memory_nodes = vec![
751            (Nibbles::from_nibbles([0x1]), None),
752            (
753                Nibbles::from_nibbles([0x2]),
754                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
755            ),
756            (Nibbles::from_nibbles([0x3]), None),
757        ];
758
759        let expected_results = vec![(
760            Nibbles::from_nibbles([0x2]),
761            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
762        )];
763
764        let test_case =
765            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
766        execute_test(test_case);
767    }
768
769    #[test]
770    fn test_current_key_tracking() {
771        let db_nodes = vec![(
772            Nibbles::from_nibbles([0x2]),
773            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
774        )];
775
776        let in_memory_nodes = vec![
777            (
778                Nibbles::from_nibbles([0x1]),
779                Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
780            ),
781            (
782                Nibbles::from_nibbles([0x3]),
783                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
784            ),
785        ];
786
787        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
788        let db_nodes_arc = Arc::new(db_nodes_map);
789        let visited_keys = Arc::new(Mutex::new(Vec::new()));
790        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
791
792        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
793        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
794
795        assert_eq!(cursor.current().unwrap(), None);
796
797        cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
798        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
799
800        cursor.next().unwrap();
801        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
802
803        cursor.next().unwrap();
804        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
805    }
806
807    #[test]
808    fn test_all_storage_slots_deleted_not_wiped_exact_keys() {
809        use tracing::debug;
810        reth_tracing::init_test_tracing();
811
812        // This test reproduces an edge case where:
813        // - cursor is not None (not wiped)
814        // - All in-memory entries are deletions (None values)
815        // - Database has corresponding entries
816        // - Expected: NO leaves should be returned (all deleted)
817
818        // Generate 42 trie node entries with keys distributed across the keyspace
819        let mut db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (0..10)
820            .map(|i| {
821                let key_bytes = vec![(i * 6) as u8, i as u8]; // Spread keys across keyspace
822                let nibbles = Nibbles::unpack(key_bytes);
823                (nibbles, BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None))
824            })
825            .collect();
826
827        db_nodes.sort_by_key(|(key, _)| *key);
828        db_nodes.dedup_by_key(|(key, _)| *key);
829
830        for (key, _) in &db_nodes {
831            debug!("node at {key:?}");
832        }
833
834        // Create in-memory entries with same keys but all None values (deletions)
835        let in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)> =
836            db_nodes.iter().map(|(key, _)| (*key, None)).collect();
837
838        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
839        let db_nodes_arc = Arc::new(db_nodes_map);
840        let visited_keys = Arc::new(Mutex::new(Vec::new()));
841        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
842
843        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
844        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
845
846        // Seek to beginning should return None (all nodes are deleted)
847        tracing::debug!("seeking to 0x");
848        let result = cursor.seek(Nibbles::default()).unwrap();
849        assert_eq!(
850            result, None,
851            "Expected no entries when all nodes are deleted, but got {:?}",
852            result
853        );
854
855        // Test seek operations at various positions - all should return None
856        let seek_keys = vec![
857            Nibbles::unpack([0x00]),
858            Nibbles::unpack([0x5d]),
859            Nibbles::unpack([0x5e]),
860            Nibbles::unpack([0x5f]),
861            Nibbles::unpack([0xc2]),
862            Nibbles::unpack([0xc5]),
863            Nibbles::unpack([0xc9]),
864            Nibbles::unpack([0xf0]),
865        ];
866
867        for seek_key in seek_keys {
868            tracing::debug!("seeking to {seek_key:?}");
869            let result = cursor.seek(seek_key).unwrap();
870            assert_eq!(
871                result, None,
872                "Expected None when seeking to {:?} but got {:?}",
873                seek_key, result
874            );
875        }
876
877        // next() should also always return None
878        let result = cursor.next().unwrap();
879        assert_eq!(result, None, "Expected None from next() but got {:?}", result);
880    }
881
882    mod proptest_tests {
883        use super::*;
884        use itertools::Itertools;
885        use proptest::prelude::*;
886
887        /// Merge `db_nodes` with `in_memory_nodes`, applying the in-memory overlay.
888        /// This properly handles deletions (None values in `in_memory_nodes`).
889        fn merge_with_overlay(
890            db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
891            in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
892        ) -> Vec<(Nibbles, BranchNodeCompact)> {
893            db_nodes
894                .into_iter()
895                .merge_join_by(in_memory_nodes, |db_entry, mem_entry| db_entry.0.cmp(&mem_entry.0))
896                .filter_map(|entry| match entry {
897                    // Only in db: keep it
898                    itertools::EitherOrBoth::Left((key, node)) => Some((key, node)),
899                    // Only in memory: keep if not a deletion
900                    itertools::EitherOrBoth::Right((key, node_opt)) => {
901                        node_opt.map(|node| (key, node))
902                    }
903                    // In both: memory takes precedence (keep if not a deletion)
904                    itertools::EitherOrBoth::Both(_, (key, node_opt)) => {
905                        node_opt.map(|node| (key, node))
906                    }
907                })
908                .collect()
909        }
910
911        /// Generate a strategy for a `BranchNodeCompact` with simplified parameters.
912        /// The constraints are:
913        /// - `tree_mask` must be a subset of `state_mask`
914        /// - `hash_mask` must be a subset of `state_mask`
915        /// - `hash_mask.count_ones()` must equal `hashes.len()`
916        ///
917        /// To keep it simple, we use an empty hashes vec and `hash_mask` of 0.
918        fn branch_node_strategy() -> impl Strategy<Value = BranchNodeCompact> {
919            any::<u16>()
920                .prop_flat_map(|state_mask| {
921                    let tree_mask_strategy = any::<u16>().prop_map(move |tree| tree & state_mask);
922                    (Just(state_mask), tree_mask_strategy)
923                })
924                .prop_map(|(state_mask, tree_mask)| {
925                    BranchNodeCompact::new(state_mask, tree_mask, 0, vec![], None)
926                })
927        }
928
929        /// Generate a sorted vector of (Nibbles, `BranchNodeCompact`) entries
930        fn sorted_db_nodes_strategy() -> impl Strategy<Value = Vec<(Nibbles, BranchNodeCompact)>> {
931            prop::collection::vec(
932                (prop::collection::vec(any::<u8>(), 0..2), branch_node_strategy()),
933                0..20,
934            )
935            .prop_map(|entries| {
936                // Convert Vec<u8> to Nibbles and sort
937                let mut result: Vec<(Nibbles, BranchNodeCompact)> = entries
938                    .into_iter()
939                    .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node))
940                    .collect();
941                result.sort_by_key(|a| a.0);
942                result.dedup_by(|a, b| a.0 == b.0);
943                result
944            })
945        }
946
947        /// Generate a sorted vector of (Nibbles, Option<BranchNodeCompact>) entries
948        fn sorted_in_memory_nodes_strategy(
949        ) -> impl Strategy<Value = Vec<(Nibbles, Option<BranchNodeCompact>)>> {
950            prop::collection::vec(
951                (
952                    prop::collection::vec(any::<u8>(), 0..2),
953                    prop::option::of(branch_node_strategy()),
954                ),
955                0..20,
956            )
957            .prop_map(|entries| {
958                // Convert Vec<u8> to Nibbles and sort
959                let mut result: Vec<(Nibbles, Option<BranchNodeCompact>)> = entries
960                    .into_iter()
961                    .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node))
962                    .collect();
963                result.sort_by_key(|a| a.0);
964                result.dedup_by(|a, b| a.0 == b.0);
965                result
966            })
967        }
968
969        proptest! {
970            #![proptest_config(ProptestConfig::with_cases(10000))]
971
972            #[test]
973            fn proptest_in_memory_trie_cursor(
974                db_nodes in sorted_db_nodes_strategy(),
975                in_memory_nodes in sorted_in_memory_nodes_strategy(),
976                op_choices in prop::collection::vec(any::<u8>(), 10..500),
977            ) {
978                reth_tracing::init_test_tracing();
979                use tracing::debug;
980
981                debug!(
982                    db_paths=?db_nodes.iter().map(|(k, _)| k).collect::<Vec<_>>(),
983                    in_mem_nodes=?in_memory_nodes.iter().map(|(k, v)| (k, v.is_some())).collect::<Vec<_>>(),
984                    num_op_choices=?op_choices.len(),
985                    "Starting proptest!",
986                );
987
988                // Create the expected results by merging the two sorted vectors,
989                // properly handling deletions (None values in in_memory_nodes)
990                let expected_combined = merge_with_overlay(db_nodes.clone(), in_memory_nodes.clone());
991
992                // Collect all keys for operation generation
993                let all_keys: Vec<Nibbles> = expected_combined.iter().map(|(k, _)| *k).collect();
994
995                // Create a control cursor using the combined result with a mock cursor
996                let control_db_map: BTreeMap<Nibbles, BranchNodeCompact> =
997                    expected_combined.into_iter().collect();
998                let control_db_arc = Arc::new(control_db_map);
999                let control_visited_keys = Arc::new(Mutex::new(Vec::new()));
1000                let mut control_cursor = MockTrieCursor::new(control_db_arc, control_visited_keys);
1001
1002                // Create the InMemoryTrieCursor being tested
1003                let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
1004                    db_nodes.into_iter().collect();
1005                let db_nodes_arc = Arc::new(db_nodes_map);
1006                let visited_keys = Arc::new(Mutex::new(Vec::new()));
1007                let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
1008                let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
1009                let mut test_cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
1010
1011                // Test: seek to the beginning first
1012                let control_first = control_cursor.seek(Nibbles::default()).unwrap();
1013                let test_first = test_cursor.seek(Nibbles::default()).unwrap();
1014                debug!(
1015                    control=?control_first.as_ref().map(|(k, _)| k),
1016                    test=?test_first.as_ref().map(|(k, _)| k),
1017                    "Initial seek returned",
1018                );
1019                assert_eq!(control_first, test_first, "Initial seek mismatch");
1020
1021                // If both cursors returned None, nothing to test
1022                if control_first.is_none() && test_first.is_none() {
1023                    return Ok(());
1024                }
1025
1026                // Track the last key returned from the cursor
1027                let mut last_returned_key = control_first.as_ref().map(|(k, _)| *k);
1028
1029                // Execute a sequence of random operations
1030                for choice in op_choices {
1031                    let op_type = choice % 3;
1032
1033                    match op_type {
1034                        0 => {
1035                            // Next operation
1036                            let control_result = control_cursor.next().unwrap();
1037                            let test_result = test_cursor.next().unwrap();
1038                            debug!(
1039                                control=?control_result.as_ref().map(|(k, _)| k),
1040                                test=?test_result.as_ref().map(|(k, _)| k),
1041                                "Next returned",
1042                            );
1043                            assert_eq!(control_result, test_result, "Next operation mismatch");
1044
1045                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
1046
1047                            // Stop if both cursors are exhausted
1048                            if control_result.is_none() && test_result.is_none() {
1049                                break;
1050                            }
1051                        }
1052                        1 => {
1053                            // Seek operation - choose a key >= last_returned_key
1054                            if all_keys.is_empty() {
1055                                continue;
1056                            }
1057
1058                            let valid_keys: Vec<_> = all_keys
1059                                .iter()
1060                                .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
1061                                .collect();
1062
1063                            if valid_keys.is_empty() {
1064                                continue;
1065                            }
1066
1067                            let key = *valid_keys[choice as usize % valid_keys.len()];
1068
1069                            let control_result = control_cursor.seek(key).unwrap();
1070                            let test_result = test_cursor.seek(key).unwrap();
1071                            debug!(
1072                                control=?control_result.as_ref().map(|(k, _)| k),
1073                                test=?test_result.as_ref().map(|(k, _)| k),
1074                                ?key,
1075                                "Seek returned",
1076                            );
1077                            assert_eq!(control_result, test_result, "Seek operation mismatch for key {:?}", key);
1078
1079                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
1080
1081                            // Stop if both cursors are exhausted
1082                            if control_result.is_none() && test_result.is_none() {
1083                                break;
1084                            }
1085                        }
1086                        _ => {
1087                            // SeekExact operation - choose a key >= last_returned_key
1088                            if all_keys.is_empty() {
1089                                continue;
1090                            }
1091
1092                            let valid_keys: Vec<_> = all_keys
1093                                .iter()
1094                                .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
1095                                .collect();
1096
1097                            if valid_keys.is_empty() {
1098                                continue;
1099                            }
1100
1101                            let key = *valid_keys[choice as usize  % valid_keys.len()];
1102
1103                            let control_result = control_cursor.seek_exact(key).unwrap();
1104                            let test_result = test_cursor.seek_exact(key).unwrap();
1105                            debug!(
1106                                control=?control_result.as_ref().map(|(k, _)| k),
1107                                test=?test_result.as_ref().map(|(k, _)| k),
1108                                ?key,
1109                                "SeekExact returned",
1110                            );
1111                            assert_eq!(control_result, test_result, "SeekExact operation mismatch for key {:?}", key);
1112
1113                            // seek_exact updates the last_key internally but only if it found something
1114                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
1115                        }
1116                    }
1117                }
1118            }
1119        }
1120    }
1121}