Skip to main content

reth_trie/trie_cursor/
in_memory.rs

1use super::{TrieCursor, TrieCursorFactory, TrieStorageCursor};
2use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
3use alloy_primitives::B256;
4use reth_storage_errors::db::DatabaseError;
5use reth_trie_common::{BranchNodeCompact, Nibbles};
6
7/// The trie cursor factory for the trie updates.
8#[derive(Debug, Clone)]
9pub struct InMemoryTrieCursorFactory<CF, T> {
10    /// Underlying trie cursor factory.
11    cursor_factory: CF,
12    /// Reference to sorted trie updates.
13    trie_updates: T,
14}
15
16impl<CF, T> InMemoryTrieCursorFactory<CF, T> {
17    /// Create a new trie cursor factory.
18    pub const fn new(cursor_factory: CF, trie_updates: T) -> Self {
19        Self { cursor_factory, trie_updates }
20    }
21}
22
23impl<'overlay, CF, T> TrieCursorFactory for InMemoryTrieCursorFactory<CF, &'overlay T>
24where
25    CF: TrieCursorFactory + 'overlay,
26    T: AsRef<TrieUpdatesSorted>,
27{
28    type AccountTrieCursor<'cursor>
29        = InMemoryTrieCursor<'overlay, CF::AccountTrieCursor<'cursor>>
30    where
31        Self: 'cursor;
32
33    type StorageTrieCursor<'cursor>
34        = InMemoryTrieCursor<'overlay, CF::StorageTrieCursor<'cursor>>
35    where
36        Self: 'cursor;
37
38    fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
39        let cursor = self.cursor_factory.account_trie_cursor()?;
40        Ok(InMemoryTrieCursor::new_account(cursor, self.trie_updates.as_ref()))
41    }
42
43    fn storage_trie_cursor(
44        &self,
45        hashed_address: B256,
46    ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
47        let trie_updates = self.trie_updates.as_ref();
48        let cursor = self.cursor_factory.storage_trie_cursor(hashed_address)?;
49        Ok(InMemoryTrieCursor::new_storage(cursor, trie_updates, hashed_address))
50    }
51}
52
53/// A cursor to iterate over trie updates and corresponding database entries.
54/// It will always give precedence to the data from the trie updates.
55#[derive(Debug)]
56pub struct InMemoryTrieCursor<'a, C> {
57    /// The underlying cursor.
58    cursor: C,
59    /// Tracks whether the DB cursor is available, positioned, or exhausted.
60    db_cursor_state: DbCursorState,
61    /// Forward-only in-memory cursor over storage trie nodes.
62    in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
63    /// The key most recently returned from the Cursor.
64    last_key: Option<Nibbles>,
65    #[cfg(debug_assertions)]
66    /// Whether an initial seek was called.
67    seeked: bool,
68    /// Reference to the full trie updates.
69    trie_updates: &'a TrieUpdatesSorted,
70}
71
72#[derive(Debug)]
73enum DbCursorState {
74    NeedsPosition,
75    Positioned((Nibbles, BranchNodeCompact)),
76    Exhausted,
77    Wiped,
78}
79
80impl DbCursorState {
81    const fn new(cursor_wiped: bool) -> Self {
82        if cursor_wiped {
83            Self::Wiped
84        } else {
85            Self::NeedsPosition
86        }
87    }
88
89    const fn entry(&self) -> Option<&(Nibbles, BranchNodeCompact)> {
90        match self {
91            Self::Positioned(entry) => Some(entry),
92            Self::NeedsPosition | Self::Exhausted | Self::Wiped => None,
93        }
94    }
95
96    fn set_entry(&mut self, entry: Option<(Nibbles, BranchNodeCompact)>) {
97        *self = match entry {
98            Some(entry) => Self::Positioned(entry),
99            None => Self::Exhausted,
100        };
101    }
102}
103
104impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
105    /// Create new account trie cursor which combines a DB cursor and the trie updates.
106    pub fn new_account(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self {
107        let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates.account_nodes_ref());
108        Self {
109            cursor,
110            db_cursor_state: DbCursorState::NeedsPosition,
111            in_memory_cursor,
112            last_key: None,
113            #[cfg(debug_assertions)]
114            seeked: false,
115            trie_updates,
116        }
117    }
118
119    /// Create new storage trie cursor with full trie updates reference.
120    /// This allows the cursor to switch between storage tries when `set_hashed_address` is called.
121    pub fn new_storage(
122        cursor: C,
123        trie_updates: &'a TrieUpdatesSorted,
124        hashed_address: B256,
125    ) -> Self {
126        let (in_memory_cursor, cursor_wiped) =
127            Self::get_storage_overlay(trie_updates, hashed_address);
128        Self {
129            cursor,
130            db_cursor_state: DbCursorState::new(cursor_wiped),
131            in_memory_cursor,
132            last_key: None,
133            #[cfg(debug_assertions)]
134            seeked: false,
135            trie_updates,
136        }
137    }
138
139    /// Returns the storage overlay for `hashed_address` and whether it was deleted.
140    fn get_storage_overlay(
141        trie_updates: &'a TrieUpdatesSorted,
142        hashed_address: B256,
143    ) -> (ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>, bool) {
144        let storage_trie_updates = trie_updates.storage_tries_ref().get(&hashed_address);
145        let cursor_wiped = storage_trie_updates.is_some_and(|u| u.is_deleted());
146        let storage_nodes = storage_trie_updates.map(|u| u.storage_nodes_ref()).unwrap_or(&[]);
147
148        (ForwardInMemoryCursor::new(storage_nodes), cursor_wiped)
149    }
150
151    /// Returns a mutable reference to the underlying cursor if it's not wiped, None otherwise.
152    fn get_cursor_mut(&mut self) -> Option<&mut C> {
153        (!matches!(self.db_cursor_state, DbCursorState::Wiped)).then_some(&mut self.cursor)
154    }
155
156    /// Asserts that the next entry to be returned from the cursor is not previous to the last entry
157    /// returned.
158    fn set_last_key(&mut self, next_entry: &Option<(Nibbles, BranchNodeCompact)>) {
159        let next_key = next_entry.as_ref().map(|e| e.0);
160        debug_assert!(
161            self.last_key.is_none_or(|last| next_key.is_none_or(|next| next >= last)),
162            "Cannot return entry {:?} previous to the last returned entry at {:?}",
163            next_key,
164            self.last_key,
165        );
166        self.last_key = next_key;
167    }
168
169    /// Positions the DB cursor state using the underlying cursor when needed.
170    fn cursor_seek(&mut self, key: Nibbles) -> Result<(), DatabaseError> {
171        // Only seek if:
172        // 1. We have a cursor entry and need to seek forward (entry.0 < key), OR
173        // 2. The DB cursor needs to be positioned.
174        let should_seek = match &self.db_cursor_state {
175            DbCursorState::NeedsPosition => true,
176            DbCursorState::Positioned((entry_key, _)) => entry_key < &key,
177            DbCursorState::Exhausted | DbCursorState::Wiped => false,
178        };
179
180        if should_seek {
181            let entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten();
182            self.db_cursor_state.set_entry(entry);
183        }
184
185        Ok(())
186    }
187
188    /// Advances the DB cursor state to the subsequent entry using the underlying cursor.
189    fn cursor_next(&mut self) -> Result<(), DatabaseError> {
190        #[cfg(debug_assertions)]
191        {
192            debug_assert!(self.seeked);
193            debug_assert!(!matches!(self.db_cursor_state, DbCursorState::NeedsPosition));
194        }
195
196        // Exhausted and wiped states are stable; only advance if the DB cursor currently points to
197        // an entry.
198        if matches!(self.db_cursor_state, DbCursorState::Positioned(_)) {
199            let entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten();
200            self.db_cursor_state.set_entry(entry);
201        }
202
203        Ok(())
204    }
205
206    /// Compares the current in-memory entry with the current entry of the cursor, and applies the
207    /// in-memory entry to the cursor entry as an overlay.
208    //
209    /// This may consume and move forward the current entries when the overlay indicates a removed
210    /// node.
211    fn choose_next_entry(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
212        loop {
213            let mem_entry = self.in_memory_cursor.current().cloned();
214            let db_entry = self.db_cursor_state.entry();
215
216            match (mem_entry, db_entry) {
217                (Some((mem_key, None)), _)
218                    if db_entry.is_none_or(|(db_key, _)| &mem_key < db_key) =>
219                {
220                    // If overlay has a removed node but DB cursor is exhausted or ahead of the
221                    // in-memory cursor then move ahead in-memory, as there might be further
222                    // non-removed overlay nodes.
223                    self.in_memory_cursor.first_after(&mem_key);
224                }
225                (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
226                    // If overlay has a removed node which is returned from DB then move both
227                    // cursors ahead to the next key.
228                    self.in_memory_cursor.first_after(&mem_key);
229                    self.cursor_next()?;
230                }
231                (Some((mem_key, Some(node))), _)
232                    if db_entry.is_none_or(|(db_key, _)| &mem_key <= db_key) =>
233                {
234                    // If overlay returns a node prior to the DB's node, or the DB is exhausted,
235                    // then we return the overlay's node.
236                    return Ok(Some((mem_key, node)))
237                }
238                // All other cases:
239                // - mem_key > db_key
240                // - overlay is exhausted
241                // Return the db_entry. If DB is also exhausted then this returns None.
242                _ => return Ok(db_entry.cloned()),
243            }
244        }
245    }
246}
247
248impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
249    fn seek_exact(
250        &mut self,
251        key: Nibbles,
252    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
253        let mem_entry = self.in_memory_cursor.seek(&key);
254
255        if let Some((mem_key, entry_inner)) = mem_entry &&
256            *mem_key == key
257        {
258            #[cfg(debug_assertions)]
259            {
260                self.seeked = true;
261            }
262
263            // An exact overlay hit can move the logical cursor ahead without touching the DB. If
264            // the DB cursor was still behind this key, force a re-seek before the next DB-backed
265            // operation so `next()` cannot return a stale earlier entry.
266            if matches!(&self.db_cursor_state, DbCursorState::Positioned((db_key, _)) if db_key < &key)
267            {
268                self.db_cursor_state = DbCursorState::NeedsPosition;
269            }
270
271            let entry = entry_inner.clone().map(|node| (key, node));
272            self.set_last_key(&entry);
273            return Ok(entry)
274        }
275
276        self.cursor_seek(key)?;
277
278        #[cfg(debug_assertions)]
279        {
280            self.seeked = true;
281        }
282
283        let entry = match self.db_cursor_state.entry() {
284            Some((db_key, node)) if db_key == &key => Some((key, node.clone())),
285            _ => None,
286        };
287
288        self.set_last_key(&entry);
289        Ok(entry)
290    }
291
292    fn seek(
293        &mut self,
294        key: Nibbles,
295    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
296        self.cursor_seek(key)?;
297        self.in_memory_cursor.seek(&key);
298
299        #[cfg(debug_assertions)]
300        {
301            self.seeked = true;
302        }
303
304        let entry = self.choose_next_entry()?;
305        self.set_last_key(&entry);
306        Ok(entry)
307    }
308
309    fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
310        #[cfg(debug_assertions)]
311        {
312            debug_assert!(self.seeked, "Cursor must be seek'd before next is called");
313        }
314
315        // A `last_key` of `None` indicates that the cursor is exhausted.
316        let Some(last_key) = self.last_key else {
317            return Ok(None);
318        };
319
320        // If either cursor is currently pointing to the last entry which was returned then consume
321        // that entry so that `choose_next_entry` is looking at the subsequent one.
322        if let Some((key, _)) = self.in_memory_cursor.current() &&
323            key == &last_key
324        {
325            self.in_memory_cursor.first_after(&last_key);
326        }
327
328        if matches!(self.db_cursor_state, DbCursorState::NeedsPosition) {
329            self.cursor_seek(last_key)?;
330        }
331
332        if let Some((key, _)) = self.db_cursor_state.entry() &&
333            key == &last_key
334        {
335            self.cursor_next()?;
336        }
337
338        let entry = self.choose_next_entry()?;
339        self.set_last_key(&entry);
340        Ok(entry)
341    }
342
343    fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
344        match &self.last_key {
345            Some(key) => Ok(Some(*key)),
346            None => Ok(self.get_cursor_mut().map(|c| c.current()).transpose()?.flatten()),
347        }
348    }
349
350    fn reset(&mut self) {
351        self.cursor.reset();
352        self.in_memory_cursor.reset();
353
354        self.db_cursor_state = DbCursorState::NeedsPosition;
355        self.last_key = None;
356        #[cfg(debug_assertions)]
357        {
358            self.seeked = false;
359        }
360    }
361}
362
363impl<C: TrieStorageCursor> TrieStorageCursor for InMemoryTrieCursor<'_, C> {
364    fn set_hashed_address(&mut self, hashed_address: B256) {
365        self.reset();
366        self.cursor.set_hashed_address(hashed_address);
367        let (in_memory_cursor, cursor_wiped) =
368            Self::get_storage_overlay(self.trie_updates, hashed_address);
369        self.in_memory_cursor = in_memory_cursor;
370        self.db_cursor_state = DbCursorState::new(cursor_wiped);
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::trie_cursor::mock::MockTrieCursor;
378    use parking_lot::Mutex;
379    use std::{collections::BTreeMap, sync::Arc};
380
381    #[derive(Debug)]
382    struct InMemoryTrieCursorTestCase {
383        db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
384        in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
385        expected_results: Vec<(Nibbles, BranchNodeCompact)>,
386    }
387
388    fn execute_test(test_case: InMemoryTrieCursorTestCase) {
389        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
390            test_case.db_nodes.into_iter().collect();
391        let db_nodes_arc = Arc::new(db_nodes_map);
392        let visited_keys = Arc::new(Mutex::new(Vec::new()));
393        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
394
395        let trie_updates = TrieUpdatesSorted::new(test_case.in_memory_nodes, Default::default());
396        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
397
398        let mut results = Vec::new();
399
400        if let Some(first_expected) = test_case.expected_results.first() &&
401            let Ok(Some(entry)) = cursor.seek(first_expected.0)
402        {
403            results.push(entry);
404        }
405
406        if !test_case.expected_results.is_empty() {
407            while let Ok(Some(entry)) = cursor.next() {
408                results.push(entry);
409            }
410        }
411
412        assert_eq!(
413            results, test_case.expected_results,
414            "Results mismatch.\nGot: {:?}\nExpected: {:?}",
415            results, test_case.expected_results
416        );
417    }
418
419    #[test]
420    fn test_empty_db_and_memory() {
421        let test_case = InMemoryTrieCursorTestCase {
422            db_nodes: vec![],
423            in_memory_nodes: vec![],
424            expected_results: vec![],
425        };
426        execute_test(test_case);
427    }
428
429    #[test]
430    fn test_only_db_nodes() {
431        let db_nodes = vec![
432            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
433            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
434            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
435        ];
436
437        let test_case = InMemoryTrieCursorTestCase {
438            db_nodes: db_nodes.clone(),
439            in_memory_nodes: vec![],
440            expected_results: db_nodes,
441        };
442        execute_test(test_case);
443    }
444
445    #[test]
446    fn test_only_in_memory_nodes() {
447        let in_memory_nodes = vec![
448            (
449                Nibbles::from_nibbles([0x1]),
450                Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
451            ),
452            (
453                Nibbles::from_nibbles([0x2]),
454                Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
455            ),
456            (
457                Nibbles::from_nibbles([0x3]),
458                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
459            ),
460        ];
461
462        let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
463            .iter()
464            .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
465            .collect();
466
467        let test_case =
468            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
469        execute_test(test_case);
470    }
471
472    #[test]
473    fn test_in_memory_overwrites_db() {
474        let db_nodes = vec![
475            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
476            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
477        ];
478
479        let in_memory_nodes = vec![
480            (
481                Nibbles::from_nibbles([0x1]),
482                Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
483            ),
484            (
485                Nibbles::from_nibbles([0x3]),
486                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
487            ),
488        ];
489
490        let expected_results = vec![
491            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
492            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
493            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
494        ];
495
496        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
497        execute_test(test_case);
498    }
499
500    #[test]
501    fn test_in_memory_deletes_db_nodes() {
502        let db_nodes = vec![
503            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
504            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
505            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
506        ];
507
508        let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
509
510        let expected_results = vec![
511            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
512            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
513        ];
514
515        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
516        execute_test(test_case);
517    }
518
519    #[test]
520    fn test_complex_interleaving() {
521        let db_nodes = vec![
522            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
523            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
524            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
525            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
526        ];
527
528        let in_memory_nodes = vec![
529            (
530                Nibbles::from_nibbles([0x2]),
531                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
532            ),
533            (Nibbles::from_nibbles([0x3]), None),
534            (
535                Nibbles::from_nibbles([0x4]),
536                Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
537            ),
538            (
539                Nibbles::from_nibbles([0x6]),
540                Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
541            ),
542            (Nibbles::from_nibbles([0x7]), None),
543            (
544                Nibbles::from_nibbles([0x8]),
545                Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
546            ),
547        ];
548
549        let expected_results = vec![
550            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
551            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
552            (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
553            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
554            (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
555            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
556        ];
557
558        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
559        execute_test(test_case);
560    }
561
562    #[test]
563    fn test_seek_exact() {
564        let db_nodes = vec![
565            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
566            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
567        ];
568
569        let in_memory_nodes = vec![(
570            Nibbles::from_nibbles([0x2]),
571            Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
572        )];
573
574        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
575        let db_nodes_arc = Arc::new(db_nodes_map);
576        let visited_keys = Arc::new(Mutex::new(Vec::new()));
577        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys.clone());
578
579        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
580        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
581
582        let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
583        assert_eq!(
584            result,
585            Some((
586                Nibbles::from_nibbles([0x2]),
587                BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
588            ))
589        );
590        assert!(visited_keys.lock().is_empty(), "exact overlay hit should not touch the DB cursor");
591
592        let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
593        assert_eq!(
594            result,
595            Some((
596                Nibbles::from_nibbles([0x3]),
597                BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
598            ))
599        );
600
601        let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
602        assert_eq!(result, None);
603    }
604
605    #[test]
606    fn test_multiple_consecutive_deletes() {
607        let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
608            .map(|i| {
609                (
610                    Nibbles::from_nibbles([i]),
611                    BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
612                )
613            })
614            .collect();
615
616        let in_memory_nodes = vec![
617            (Nibbles::from_nibbles([0x3]), None),
618            (Nibbles::from_nibbles([0x4]), None),
619            (Nibbles::from_nibbles([0x5]), None),
620            (Nibbles::from_nibbles([0x6]), None),
621        ];
622
623        let expected_results = vec![
624            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
625            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
626            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
627            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
628            (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
629            (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
630        ];
631
632        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
633        execute_test(test_case);
634    }
635
636    #[test]
637    fn test_empty_db_with_in_memory_deletes() {
638        let in_memory_nodes = vec![
639            (Nibbles::from_nibbles([0x1]), None),
640            (
641                Nibbles::from_nibbles([0x2]),
642                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
643            ),
644            (Nibbles::from_nibbles([0x3]), None),
645        ];
646
647        let expected_results = vec![(
648            Nibbles::from_nibbles([0x2]),
649            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
650        )];
651
652        let test_case =
653            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
654        execute_test(test_case);
655    }
656
657    #[test]
658    fn test_current_key_tracking() {
659        let db_nodes = vec![(
660            Nibbles::from_nibbles([0x2]),
661            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
662        )];
663
664        let in_memory_nodes = vec![
665            (
666                Nibbles::from_nibbles([0x1]),
667                Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
668            ),
669            (
670                Nibbles::from_nibbles([0x3]),
671                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
672            ),
673        ];
674
675        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
676        let db_nodes_arc = Arc::new(db_nodes_map);
677        let visited_keys = Arc::new(Mutex::new(Vec::new()));
678        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
679
680        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
681        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
682
683        assert_eq!(cursor.current().unwrap(), None);
684
685        cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
686        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
687
688        cursor.next().unwrap();
689        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
690
691        cursor.next().unwrap();
692        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
693    }
694
695    #[test]
696    fn test_all_storage_slots_deleted_not_wiped_exact_keys() {
697        use tracing::debug;
698        reth_tracing::init_test_tracing();
699
700        // This test reproduces an edge case where:
701        // - cursor is not None (not wiped)
702        // - All in-memory entries are deletions (None values)
703        // - Database has corresponding entries
704        // - Expected: NO leaves should be returned (all deleted)
705
706        // Generate 42 trie node entries with keys distributed across the keyspace
707        let mut db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (0..10)
708            .map(|i| {
709                let key_bytes = vec![(i * 6) as u8, i as u8]; // Spread keys across keyspace
710                let nibbles = Nibbles::unpack(key_bytes);
711                (nibbles, BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None))
712            })
713            .collect();
714
715        db_nodes.sort_by_key(|(key, _)| *key);
716        db_nodes.dedup_by_key(|(key, _)| *key);
717
718        for (key, _) in &db_nodes {
719            debug!("node at {key:?}");
720        }
721
722        // Create in-memory entries with same keys but all None values (deletions)
723        let in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)> =
724            db_nodes.iter().map(|(key, _)| (*key, None)).collect();
725
726        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
727        let db_nodes_arc = Arc::new(db_nodes_map);
728        let visited_keys = Arc::new(Mutex::new(Vec::new()));
729        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
730
731        let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
732        let mut cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
733
734        // Seek to beginning should return None (all nodes are deleted)
735        tracing::debug!("seeking to 0x");
736        let result = cursor.seek(Nibbles::default()).unwrap();
737        assert_eq!(
738            result, None,
739            "Expected no entries when all nodes are deleted, but got {:?}",
740            result
741        );
742
743        // Test seek operations at various positions - all should return None
744        let seek_keys = vec![
745            Nibbles::unpack([0x00]),
746            Nibbles::unpack([0x5d]),
747            Nibbles::unpack([0x5e]),
748            Nibbles::unpack([0x5f]),
749            Nibbles::unpack([0xc2]),
750            Nibbles::unpack([0xc5]),
751            Nibbles::unpack([0xc9]),
752            Nibbles::unpack([0xf0]),
753        ];
754
755        for seek_key in seek_keys {
756            tracing::debug!("seeking to {seek_key:?}");
757            let result = cursor.seek(seek_key).unwrap();
758            assert_eq!(
759                result, None,
760                "Expected None when seeking to {:?} but got {:?}",
761                seek_key, result
762            );
763        }
764
765        // next() should also always return None
766        let result = cursor.next().unwrap();
767        assert_eq!(result, None, "Expected None from next() but got {:?}", result);
768    }
769
770    mod proptest_tests {
771        use super::*;
772        use itertools::Itertools;
773        use proptest::prelude::*;
774
775        /// Merge `db_nodes` with `in_memory_nodes`, applying the in-memory overlay.
776        /// This properly handles deletions (None values in `in_memory_nodes`).
777        fn merge_with_overlay(
778            db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
779            in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
780        ) -> Vec<(Nibbles, BranchNodeCompact)> {
781            db_nodes
782                .into_iter()
783                .merge_join_by(in_memory_nodes, |db_entry, mem_entry| db_entry.0.cmp(&mem_entry.0))
784                .filter_map(|entry| match entry {
785                    // Only in db: keep it
786                    itertools::EitherOrBoth::Left((key, node)) => Some((key, node)),
787                    // Only in memory: keep if not a deletion
788                    itertools::EitherOrBoth::Right((key, node_opt)) => {
789                        node_opt.map(|node| (key, node))
790                    }
791                    // In both: memory takes precedence (keep if not a deletion)
792                    itertools::EitherOrBoth::Both(_, (key, node_opt)) => {
793                        node_opt.map(|node| (key, node))
794                    }
795                })
796                .collect()
797        }
798
799        /// Generate a strategy for a `BranchNodeCompact` with simplified parameters.
800        /// The constraints are:
801        /// - `tree_mask` must be a subset of `state_mask`
802        /// - `hash_mask` must be a subset of `state_mask`
803        /// - `hash_mask.count_ones()` must equal `hashes.len()`
804        ///
805        /// To keep it simple, we use an empty hashes vec and `hash_mask` of 0.
806        fn branch_node_strategy() -> impl Strategy<Value = BranchNodeCompact> {
807            any::<u16>()
808                .prop_flat_map(|state_mask| {
809                    let tree_mask_strategy = any::<u16>().prop_map(move |tree| tree & state_mask);
810                    (Just(state_mask), tree_mask_strategy)
811                })
812                .prop_map(|(state_mask, tree_mask)| {
813                    BranchNodeCompact::new(state_mask, tree_mask, 0, vec![], None)
814                })
815        }
816
817        /// Generate a sorted vector of (Nibbles, `BranchNodeCompact`) entries
818        fn sorted_db_nodes_strategy() -> impl Strategy<Value = Vec<(Nibbles, BranchNodeCompact)>> {
819            prop::collection::vec(
820                (prop::collection::vec(any::<u8>(), 0..2), branch_node_strategy()),
821                0..20,
822            )
823            .prop_map(|entries| {
824                // Convert Vec<u8> to Nibbles and sort
825                let mut result: Vec<(Nibbles, BranchNodeCompact)> = entries
826                    .into_iter()
827                    .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node))
828                    .collect();
829                result.sort_by_key(|a| a.0);
830                result.dedup_by(|a, b| a.0 == b.0);
831                result
832            })
833        }
834
835        /// Generate a sorted vector of (Nibbles, Option<BranchNodeCompact>) entries
836        fn sorted_in_memory_nodes_strategy(
837        ) -> impl Strategy<Value = Vec<(Nibbles, Option<BranchNodeCompact>)>> {
838            prop::collection::vec(
839                (
840                    prop::collection::vec(any::<u8>(), 0..2),
841                    prop::option::of(branch_node_strategy()),
842                ),
843                0..20,
844            )
845            .prop_map(|entries| {
846                // Convert Vec<u8> to Nibbles and sort
847                let mut result: Vec<(Nibbles, Option<BranchNodeCompact>)> = entries
848                    .into_iter()
849                    .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node))
850                    .collect();
851                result.sort_by_key(|a| a.0);
852                result.dedup_by(|a, b| a.0 == b.0);
853                result
854            })
855        }
856
857        proptest! {
858            #![proptest_config(ProptestConfig::with_cases(10000))]
859
860            #[test]
861            fn proptest_in_memory_trie_cursor(
862                db_nodes in sorted_db_nodes_strategy(),
863                in_memory_nodes in sorted_in_memory_nodes_strategy(),
864                op_choices in prop::collection::vec(any::<u8>(), 10..500),
865            ) {
866                reth_tracing::init_test_tracing();
867                use tracing::debug;
868
869                debug!(
870                    db_paths=?db_nodes.iter().map(|(k, _)| k).collect::<Vec<_>>(),
871                    in_mem_nodes=?in_memory_nodes.iter().map(|(k, v)| (k, v.is_some())).collect::<Vec<_>>(),
872                    num_op_choices=?op_choices.len(),
873                    "Starting proptest!",
874                );
875
876                // Create the expected results by merging the two sorted vectors,
877                // properly handling deletions (None values in in_memory_nodes)
878                let expected_combined = merge_with_overlay(db_nodes.clone(), in_memory_nodes.clone());
879
880                // Collect all keys for operation generation
881                let all_keys: Vec<Nibbles> = expected_combined.iter().map(|(k, _)| *k).collect();
882
883                // Create a control cursor using the combined result with a mock cursor
884                let control_db_map: BTreeMap<Nibbles, BranchNodeCompact> =
885                    expected_combined.into_iter().collect();
886                let control_db_arc = Arc::new(control_db_map);
887                let control_visited_keys = Arc::new(Mutex::new(Vec::new()));
888                let mut control_cursor = MockTrieCursor::new(control_db_arc, control_visited_keys);
889
890                // Create the InMemoryTrieCursor being tested
891                let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
892                    db_nodes.into_iter().collect();
893                let db_nodes_arc = Arc::new(db_nodes_map);
894                let visited_keys = Arc::new(Mutex::new(Vec::new()));
895                let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
896                let trie_updates = TrieUpdatesSorted::new(in_memory_nodes, Default::default());
897                let mut test_cursor = InMemoryTrieCursor::new_account(mock_cursor, &trie_updates);
898
899                // Test: seek to the beginning first
900                let control_first = control_cursor.seek(Nibbles::default()).unwrap();
901                let test_first = test_cursor.seek(Nibbles::default()).unwrap();
902                debug!(
903                    control=?control_first.as_ref().map(|(k, _)| k),
904                    test=?test_first.as_ref().map(|(k, _)| k),
905                    "Initial seek returned",
906                );
907                assert_eq!(control_first, test_first, "Initial seek mismatch");
908
909                // If both cursors returned None, nothing to test
910                if control_first.is_none() && test_first.is_none() {
911                    return Ok(());
912                }
913
914                // Track the last key returned from the cursor
915                let mut last_returned_key = control_first.as_ref().map(|(k, _)| *k);
916
917                // Execute a sequence of random operations
918                for choice in op_choices {
919                    let op_type = choice % 3;
920
921                    match op_type {
922                        0 => {
923                            // Next operation
924                            let control_result = control_cursor.next().unwrap();
925                            let test_result = test_cursor.next().unwrap();
926                            debug!(
927                                control=?control_result.as_ref().map(|(k, _)| k),
928                                test=?test_result.as_ref().map(|(k, _)| k),
929                                "Next returned",
930                            );
931                            assert_eq!(control_result, test_result, "Next operation mismatch");
932
933                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
934
935                            // Stop if both cursors are exhausted
936                            if control_result.is_none() && test_result.is_none() {
937                                break;
938                            }
939                        }
940                        1 => {
941                            // Seek operation - choose a key >= last_returned_key
942                            if all_keys.is_empty() {
943                                continue;
944                            }
945
946                            let valid_keys: Vec<_> = all_keys
947                                .iter()
948                                .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
949                                .collect();
950
951                            if valid_keys.is_empty() {
952                                continue;
953                            }
954
955                            let key = *valid_keys[choice as usize % valid_keys.len()];
956
957                            let control_result = control_cursor.seek(key).unwrap();
958                            let test_result = test_cursor.seek(key).unwrap();
959                            debug!(
960                                control=?control_result.as_ref().map(|(k, _)| k),
961                                test=?test_result.as_ref().map(|(k, _)| k),
962                                ?key,
963                                "Seek returned",
964                            );
965                            assert_eq!(control_result, test_result, "Seek operation mismatch for key {:?}", key);
966
967                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
968
969                            // Stop if both cursors are exhausted
970                            if control_result.is_none() && test_result.is_none() {
971                                break;
972                            }
973                        }
974                        _ => {
975                            // SeekExact operation - choose a key >= last_returned_key
976                            if all_keys.is_empty() {
977                                continue;
978                            }
979
980                            let valid_keys: Vec<_> = all_keys
981                                .iter()
982                                .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
983                                .collect();
984
985                            if valid_keys.is_empty() {
986                                continue;
987                            }
988
989                            let key = *valid_keys[choice as usize  % valid_keys.len()];
990
991                            let control_result = control_cursor.seek_exact(key).unwrap();
992                            let test_result = test_cursor.seek_exact(key).unwrap();
993                            debug!(
994                                control=?control_result.as_ref().map(|(k, _)| k),
995                                test=?test_result.as_ref().map(|(k, _)| k),
996                                ?key,
997                                "SeekExact returned",
998                            );
999                            assert_eq!(control_result, test_result, "SeekExact operation mismatch for key {:?}", key);
1000
1001                            // seek_exact updates the last_key internally but only if it found something
1002                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
1003                        }
1004                    }
1005                }
1006            }
1007        }
1008    }
1009}