reth_trie/trie_cursor/
in_memory.rs

1use super::{TrieCursor, TrieCursorFactory};
2use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
3use alloy_primitives::B256;
4use reth_storage_errors::db::DatabaseError;
5use reth_trie_common::{BranchNodeCompact, Nibbles};
6
7/// The trie cursor factory for the trie updates.
8#[derive(Debug, Clone)]
9pub struct InMemoryTrieCursorFactory<CF, T> {
10    /// Underlying trie cursor factory.
11    cursor_factory: CF,
12    /// Reference to sorted trie updates.
13    trie_updates: T,
14}
15
16impl<CF, T> InMemoryTrieCursorFactory<CF, T> {
17    /// Create a new trie cursor factory.
18    pub const fn new(cursor_factory: CF, trie_updates: T) -> Self {
19        Self { cursor_factory, trie_updates }
20    }
21}
22
23impl<'a, CF, T> TrieCursorFactory for InMemoryTrieCursorFactory<CF, &'a T>
24where
25    CF: TrieCursorFactory,
26    T: AsRef<TrieUpdatesSorted>,
27{
28    type AccountTrieCursor = InMemoryTrieCursor<'a, CF::AccountTrieCursor>;
29    type StorageTrieCursor = InMemoryTrieCursor<'a, CF::StorageTrieCursor>;
30
31    fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor, DatabaseError> {
32        let cursor = self.cursor_factory.account_trie_cursor()?;
33        Ok(InMemoryTrieCursor::new(Some(cursor), self.trie_updates.as_ref().account_nodes_ref()))
34    }
35
36    fn storage_trie_cursor(
37        &self,
38        hashed_address: B256,
39    ) -> Result<Self::StorageTrieCursor, DatabaseError> {
40        // if the storage trie has no updates then we use this as the in-memory overlay.
41        static EMPTY_UPDATES: Vec<(Nibbles, Option<BranchNodeCompact>)> = Vec::new();
42
43        let storage_trie_updates = self.trie_updates.as_ref().storage_tries.get(&hashed_address);
44        let (storage_nodes, cleared) = storage_trie_updates
45            .map(|u| (u.storage_nodes_ref(), u.is_deleted()))
46            .unwrap_or((&EMPTY_UPDATES, false));
47
48        let cursor = if cleared {
49            None
50        } else {
51            Some(self.cursor_factory.storage_trie_cursor(hashed_address)?)
52        };
53
54        Ok(InMemoryTrieCursor::new(cursor, storage_nodes))
55    }
56}
57
58/// A cursor to iterate over trie updates and corresponding database entries.
59/// It will always give precedence to the data from the trie updates.
60#[derive(Debug)]
61pub struct InMemoryTrieCursor<'a, C> {
62    /// The underlying cursor. If None then it is assumed there is no DB data.
63    cursor: Option<C>,
64    /// Forward-only in-memory cursor over storage trie nodes.
65    in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
66    /// Last key returned by the cursor.
67    last_key: Option<Nibbles>,
68}
69
70impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
71    /// Create new trie cursor which combines a DB cursor (None to assume empty DB) and a set of
72    /// in-memory trie nodes.
73    pub fn new(
74        cursor: Option<C>,
75        trie_updates: &'a [(Nibbles, Option<BranchNodeCompact>)],
76    ) -> Self {
77        let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates);
78        Self { cursor, in_memory_cursor, last_key: None }
79    }
80
81    fn seek_inner(
82        &mut self,
83        key: Nibbles,
84        exact: bool,
85    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
86        let mut mem_entry = self.in_memory_cursor.seek(&key);
87        let mut db_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten();
88
89        // exact matching is easy, if overlay has a value then return that (updated or removed), or
90        // if db has a value then return that.
91        if exact {
92            return Ok(match (mem_entry, db_entry) {
93                (Some((mem_key, entry_inner)), _) if mem_key == key => {
94                    entry_inner.map(|node| (key, node))
95                }
96                (_, Some((db_key, node))) if db_key == key => Some((key, node)),
97                _ => None,
98            })
99        }
100
101        loop {
102            match (mem_entry, &db_entry) {
103                (Some((mem_key, None)), _)
104                    if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
105                {
106                    // If overlay has a removed node but DB cursor is exhausted or ahead of the
107                    // in-memory cursor then move ahead in-memory, as there might be further
108                    // non-removed overlay nodes.
109                    mem_entry = self.in_memory_cursor.first_after(&mem_key);
110                }
111                (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
112                    // If overlay has a removed node which is returned from DB then move both
113                    // cursors ahead to the next key.
114                    mem_entry = self.in_memory_cursor.first_after(&mem_key);
115                    db_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten();
116                }
117                (Some((mem_key, Some(node))), _)
118                    if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
119                {
120                    // If overlay returns a node prior to the DB's node, or the DB is exhausted,
121                    // then we return the overlay's node.
122                    return Ok(Some((mem_key, node)))
123                }
124                // All other cases:
125                // - mem_key > db_key
126                // - overlay is exhausted
127                // Return the db_entry. If DB is also exhausted then this returns None.
128                _ => return Ok(db_entry),
129            }
130        }
131    }
132
133    fn next_inner(
134        &mut self,
135        last: Nibbles,
136    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
137        let Some(key) = last.increment() else { return Ok(None) };
138        self.seek_inner(key, false)
139    }
140}
141
142impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
143    fn seek_exact(
144        &mut self,
145        key: Nibbles,
146    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
147        let entry = self.seek_inner(key, true)?;
148        self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
149        Ok(entry)
150    }
151
152    fn seek(
153        &mut self,
154        key: Nibbles,
155    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
156        let entry = self.seek_inner(key, false)?;
157        self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
158        Ok(entry)
159    }
160
161    fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
162        let next = match &self.last_key {
163            Some(last) => {
164                let entry = self.next_inner(*last)?;
165                self.last_key = entry.as_ref().map(|entry| entry.0);
166                entry
167            }
168            // no previous entry was found
169            None => None,
170        };
171        Ok(next)
172    }
173
174    fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
175        match &self.last_key {
176            Some(key) => Ok(Some(*key)),
177            None => Ok(self.cursor.as_mut().map(|c| c.current()).transpose()?.flatten()),
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::trie_cursor::mock::MockTrieCursor;
186    use parking_lot::Mutex;
187    use std::{collections::BTreeMap, sync::Arc};
188
189    #[derive(Debug)]
190    struct InMemoryTrieCursorTestCase {
191        db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
192        in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
193        expected_results: Vec<(Nibbles, BranchNodeCompact)>,
194    }
195
196    fn execute_test(test_case: InMemoryTrieCursorTestCase) {
197        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
198            test_case.db_nodes.into_iter().collect();
199        let db_nodes_arc = Arc::new(db_nodes_map);
200        let visited_keys = Arc::new(Mutex::new(Vec::new()));
201        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
202
203        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &test_case.in_memory_nodes);
204
205        let mut results = Vec::new();
206
207        if let Some(first_expected) = test_case.expected_results.first() &&
208            let Ok(Some(entry)) = cursor.seek(first_expected.0)
209        {
210            results.push(entry);
211        }
212
213        while let Ok(Some(entry)) = cursor.next() {
214            results.push(entry);
215        }
216
217        assert_eq!(
218            results, test_case.expected_results,
219            "Results mismatch.\nGot: {:?}\nExpected: {:?}",
220            results, test_case.expected_results
221        );
222    }
223
224    #[test]
225    fn test_empty_db_and_memory() {
226        let test_case = InMemoryTrieCursorTestCase {
227            db_nodes: vec![],
228            in_memory_nodes: vec![],
229            expected_results: vec![],
230        };
231        execute_test(test_case);
232    }
233
234    #[test]
235    fn test_only_db_nodes() {
236        let db_nodes = vec![
237            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
238            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
239            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
240        ];
241
242        let test_case = InMemoryTrieCursorTestCase {
243            db_nodes: db_nodes.clone(),
244            in_memory_nodes: vec![],
245            expected_results: db_nodes,
246        };
247        execute_test(test_case);
248    }
249
250    #[test]
251    fn test_only_in_memory_nodes() {
252        let in_memory_nodes = vec![
253            (
254                Nibbles::from_nibbles([0x1]),
255                Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
256            ),
257            (
258                Nibbles::from_nibbles([0x2]),
259                Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
260            ),
261            (
262                Nibbles::from_nibbles([0x3]),
263                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
264            ),
265        ];
266
267        let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
268            .iter()
269            .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
270            .collect();
271
272        let test_case =
273            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
274        execute_test(test_case);
275    }
276
277    #[test]
278    fn test_in_memory_overwrites_db() {
279        let db_nodes = vec![
280            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
281            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
282        ];
283
284        let in_memory_nodes = vec![
285            (
286                Nibbles::from_nibbles([0x1]),
287                Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
288            ),
289            (
290                Nibbles::from_nibbles([0x3]),
291                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
292            ),
293        ];
294
295        let expected_results = vec![
296            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
297            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
298            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
299        ];
300
301        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
302        execute_test(test_case);
303    }
304
305    #[test]
306    fn test_in_memory_deletes_db_nodes() {
307        let db_nodes = vec![
308            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
309            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
310            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
311        ];
312
313        let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
314
315        let expected_results = vec![
316            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
317            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
318        ];
319
320        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
321        execute_test(test_case);
322    }
323
324    #[test]
325    fn test_complex_interleaving() {
326        let db_nodes = vec![
327            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
328            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
329            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
330            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
331        ];
332
333        let in_memory_nodes = vec![
334            (
335                Nibbles::from_nibbles([0x2]),
336                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
337            ),
338            (Nibbles::from_nibbles([0x3]), None),
339            (
340                Nibbles::from_nibbles([0x4]),
341                Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
342            ),
343            (
344                Nibbles::from_nibbles([0x6]),
345                Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
346            ),
347            (Nibbles::from_nibbles([0x7]), None),
348            (
349                Nibbles::from_nibbles([0x8]),
350                Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
351            ),
352        ];
353
354        let expected_results = vec![
355            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
356            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
357            (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
358            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
359            (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
360            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
361        ];
362
363        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
364        execute_test(test_case);
365    }
366
367    #[test]
368    fn test_seek_exact() {
369        let db_nodes = vec![
370            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
371            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
372        ];
373
374        let in_memory_nodes = vec![(
375            Nibbles::from_nibbles([0x2]),
376            Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
377        )];
378
379        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
380        let db_nodes_arc = Arc::new(db_nodes_map);
381        let visited_keys = Arc::new(Mutex::new(Vec::new()));
382        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
383
384        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
385
386        let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
387        assert_eq!(
388            result,
389            Some((
390                Nibbles::from_nibbles([0x2]),
391                BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
392            ))
393        );
394
395        let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
396        assert_eq!(
397            result,
398            Some((
399                Nibbles::from_nibbles([0x3]),
400                BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
401            ))
402        );
403
404        let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
405        assert_eq!(result, None);
406    }
407
408    #[test]
409    fn test_multiple_consecutive_deletes() {
410        let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
411            .map(|i| {
412                (
413                    Nibbles::from_nibbles([i]),
414                    BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
415                )
416            })
417            .collect();
418
419        let in_memory_nodes = vec![
420            (Nibbles::from_nibbles([0x3]), None),
421            (Nibbles::from_nibbles([0x4]), None),
422            (Nibbles::from_nibbles([0x5]), None),
423            (Nibbles::from_nibbles([0x6]), None),
424        ];
425
426        let expected_results = vec![
427            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
428            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
429            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
430            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
431            (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
432            (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
433        ];
434
435        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
436        execute_test(test_case);
437    }
438
439    #[test]
440    fn test_empty_db_with_in_memory_deletes() {
441        let in_memory_nodes = vec![
442            (Nibbles::from_nibbles([0x1]), None),
443            (
444                Nibbles::from_nibbles([0x2]),
445                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
446            ),
447            (Nibbles::from_nibbles([0x3]), None),
448        ];
449
450        let expected_results = vec![(
451            Nibbles::from_nibbles([0x2]),
452            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
453        )];
454
455        let test_case =
456            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
457        execute_test(test_case);
458    }
459
460    #[test]
461    fn test_current_key_tracking() {
462        let db_nodes = vec![(
463            Nibbles::from_nibbles([0x2]),
464            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
465        )];
466
467        let in_memory_nodes = vec![
468            (
469                Nibbles::from_nibbles([0x1]),
470                Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
471            ),
472            (
473                Nibbles::from_nibbles([0x3]),
474                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
475            ),
476        ];
477
478        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
479        let db_nodes_arc = Arc::new(db_nodes_map);
480        let visited_keys = Arc::new(Mutex::new(Vec::new()));
481        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
482
483        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
484
485        assert_eq!(cursor.current().unwrap(), None);
486
487        cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
488        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
489
490        cursor.next().unwrap();
491        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
492
493        cursor.next().unwrap();
494        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
495    }
496}