reth_trie/trie_cursor/
in_memory.rs

1use super::{TrieCursor, TrieCursorFactory};
2use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
3use alloy_primitives::B256;
4use reth_storage_errors::db::DatabaseError;
5use reth_trie_common::{BranchNodeCompact, Nibbles};
6
7/// The trie cursor factory for the trie updates.
8#[derive(Debug, Clone)]
9pub struct InMemoryTrieCursorFactory<CF, T> {
10    /// Underlying trie cursor factory.
11    cursor_factory: CF,
12    /// Reference to sorted trie updates.
13    trie_updates: T,
14}
15
16impl<CF, T> InMemoryTrieCursorFactory<CF, T> {
17    /// Create a new trie cursor factory.
18    pub const fn new(cursor_factory: CF, trie_updates: T) -> Self {
19        Self { cursor_factory, trie_updates }
20    }
21}
22
23impl<'overlay, CF, T> TrieCursorFactory for InMemoryTrieCursorFactory<CF, &'overlay T>
24where
25    CF: TrieCursorFactory + 'overlay,
26    T: AsRef<TrieUpdatesSorted>,
27{
28    type AccountTrieCursor<'cursor>
29        = InMemoryTrieCursor<'overlay, CF::AccountTrieCursor<'cursor>>
30    where
31        Self: 'cursor;
32
33    type StorageTrieCursor<'cursor>
34        = InMemoryTrieCursor<'overlay, CF::StorageTrieCursor<'cursor>>
35    where
36        Self: 'cursor;
37
38    fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
39        let cursor = self.cursor_factory.account_trie_cursor()?;
40        Ok(InMemoryTrieCursor::new(Some(cursor), self.trie_updates.as_ref().account_nodes_ref()))
41    }
42
43    fn storage_trie_cursor(
44        &self,
45        hashed_address: B256,
46    ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
47        // if the storage trie has no updates then we use this as the in-memory overlay.
48        static EMPTY_UPDATES: Vec<(Nibbles, Option<BranchNodeCompact>)> = Vec::new();
49
50        let storage_trie_updates =
51            self.trie_updates.as_ref().storage_tries_ref().get(&hashed_address);
52        let (storage_nodes, cleared) = storage_trie_updates
53            .map(|u| (u.storage_nodes_ref(), u.is_deleted()))
54            .unwrap_or((&EMPTY_UPDATES, false));
55
56        let cursor = if cleared {
57            None
58        } else {
59            Some(self.cursor_factory.storage_trie_cursor(hashed_address)?)
60        };
61
62        Ok(InMemoryTrieCursor::new(cursor, storage_nodes))
63    }
64}
65
66/// A cursor to iterate over trie updates and corresponding database entries.
67/// It will always give precedence to the data from the trie updates.
68#[derive(Debug)]
69pub struct InMemoryTrieCursor<'a, C> {
70    /// The underlying cursor. If None then it is assumed there is no DB data.
71    cursor: Option<C>,
72    /// Entry that `cursor` is currently pointing to.
73    cursor_entry: Option<(Nibbles, BranchNodeCompact)>,
74    /// Forward-only in-memory cursor over storage trie nodes.
75    in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
76    /// The key most recently returned from the Cursor.
77    last_key: Option<Nibbles>,
78    #[cfg(debug_assertions)]
79    /// Whether an initial seek was called.
80    seeked: bool,
81}
82
83impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
84    /// Create new trie cursor which combines a DB cursor (None to assume empty DB) and a set of
85    /// in-memory trie nodes.
86    pub fn new(
87        cursor: Option<C>,
88        trie_updates: &'a [(Nibbles, Option<BranchNodeCompact>)],
89    ) -> Self {
90        let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates);
91        Self {
92            cursor,
93            cursor_entry: None,
94            in_memory_cursor,
95            last_key: None,
96            #[cfg(debug_assertions)]
97            seeked: false,
98        }
99    }
100
101    /// Asserts that the next entry to be returned from the cursor is not previous to the last entry
102    /// returned.
103    fn set_last_key(&mut self, next_entry: &Option<(Nibbles, BranchNodeCompact)>) {
104        let next_key = next_entry.as_ref().map(|e| e.0);
105        debug_assert!(
106            self.last_key.is_none_or(|last| next_key.is_none_or(|next| next >= last)),
107            "Cannot return entry {:?} previous to the last returned entry at {:?}",
108            next_key,
109            self.last_key,
110        );
111        self.last_key = next_key;
112    }
113
114    /// Seeks the `cursor_entry` field of the struct using the cursor.
115    fn cursor_seek(&mut self, key: Nibbles) -> Result<(), DatabaseError> {
116        if let Some(entry) = self.cursor_entry.as_ref() &&
117            entry.0 >= key
118        {
119            // If already seeked to the given key then don't do anything. Also if we're seeked past
120            // the given key then don't anything, because `TrieCursor` is specifically a
121            // forward-only cursor.
122        } else {
123            self.cursor_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten();
124        }
125
126        Ok(())
127    }
128
129    /// Seeks the `cursor_entry` field of the struct to the subsequent entry using the cursor.
130    fn cursor_next(&mut self) -> Result<(), DatabaseError> {
131        #[cfg(debug_assertions)]
132        {
133            debug_assert!(self.seeked);
134        }
135
136        // If the previous entry is `None`, and we've done a seek previously, then the cursor is
137        // exhausted and we shouldn't call `next` again.
138        if self.cursor_entry.is_some() {
139            self.cursor_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten();
140        }
141
142        Ok(())
143    }
144
145    /// Compares the current in-memory entry with the current entry of the cursor, and applies the
146    /// in-memory entry to the cursor entry as an overlay.
147    //
148    /// This may consume and move forward the current entries when the overlay indicates a removed
149    /// node.
150    fn choose_next_entry(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
151        loop {
152            match (self.in_memory_cursor.current().cloned(), &self.cursor_entry) {
153                (Some((mem_key, None)), _)
154                    if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
155                {
156                    // If overlay has a removed node but DB cursor is exhausted or ahead of the
157                    // in-memory cursor then move ahead in-memory, as there might be further
158                    // non-removed overlay nodes.
159                    self.in_memory_cursor.first_after(&mem_key);
160                }
161                (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
162                    // If overlay has a removed node which is returned from DB then move both
163                    // cursors ahead to the next key.
164                    self.in_memory_cursor.first_after(&mem_key);
165                    self.cursor_next()?;
166                }
167                (Some((mem_key, Some(node))), _)
168                    if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
169                {
170                    // If overlay returns a node prior to the DB's node, or the DB is exhausted,
171                    // then we return the overlay's node.
172                    return Ok(Some((mem_key, node)))
173                }
174                // All other cases:
175                // - mem_key > db_key
176                // - overlay is exhausted
177                // Return the db_entry. If DB is also exhausted then this returns None.
178                _ => return Ok(self.cursor_entry.clone()),
179            }
180        }
181    }
182}
183
184impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
185    fn seek_exact(
186        &mut self,
187        key: Nibbles,
188    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
189        self.cursor_seek(key)?;
190        let mem_entry = self.in_memory_cursor.seek(&key);
191
192        #[cfg(debug_assertions)]
193        {
194            self.seeked = true;
195        }
196
197        let entry = match (mem_entry, &self.cursor_entry) {
198            (Some((mem_key, entry_inner)), _) if mem_key == key => {
199                entry_inner.map(|node| (key, node))
200            }
201            (_, Some((db_key, node))) if db_key == &key => Some((key, node.clone())),
202            _ => None,
203        };
204
205        self.set_last_key(&entry);
206        Ok(entry)
207    }
208
209    fn seek(
210        &mut self,
211        key: Nibbles,
212    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
213        self.cursor_seek(key)?;
214        self.in_memory_cursor.seek(&key);
215
216        #[cfg(debug_assertions)]
217        {
218            self.seeked = true;
219        }
220
221        let entry = self.choose_next_entry()?;
222        self.set_last_key(&entry);
223        Ok(entry)
224    }
225
226    fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
227        #[cfg(debug_assertions)]
228        {
229            debug_assert!(self.seeked, "Cursor must be seek'd before next is called");
230        }
231
232        // A `last_key` of `None` indicates that the cursor is exhausted.
233        let Some(last_key) = self.last_key else {
234            return Ok(None);
235        };
236
237        // If either cursor is currently pointing to the last entry which was returned then consume
238        // that entry so that `choose_next_entry` is looking at the subsequent one.
239        if let Some((key, _)) = self.in_memory_cursor.current() &&
240            key == &last_key
241        {
242            self.in_memory_cursor.first_after(&last_key);
243        }
244
245        if let Some((key, _)) = &self.cursor_entry &&
246            key == &last_key
247        {
248            self.cursor_next()?;
249        }
250
251        let entry = self.choose_next_entry()?;
252        self.set_last_key(&entry);
253        Ok(entry)
254    }
255
256    fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
257        match &self.last_key {
258            Some(key) => Ok(Some(*key)),
259            None => Ok(self.cursor.as_mut().map(|c| c.current()).transpose()?.flatten()),
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::trie_cursor::mock::MockTrieCursor;
268    use parking_lot::Mutex;
269    use std::{collections::BTreeMap, sync::Arc};
270
271    #[derive(Debug)]
272    struct InMemoryTrieCursorTestCase {
273        db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
274        in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
275        expected_results: Vec<(Nibbles, BranchNodeCompact)>,
276    }
277
278    fn execute_test(test_case: InMemoryTrieCursorTestCase) {
279        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
280            test_case.db_nodes.into_iter().collect();
281        let db_nodes_arc = Arc::new(db_nodes_map);
282        let visited_keys = Arc::new(Mutex::new(Vec::new()));
283        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
284
285        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &test_case.in_memory_nodes);
286
287        let mut results = Vec::new();
288
289        if let Some(first_expected) = test_case.expected_results.first() &&
290            let Ok(Some(entry)) = cursor.seek(first_expected.0)
291        {
292            results.push(entry);
293        }
294
295        if !test_case.expected_results.is_empty() {
296            while let Ok(Some(entry)) = cursor.next() {
297                results.push(entry);
298            }
299        }
300
301        assert_eq!(
302            results, test_case.expected_results,
303            "Results mismatch.\nGot: {:?}\nExpected: {:?}",
304            results, test_case.expected_results
305        );
306    }
307
308    #[test]
309    fn test_empty_db_and_memory() {
310        let test_case = InMemoryTrieCursorTestCase {
311            db_nodes: vec![],
312            in_memory_nodes: vec![],
313            expected_results: vec![],
314        };
315        execute_test(test_case);
316    }
317
318    #[test]
319    fn test_only_db_nodes() {
320        let db_nodes = vec![
321            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
322            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
323            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
324        ];
325
326        let test_case = InMemoryTrieCursorTestCase {
327            db_nodes: db_nodes.clone(),
328            in_memory_nodes: vec![],
329            expected_results: db_nodes,
330        };
331        execute_test(test_case);
332    }
333
334    #[test]
335    fn test_only_in_memory_nodes() {
336        let in_memory_nodes = vec![
337            (
338                Nibbles::from_nibbles([0x1]),
339                Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
340            ),
341            (
342                Nibbles::from_nibbles([0x2]),
343                Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
344            ),
345            (
346                Nibbles::from_nibbles([0x3]),
347                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
348            ),
349        ];
350
351        let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
352            .iter()
353            .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
354            .collect();
355
356        let test_case =
357            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
358        execute_test(test_case);
359    }
360
361    #[test]
362    fn test_in_memory_overwrites_db() {
363        let db_nodes = vec![
364            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
365            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
366        ];
367
368        let in_memory_nodes = vec![
369            (
370                Nibbles::from_nibbles([0x1]),
371                Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
372            ),
373            (
374                Nibbles::from_nibbles([0x3]),
375                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
376            ),
377        ];
378
379        let expected_results = vec![
380            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
381            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
382            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
383        ];
384
385        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
386        execute_test(test_case);
387    }
388
389    #[test]
390    fn test_in_memory_deletes_db_nodes() {
391        let db_nodes = vec![
392            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
393            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
394            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
395        ];
396
397        let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
398
399        let expected_results = vec![
400            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
401            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
402        ];
403
404        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
405        execute_test(test_case);
406    }
407
408    #[test]
409    fn test_complex_interleaving() {
410        let db_nodes = vec![
411            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
412            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
413            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
414            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
415        ];
416
417        let in_memory_nodes = vec![
418            (
419                Nibbles::from_nibbles([0x2]),
420                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
421            ),
422            (Nibbles::from_nibbles([0x3]), None),
423            (
424                Nibbles::from_nibbles([0x4]),
425                Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
426            ),
427            (
428                Nibbles::from_nibbles([0x6]),
429                Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
430            ),
431            (Nibbles::from_nibbles([0x7]), None),
432            (
433                Nibbles::from_nibbles([0x8]),
434                Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
435            ),
436        ];
437
438        let expected_results = vec![
439            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
440            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
441            (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
442            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
443            (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
444            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
445        ];
446
447        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
448        execute_test(test_case);
449    }
450
451    #[test]
452    fn test_seek_exact() {
453        let db_nodes = vec![
454            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
455            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
456        ];
457
458        let in_memory_nodes = vec![(
459            Nibbles::from_nibbles([0x2]),
460            Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
461        )];
462
463        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
464        let db_nodes_arc = Arc::new(db_nodes_map);
465        let visited_keys = Arc::new(Mutex::new(Vec::new()));
466        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
467
468        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
469
470        let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
471        assert_eq!(
472            result,
473            Some((
474                Nibbles::from_nibbles([0x2]),
475                BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
476            ))
477        );
478
479        let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
480        assert_eq!(
481            result,
482            Some((
483                Nibbles::from_nibbles([0x3]),
484                BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
485            ))
486        );
487
488        let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
489        assert_eq!(result, None);
490    }
491
492    #[test]
493    fn test_multiple_consecutive_deletes() {
494        let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
495            .map(|i| {
496                (
497                    Nibbles::from_nibbles([i]),
498                    BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
499                )
500            })
501            .collect();
502
503        let in_memory_nodes = vec![
504            (Nibbles::from_nibbles([0x3]), None),
505            (Nibbles::from_nibbles([0x4]), None),
506            (Nibbles::from_nibbles([0x5]), None),
507            (Nibbles::from_nibbles([0x6]), None),
508        ];
509
510        let expected_results = vec![
511            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
512            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
513            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
514            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
515            (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
516            (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
517        ];
518
519        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
520        execute_test(test_case);
521    }
522
523    #[test]
524    fn test_empty_db_with_in_memory_deletes() {
525        let in_memory_nodes = vec![
526            (Nibbles::from_nibbles([0x1]), None),
527            (
528                Nibbles::from_nibbles([0x2]),
529                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
530            ),
531            (Nibbles::from_nibbles([0x3]), None),
532        ];
533
534        let expected_results = vec![(
535            Nibbles::from_nibbles([0x2]),
536            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
537        )];
538
539        let test_case =
540            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
541        execute_test(test_case);
542    }
543
544    #[test]
545    fn test_current_key_tracking() {
546        let db_nodes = vec![(
547            Nibbles::from_nibbles([0x2]),
548            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
549        )];
550
551        let in_memory_nodes = vec![
552            (
553                Nibbles::from_nibbles([0x1]),
554                Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
555            ),
556            (
557                Nibbles::from_nibbles([0x3]),
558                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
559            ),
560        ];
561
562        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
563        let db_nodes_arc = Arc::new(db_nodes_map);
564        let visited_keys = Arc::new(Mutex::new(Vec::new()));
565        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
566
567        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
568
569        assert_eq!(cursor.current().unwrap(), None);
570
571        cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
572        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
573
574        cursor.next().unwrap();
575        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
576
577        cursor.next().unwrap();
578        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
579    }
580
581    mod proptest_tests {
582        use super::*;
583        use itertools::Itertools;
584        use proptest::prelude::*;
585
586        /// Merge `db_nodes` with `in_memory_nodes`, applying the in-memory overlay.
587        /// This properly handles deletions (None values in `in_memory_nodes`).
588        fn merge_with_overlay(
589            db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
590            in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
591        ) -> Vec<(Nibbles, BranchNodeCompact)> {
592            db_nodes
593                .into_iter()
594                .merge_join_by(in_memory_nodes, |db_entry, mem_entry| db_entry.0.cmp(&mem_entry.0))
595                .filter_map(|entry| match entry {
596                    // Only in db: keep it
597                    itertools::EitherOrBoth::Left((key, node)) => Some((key, node)),
598                    // Only in memory: keep if not a deletion
599                    itertools::EitherOrBoth::Right((key, node_opt)) => {
600                        node_opt.map(|node| (key, node))
601                    }
602                    // In both: memory takes precedence (keep if not a deletion)
603                    itertools::EitherOrBoth::Both(_, (key, node_opt)) => {
604                        node_opt.map(|node| (key, node))
605                    }
606                })
607                .collect()
608        }
609
610        /// Generate a strategy for a `BranchNodeCompact` with simplified parameters.
611        /// The constraints are:
612        /// - `tree_mask` must be a subset of `state_mask`
613        /// - `hash_mask` must be a subset of `state_mask`
614        /// - `hash_mask.count_ones()` must equal `hashes.len()`
615        ///
616        /// To keep it simple, we use an empty hashes vec and `hash_mask` of 0.
617        fn branch_node_strategy() -> impl Strategy<Value = BranchNodeCompact> {
618            any::<u16>()
619                .prop_flat_map(|state_mask| {
620                    let tree_mask_strategy = any::<u16>().prop_map(move |tree| tree & state_mask);
621                    (Just(state_mask), tree_mask_strategy)
622                })
623                .prop_map(|(state_mask, tree_mask)| {
624                    BranchNodeCompact::new(state_mask, tree_mask, 0, vec![], None)
625                })
626        }
627
628        /// Generate a sorted vector of (Nibbles, `BranchNodeCompact`) entries
629        fn sorted_db_nodes_strategy() -> impl Strategy<Value = Vec<(Nibbles, BranchNodeCompact)>> {
630            prop::collection::vec(
631                (prop::collection::vec(any::<u8>(), 0..3), branch_node_strategy()),
632                0..20,
633            )
634            .prop_map(|entries| {
635                // Convert Vec<u8> to Nibbles and sort
636                let mut result: Vec<(Nibbles, BranchNodeCompact)> = entries
637                    .into_iter()
638                    .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node))
639                    .collect();
640                result.sort_by(|a, b| a.0.cmp(&b.0));
641                result.dedup_by(|a, b| a.0 == b.0);
642                result
643            })
644        }
645
646        /// Generate a sorted vector of (Nibbles, Option<BranchNodeCompact>) entries
647        fn sorted_in_memory_nodes_strategy(
648        ) -> impl Strategy<Value = Vec<(Nibbles, Option<BranchNodeCompact>)>> {
649            prop::collection::vec(
650                (
651                    prop::collection::vec(any::<u8>(), 0..3),
652                    prop::option::of(branch_node_strategy()),
653                ),
654                0..20,
655            )
656            .prop_map(|entries| {
657                // Convert Vec<u8> to Nibbles and sort
658                let mut result: Vec<(Nibbles, Option<BranchNodeCompact>)> = entries
659                    .into_iter()
660                    .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node))
661                    .collect();
662                result.sort_by(|a, b| a.0.cmp(&b.0));
663                result.dedup_by(|a, b| a.0 == b.0);
664                result
665            })
666        }
667
668        proptest! {
669            #![proptest_config(ProptestConfig::with_cases(1000))]
670
671            #[test]
672            fn proptest_in_memory_trie_cursor(
673                db_nodes in sorted_db_nodes_strategy(),
674                in_memory_nodes in sorted_in_memory_nodes_strategy(),
675                op_choices in prop::collection::vec(any::<u8>(), 10..500),
676            ) {
677                reth_tracing::init_test_tracing();
678                use tracing::debug;
679
680                debug!("Starting proptest!");
681
682                // Create the expected results by merging the two sorted vectors,
683                // properly handling deletions (None values in in_memory_nodes)
684                let expected_combined = merge_with_overlay(db_nodes.clone(), in_memory_nodes.clone());
685
686                // Collect all keys for operation generation
687                let all_keys: Vec<Nibbles> = expected_combined.iter().map(|(k, _)| *k).collect();
688
689                // Create a control cursor using the combined result with a mock cursor
690                let control_db_map: BTreeMap<Nibbles, BranchNodeCompact> =
691                    expected_combined.into_iter().collect();
692                let control_db_arc = Arc::new(control_db_map);
693                let control_visited_keys = Arc::new(Mutex::new(Vec::new()));
694                let mut control_cursor = MockTrieCursor::new(control_db_arc, control_visited_keys);
695
696                // Create the InMemoryTrieCursor being tested
697                let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
698                    db_nodes.into_iter().collect();
699                let db_nodes_arc = Arc::new(db_nodes_map);
700                let visited_keys = Arc::new(Mutex::new(Vec::new()));
701                let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
702                let mut test_cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
703
704                // Test: seek to the beginning first
705                let control_first = control_cursor.seek(Nibbles::default()).unwrap();
706                let test_first = test_cursor.seek(Nibbles::default()).unwrap();
707                debug!(
708                    control=?control_first.as_ref().map(|(k, _)| k),
709                    test=?test_first.as_ref().map(|(k, _)| k),
710                    "Initial seek returned",
711                );
712                assert_eq!(control_first, test_first, "Initial seek mismatch");
713
714                // If both cursors returned None, nothing to test
715                if control_first.is_none() && test_first.is_none() {
716                    return Ok(());
717                }
718
719                // Track the last key returned from the cursor
720                let mut last_returned_key = control_first.as_ref().map(|(k, _)| *k);
721
722                // Execute a sequence of random operations
723                for choice in op_choices {
724                    let op_type = choice % 3;
725
726                    match op_type {
727                        0 => {
728                            // Next operation
729                            let control_result = control_cursor.next().unwrap();
730                            let test_result = test_cursor.next().unwrap();
731                            debug!(
732                                control=?control_result.as_ref().map(|(k, _)| k),
733                                test=?test_result.as_ref().map(|(k, _)| k),
734                                "Next returned",
735                            );
736                            assert_eq!(control_result, test_result, "Next operation mismatch");
737
738                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
739
740                            // Stop if both cursors are exhausted
741                            if control_result.is_none() && test_result.is_none() {
742                                break;
743                            }
744                        }
745                        1 => {
746                            // Seek operation - choose a key >= last_returned_key
747                            if all_keys.is_empty() {
748                                continue;
749                            }
750
751                            let valid_keys: Vec<_> = all_keys
752                                .iter()
753                                .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
754                                .collect();
755
756                            if valid_keys.is_empty() {
757                                continue;
758                            }
759
760                            let key = *valid_keys[(choice as usize / 3) % valid_keys.len()];
761
762                            let control_result = control_cursor.seek(key).unwrap();
763                            let test_result = test_cursor.seek(key).unwrap();
764                            debug!(
765                                control=?control_result.as_ref().map(|(k, _)| k),
766                                test=?test_result.as_ref().map(|(k, _)| k),
767                                ?key,
768                                "Seek returned",
769                            );
770                            assert_eq!(control_result, test_result, "Seek operation mismatch for key {:?}", key);
771
772                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
773
774                            // Stop if both cursors are exhausted
775                            if control_result.is_none() && test_result.is_none() {
776                                break;
777                            }
778                        }
779                        _ => {
780                            // SeekExact operation - choose a key >= last_returned_key
781                            if all_keys.is_empty() {
782                                continue;
783                            }
784
785                            let valid_keys: Vec<_> = all_keys
786                                .iter()
787                                .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
788                                .collect();
789
790                            if valid_keys.is_empty() {
791                                continue;
792                            }
793
794                            let key = *valid_keys[(choice as usize / 3) % valid_keys.len()];
795
796                            let control_result = control_cursor.seek_exact(key).unwrap();
797                            let test_result = test_cursor.seek_exact(key).unwrap();
798                            debug!(
799                                control=?control_result.as_ref().map(|(k, _)| k),
800                                test=?test_result.as_ref().map(|(k, _)| k),
801                                ?key,
802                                "SeekExact returned",
803                            );
804                            assert_eq!(control_result, test_result, "SeekExact operation mismatch for key {:?}", key);
805
806                            // seek_exact updates the last_key internally but only if it found something
807                            last_returned_key = control_result.as_ref().map(|(k, _)| *k);
808                        }
809                    }
810                }
811            }
812        }
813    }
814}