reth_trie/trie_cursor/
in_memory.rs

1use super::{TrieCursor, TrieCursorFactory};
2use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
3use alloy_primitives::B256;
4use reth_storage_errors::db::DatabaseError;
5use reth_trie_common::{BranchNodeCompact, Nibbles};
6
7/// The trie cursor factory for the trie updates.
8#[derive(Debug, Clone)]
9pub struct InMemoryTrieCursorFactory<'a, CF> {
10    /// Underlying trie cursor factory.
11    cursor_factory: CF,
12    /// Reference to sorted trie updates.
13    trie_updates: &'a TrieUpdatesSorted,
14}
15
16impl<'a, CF> InMemoryTrieCursorFactory<'a, CF> {
17    /// Create a new trie cursor factory.
18    pub const fn new(cursor_factory: CF, trie_updates: &'a TrieUpdatesSorted) -> Self {
19        Self { cursor_factory, trie_updates }
20    }
21}
22
23impl<'a, CF: TrieCursorFactory> TrieCursorFactory for InMemoryTrieCursorFactory<'a, CF> {
24    type AccountTrieCursor = InMemoryTrieCursor<'a, CF::AccountTrieCursor>;
25    type StorageTrieCursor = InMemoryTrieCursor<'a, CF::StorageTrieCursor>;
26
27    fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor, DatabaseError> {
28        let cursor = self.cursor_factory.account_trie_cursor()?;
29        Ok(InMemoryTrieCursor::new(Some(cursor), self.trie_updates.account_nodes_ref()))
30    }
31
32    fn storage_trie_cursor(
33        &self,
34        hashed_address: B256,
35    ) -> Result<Self::StorageTrieCursor, DatabaseError> {
36        // if the storage trie has no updates then we use this as the in-memory overlay.
37        static EMPTY_UPDATES: Vec<(Nibbles, Option<BranchNodeCompact>)> = Vec::new();
38
39        let storage_trie_updates = self.trie_updates.storage_tries.get(&hashed_address);
40        let (storage_nodes, cleared) = storage_trie_updates
41            .map(|u| (u.storage_nodes_ref(), u.is_deleted()))
42            .unwrap_or((&EMPTY_UPDATES, false));
43
44        let cursor = if cleared {
45            None
46        } else {
47            Some(self.cursor_factory.storage_trie_cursor(hashed_address)?)
48        };
49
50        Ok(InMemoryTrieCursor::new(cursor, storage_nodes))
51    }
52}
53
54/// A cursor to iterate over trie updates and corresponding database entries.
55/// It will always give precedence to the data from the trie updates.
56#[derive(Debug)]
57pub struct InMemoryTrieCursor<'a, C> {
58    /// The underlying cursor. If None then it is assumed there is no DB data.
59    cursor: Option<C>,
60    /// Forward-only in-memory cursor over storage trie nodes.
61    in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
62    /// Last key returned by the cursor.
63    last_key: Option<Nibbles>,
64}
65
66impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
67    /// Create new trie cursor which combines a DB cursor (None to assume empty DB) and a set of
68    /// in-memory trie nodes.
69    pub fn new(
70        cursor: Option<C>,
71        trie_updates: &'a [(Nibbles, Option<BranchNodeCompact>)],
72    ) -> Self {
73        let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates);
74        Self { cursor, in_memory_cursor, last_key: None }
75    }
76
77    fn seek_inner(
78        &mut self,
79        key: Nibbles,
80        exact: bool,
81    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
82        let mut mem_entry = self.in_memory_cursor.seek(&key);
83        let mut db_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten();
84
85        // exact matching is easy, if overlay has a value then return that (updated or removed), or
86        // if db has a value then return that.
87        if exact {
88            return Ok(match (mem_entry, db_entry) {
89                (Some((mem_key, entry_inner)), _) if mem_key == key => {
90                    entry_inner.map(|node| (key, node))
91                }
92                (_, Some((db_key, node))) if db_key == key => Some((key, node)),
93                _ => None,
94            })
95        }
96
97        loop {
98            match (mem_entry, &db_entry) {
99                (Some((mem_key, None)), _)
100                    if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
101                {
102                    // If overlay has a removed node but DB cursor is exhausted or ahead of the
103                    // in-memory cursor then move ahead in-memory, as there might be further
104                    // non-removed overlay nodes.
105                    mem_entry = self.in_memory_cursor.first_after(&mem_key);
106                }
107                (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
108                    // If overlay has a removed node which is returned from DB then move both
109                    // cursors ahead to the next key.
110                    mem_entry = self.in_memory_cursor.first_after(&mem_key);
111                    db_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten();
112                }
113                (Some((mem_key, Some(node))), _)
114                    if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
115                {
116                    // If overlay returns a node prior to the DB's node, or the DB is exhausted,
117                    // then we return the overlay's node.
118                    return Ok(Some((mem_key, node)))
119                }
120                // All other cases:
121                // - mem_key > db_key
122                // - overlay is exhausted
123                // Return the db_entry. If DB is also exhausted then this returns None.
124                _ => return Ok(db_entry),
125            }
126        }
127    }
128
129    fn next_inner(
130        &mut self,
131        last: Nibbles,
132    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
133        let Some(key) = last.increment() else { return Ok(None) };
134        self.seek_inner(key, false)
135    }
136}
137
138impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
139    fn seek_exact(
140        &mut self,
141        key: Nibbles,
142    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
143        let entry = self.seek_inner(key, true)?;
144        self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
145        Ok(entry)
146    }
147
148    fn seek(
149        &mut self,
150        key: Nibbles,
151    ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
152        let entry = self.seek_inner(key, false)?;
153        self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
154        Ok(entry)
155    }
156
157    fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
158        let next = match &self.last_key {
159            Some(last) => {
160                let entry = self.next_inner(*last)?;
161                self.last_key = entry.as_ref().map(|entry| entry.0);
162                entry
163            }
164            // no previous entry was found
165            None => None,
166        };
167        Ok(next)
168    }
169
170    fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
171        match &self.last_key {
172            Some(key) => Ok(Some(*key)),
173            None => Ok(self.cursor.as_mut().map(|c| c.current()).transpose()?.flatten()),
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::trie_cursor::mock::MockTrieCursor;
182    use parking_lot::Mutex;
183    use std::{collections::BTreeMap, sync::Arc};
184
185    #[derive(Debug)]
186    struct InMemoryTrieCursorTestCase {
187        db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
188        in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
189        expected_results: Vec<(Nibbles, BranchNodeCompact)>,
190    }
191
192    fn execute_test(test_case: InMemoryTrieCursorTestCase) {
193        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
194            test_case.db_nodes.into_iter().collect();
195        let db_nodes_arc = Arc::new(db_nodes_map);
196        let visited_keys = Arc::new(Mutex::new(Vec::new()));
197        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
198
199        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &test_case.in_memory_nodes);
200
201        let mut results = Vec::new();
202
203        if let Some(first_expected) = test_case.expected_results.first() {
204            if let Ok(Some(entry)) = cursor.seek(first_expected.0) {
205                results.push(entry);
206            }
207        }
208
209        while let Ok(Some(entry)) = cursor.next() {
210            results.push(entry);
211        }
212
213        assert_eq!(
214            results, test_case.expected_results,
215            "Results mismatch.\nGot: {:?}\nExpected: {:?}",
216            results, test_case.expected_results
217        );
218    }
219
220    #[test]
221    fn test_empty_db_and_memory() {
222        let test_case = InMemoryTrieCursorTestCase {
223            db_nodes: vec![],
224            in_memory_nodes: vec![],
225            expected_results: vec![],
226        };
227        execute_test(test_case);
228    }
229
230    #[test]
231    fn test_only_db_nodes() {
232        let db_nodes = vec![
233            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
234            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
235            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
236        ];
237
238        let test_case = InMemoryTrieCursorTestCase {
239            db_nodes: db_nodes.clone(),
240            in_memory_nodes: vec![],
241            expected_results: db_nodes,
242        };
243        execute_test(test_case);
244    }
245
246    #[test]
247    fn test_only_in_memory_nodes() {
248        let in_memory_nodes = vec![
249            (
250                Nibbles::from_nibbles([0x1]),
251                Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
252            ),
253            (
254                Nibbles::from_nibbles([0x2]),
255                Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
256            ),
257            (
258                Nibbles::from_nibbles([0x3]),
259                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
260            ),
261        ];
262
263        let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
264            .iter()
265            .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
266            .collect();
267
268        let test_case =
269            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
270        execute_test(test_case);
271    }
272
273    #[test]
274    fn test_in_memory_overwrites_db() {
275        let db_nodes = vec![
276            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
277            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
278        ];
279
280        let in_memory_nodes = vec![
281            (
282                Nibbles::from_nibbles([0x1]),
283                Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
284            ),
285            (
286                Nibbles::from_nibbles([0x3]),
287                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
288            ),
289        ];
290
291        let expected_results = vec![
292            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
293            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
294            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
295        ];
296
297        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
298        execute_test(test_case);
299    }
300
301    #[test]
302    fn test_in_memory_deletes_db_nodes() {
303        let db_nodes = vec![
304            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
305            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
306            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
307        ];
308
309        let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
310
311        let expected_results = vec![
312            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
313            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
314        ];
315
316        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
317        execute_test(test_case);
318    }
319
320    #[test]
321    fn test_complex_interleaving() {
322        let db_nodes = vec![
323            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
324            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
325            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
326            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
327        ];
328
329        let in_memory_nodes = vec![
330            (
331                Nibbles::from_nibbles([0x2]),
332                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
333            ),
334            (Nibbles::from_nibbles([0x3]), None),
335            (
336                Nibbles::from_nibbles([0x4]),
337                Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
338            ),
339            (
340                Nibbles::from_nibbles([0x6]),
341                Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
342            ),
343            (Nibbles::from_nibbles([0x7]), None),
344            (
345                Nibbles::from_nibbles([0x8]),
346                Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
347            ),
348        ];
349
350        let expected_results = vec![
351            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
352            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
353            (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
354            (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
355            (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
356            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
357        ];
358
359        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
360        execute_test(test_case);
361    }
362
363    #[test]
364    fn test_seek_exact() {
365        let db_nodes = vec![
366            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
367            (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
368        ];
369
370        let in_memory_nodes = vec![(
371            Nibbles::from_nibbles([0x2]),
372            Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
373        )];
374
375        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
376        let db_nodes_arc = Arc::new(db_nodes_map);
377        let visited_keys = Arc::new(Mutex::new(Vec::new()));
378        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
379
380        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
381
382        let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
383        assert_eq!(
384            result,
385            Some((
386                Nibbles::from_nibbles([0x2]),
387                BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
388            ))
389        );
390
391        let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
392        assert_eq!(
393            result,
394            Some((
395                Nibbles::from_nibbles([0x3]),
396                BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
397            ))
398        );
399
400        let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
401        assert_eq!(result, None);
402    }
403
404    #[test]
405    fn test_multiple_consecutive_deletes() {
406        let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
407            .map(|i| {
408                (
409                    Nibbles::from_nibbles([i]),
410                    BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
411                )
412            })
413            .collect();
414
415        let in_memory_nodes = vec![
416            (Nibbles::from_nibbles([0x3]), None),
417            (Nibbles::from_nibbles([0x4]), None),
418            (Nibbles::from_nibbles([0x5]), None),
419            (Nibbles::from_nibbles([0x6]), None),
420        ];
421
422        let expected_results = vec![
423            (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
424            (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
425            (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
426            (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
427            (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
428            (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
429        ];
430
431        let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
432        execute_test(test_case);
433    }
434
435    #[test]
436    fn test_empty_db_with_in_memory_deletes() {
437        let in_memory_nodes = vec![
438            (Nibbles::from_nibbles([0x1]), None),
439            (
440                Nibbles::from_nibbles([0x2]),
441                Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
442            ),
443            (Nibbles::from_nibbles([0x3]), None),
444        ];
445
446        let expected_results = vec![(
447            Nibbles::from_nibbles([0x2]),
448            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
449        )];
450
451        let test_case =
452            InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
453        execute_test(test_case);
454    }
455
456    #[test]
457    fn test_current_key_tracking() {
458        let db_nodes = vec![(
459            Nibbles::from_nibbles([0x2]),
460            BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
461        )];
462
463        let in_memory_nodes = vec![
464            (
465                Nibbles::from_nibbles([0x1]),
466                Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
467            ),
468            (
469                Nibbles::from_nibbles([0x3]),
470                Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
471            ),
472        ];
473
474        let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
475        let db_nodes_arc = Arc::new(db_nodes_map);
476        let visited_keys = Arc::new(Mutex::new(Vec::new()));
477        let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
478
479        let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
480
481        assert_eq!(cursor.current().unwrap(), None);
482
483        cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
484        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
485
486        cursor.next().unwrap();
487        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
488
489        cursor.next().unwrap();
490        assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
491    }
492}