1use super::{TrieCursor, TrieCursorFactory};
2use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
3use alloy_primitives::B256;
4use reth_storage_errors::db::DatabaseError;
5use reth_trie_common::{BranchNodeCompact, Nibbles};
6
7#[derive(Debug, Clone)]
9pub struct InMemoryTrieCursorFactory<CF, T> {
10 cursor_factory: CF,
12 trie_updates: T,
14}
15
16impl<CF, T> InMemoryTrieCursorFactory<CF, T> {
17 pub const fn new(cursor_factory: CF, trie_updates: T) -> Self {
19 Self { cursor_factory, trie_updates }
20 }
21}
22
23impl<'a, CF, T> TrieCursorFactory for InMemoryTrieCursorFactory<CF, &'a T>
24where
25 CF: TrieCursorFactory,
26 T: AsRef<TrieUpdatesSorted>,
27{
28 type AccountTrieCursor = InMemoryTrieCursor<'a, CF::AccountTrieCursor>;
29 type StorageTrieCursor = InMemoryTrieCursor<'a, CF::StorageTrieCursor>;
30
31 fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor, DatabaseError> {
32 let cursor = self.cursor_factory.account_trie_cursor()?;
33 Ok(InMemoryTrieCursor::new(Some(cursor), self.trie_updates.as_ref().account_nodes_ref()))
34 }
35
36 fn storage_trie_cursor(
37 &self,
38 hashed_address: B256,
39 ) -> Result<Self::StorageTrieCursor, DatabaseError> {
40 static EMPTY_UPDATES: Vec<(Nibbles, Option<BranchNodeCompact>)> = Vec::new();
42
43 let storage_trie_updates = self.trie_updates.as_ref().storage_tries.get(&hashed_address);
44 let (storage_nodes, cleared) = storage_trie_updates
45 .map(|u| (u.storage_nodes_ref(), u.is_deleted()))
46 .unwrap_or((&EMPTY_UPDATES, false));
47
48 let cursor = if cleared {
49 None
50 } else {
51 Some(self.cursor_factory.storage_trie_cursor(hashed_address)?)
52 };
53
54 Ok(InMemoryTrieCursor::new(cursor, storage_nodes))
55 }
56}
57
58#[derive(Debug)]
61pub struct InMemoryTrieCursor<'a, C> {
62 cursor: Option<C>,
64 in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
66 last_key: Option<Nibbles>,
68}
69
70impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
71 pub fn new(
74 cursor: Option<C>,
75 trie_updates: &'a [(Nibbles, Option<BranchNodeCompact>)],
76 ) -> Self {
77 let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates);
78 Self { cursor, in_memory_cursor, last_key: None }
79 }
80
81 fn seek_inner(
82 &mut self,
83 key: Nibbles,
84 exact: bool,
85 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
86 let mut mem_entry = self.in_memory_cursor.seek(&key);
87 let mut db_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten();
88
89 if exact {
92 return Ok(match (mem_entry, db_entry) {
93 (Some((mem_key, entry_inner)), _) if mem_key == key => {
94 entry_inner.map(|node| (key, node))
95 }
96 (_, Some((db_key, node))) if db_key == key => Some((key, node)),
97 _ => None,
98 })
99 }
100
101 loop {
102 match (mem_entry, &db_entry) {
103 (Some((mem_key, None)), _)
104 if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
105 {
106 mem_entry = self.in_memory_cursor.first_after(&mem_key);
110 }
111 (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
112 mem_entry = self.in_memory_cursor.first_after(&mem_key);
115 db_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten();
116 }
117 (Some((mem_key, Some(node))), _)
118 if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
119 {
120 return Ok(Some((mem_key, node)))
123 }
124 _ => return Ok(db_entry),
129 }
130 }
131 }
132
133 fn next_inner(
134 &mut self,
135 last: Nibbles,
136 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
137 let Some(key) = last.increment() else { return Ok(None) };
138 self.seek_inner(key, false)
139 }
140}
141
142impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
143 fn seek_exact(
144 &mut self,
145 key: Nibbles,
146 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
147 let entry = self.seek_inner(key, true)?;
148 self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
149 Ok(entry)
150 }
151
152 fn seek(
153 &mut self,
154 key: Nibbles,
155 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
156 let entry = self.seek_inner(key, false)?;
157 self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
158 Ok(entry)
159 }
160
161 fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
162 let next = match &self.last_key {
163 Some(last) => {
164 let entry = self.next_inner(*last)?;
165 self.last_key = entry.as_ref().map(|entry| entry.0);
166 entry
167 }
168 None => None,
170 };
171 Ok(next)
172 }
173
174 fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
175 match &self.last_key {
176 Some(key) => Ok(Some(*key)),
177 None => Ok(self.cursor.as_mut().map(|c| c.current()).transpose()?.flatten()),
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::trie_cursor::mock::MockTrieCursor;
186 use parking_lot::Mutex;
187 use std::{collections::BTreeMap, sync::Arc};
188
189 #[derive(Debug)]
190 struct InMemoryTrieCursorTestCase {
191 db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
192 in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
193 expected_results: Vec<(Nibbles, BranchNodeCompact)>,
194 }
195
196 fn execute_test(test_case: InMemoryTrieCursorTestCase) {
197 let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
198 test_case.db_nodes.into_iter().collect();
199 let db_nodes_arc = Arc::new(db_nodes_map);
200 let visited_keys = Arc::new(Mutex::new(Vec::new()));
201 let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
202
203 let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &test_case.in_memory_nodes);
204
205 let mut results = Vec::new();
206
207 if let Some(first_expected) = test_case.expected_results.first() &&
208 let Ok(Some(entry)) = cursor.seek(first_expected.0)
209 {
210 results.push(entry);
211 }
212
213 while let Ok(Some(entry)) = cursor.next() {
214 results.push(entry);
215 }
216
217 assert_eq!(
218 results, test_case.expected_results,
219 "Results mismatch.\nGot: {:?}\nExpected: {:?}",
220 results, test_case.expected_results
221 );
222 }
223
224 #[test]
225 fn test_empty_db_and_memory() {
226 let test_case = InMemoryTrieCursorTestCase {
227 db_nodes: vec![],
228 in_memory_nodes: vec![],
229 expected_results: vec![],
230 };
231 execute_test(test_case);
232 }
233
234 #[test]
235 fn test_only_db_nodes() {
236 let db_nodes = vec![
237 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
238 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
239 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
240 ];
241
242 let test_case = InMemoryTrieCursorTestCase {
243 db_nodes: db_nodes.clone(),
244 in_memory_nodes: vec![],
245 expected_results: db_nodes,
246 };
247 execute_test(test_case);
248 }
249
250 #[test]
251 fn test_only_in_memory_nodes() {
252 let in_memory_nodes = vec![
253 (
254 Nibbles::from_nibbles([0x1]),
255 Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
256 ),
257 (
258 Nibbles::from_nibbles([0x2]),
259 Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
260 ),
261 (
262 Nibbles::from_nibbles([0x3]),
263 Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
264 ),
265 ];
266
267 let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
268 .iter()
269 .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
270 .collect();
271
272 let test_case =
273 InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
274 execute_test(test_case);
275 }
276
277 #[test]
278 fn test_in_memory_overwrites_db() {
279 let db_nodes = vec![
280 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
281 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
282 ];
283
284 let in_memory_nodes = vec![
285 (
286 Nibbles::from_nibbles([0x1]),
287 Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
288 ),
289 (
290 Nibbles::from_nibbles([0x3]),
291 Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
292 ),
293 ];
294
295 let expected_results = vec![
296 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
297 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
298 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
299 ];
300
301 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
302 execute_test(test_case);
303 }
304
305 #[test]
306 fn test_in_memory_deletes_db_nodes() {
307 let db_nodes = vec![
308 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
309 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
310 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
311 ];
312
313 let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
314
315 let expected_results = vec![
316 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
317 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
318 ];
319
320 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
321 execute_test(test_case);
322 }
323
324 #[test]
325 fn test_complex_interleaving() {
326 let db_nodes = vec![
327 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
328 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
329 (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
330 (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
331 ];
332
333 let in_memory_nodes = vec![
334 (
335 Nibbles::from_nibbles([0x2]),
336 Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
337 ),
338 (Nibbles::from_nibbles([0x3]), None),
339 (
340 Nibbles::from_nibbles([0x4]),
341 Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
342 ),
343 (
344 Nibbles::from_nibbles([0x6]),
345 Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
346 ),
347 (Nibbles::from_nibbles([0x7]), None),
348 (
349 Nibbles::from_nibbles([0x8]),
350 Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
351 ),
352 ];
353
354 let expected_results = vec![
355 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
356 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
357 (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
358 (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
359 (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
360 (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
361 ];
362
363 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
364 execute_test(test_case);
365 }
366
367 #[test]
368 fn test_seek_exact() {
369 let db_nodes = vec![
370 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
371 (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
372 ];
373
374 let in_memory_nodes = vec![(
375 Nibbles::from_nibbles([0x2]),
376 Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
377 )];
378
379 let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
380 let db_nodes_arc = Arc::new(db_nodes_map);
381 let visited_keys = Arc::new(Mutex::new(Vec::new()));
382 let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
383
384 let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
385
386 let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
387 assert_eq!(
388 result,
389 Some((
390 Nibbles::from_nibbles([0x2]),
391 BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
392 ))
393 );
394
395 let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
396 assert_eq!(
397 result,
398 Some((
399 Nibbles::from_nibbles([0x3]),
400 BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
401 ))
402 );
403
404 let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
405 assert_eq!(result, None);
406 }
407
408 #[test]
409 fn test_multiple_consecutive_deletes() {
410 let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
411 .map(|i| {
412 (
413 Nibbles::from_nibbles([i]),
414 BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
415 )
416 })
417 .collect();
418
419 let in_memory_nodes = vec![
420 (Nibbles::from_nibbles([0x3]), None),
421 (Nibbles::from_nibbles([0x4]), None),
422 (Nibbles::from_nibbles([0x5]), None),
423 (Nibbles::from_nibbles([0x6]), None),
424 ];
425
426 let expected_results = vec![
427 (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
428 (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
429 (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
430 (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
431 (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
432 (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
433 ];
434
435 let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
436 execute_test(test_case);
437 }
438
439 #[test]
440 fn test_empty_db_with_in_memory_deletes() {
441 let in_memory_nodes = vec![
442 (Nibbles::from_nibbles([0x1]), None),
443 (
444 Nibbles::from_nibbles([0x2]),
445 Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
446 ),
447 (Nibbles::from_nibbles([0x3]), None),
448 ];
449
450 let expected_results = vec![(
451 Nibbles::from_nibbles([0x2]),
452 BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
453 )];
454
455 let test_case =
456 InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
457 execute_test(test_case);
458 }
459
460 #[test]
461 fn test_current_key_tracking() {
462 let db_nodes = vec![(
463 Nibbles::from_nibbles([0x2]),
464 BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
465 )];
466
467 let in_memory_nodes = vec![
468 (
469 Nibbles::from_nibbles([0x1]),
470 Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
471 ),
472 (
473 Nibbles::from_nibbles([0x3]),
474 Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
475 ),
476 ];
477
478 let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
479 let db_nodes_arc = Arc::new(db_nodes_map);
480 let visited_keys = Arc::new(Mutex::new(Vec::new()));
481 let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
482
483 let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
484
485 assert_eq!(cursor.current().unwrap(), None);
486
487 cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
488 assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
489
490 cursor.next().unwrap();
491 assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
492
493 cursor.next().unwrap();
494 assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
495 }
496}