1use super::{HashedCursor, HashedCursorFactory, HashedStorageCursor};
2use crate::forward_cursor::ForwardInMemoryCursor;
3use alloy_primitives::{B256, U256};
4use reth_primitives_traits::Account;
5use reth_storage_errors::db::DatabaseError;
6use reth_trie_common::HashedPostStateSorted;
7
8#[derive(Clone, Debug)]
10pub struct HashedPostStateCursorFactory<CF, T> {
11 cursor_factory: CF,
12 post_state: T,
13}
14
15impl<CF, T> HashedPostStateCursorFactory<CF, T> {
16 pub const fn new(cursor_factory: CF, post_state: T) -> Self {
18 Self { cursor_factory, post_state }
19 }
20}
21
22impl<'overlay, CF, T> HashedCursorFactory for HashedPostStateCursorFactory<CF, &'overlay T>
23where
24 CF: HashedCursorFactory,
25 T: AsRef<HashedPostStateSorted>,
26{
27 type AccountCursor<'cursor>
28 = HashedPostStateCursor<'overlay, CF::AccountCursor<'cursor>, Option<Account>>
29 where
30 Self: 'cursor;
31 type StorageCursor<'cursor>
32 = HashedPostStateCursor<'overlay, CF::StorageCursor<'cursor>, U256>
33 where
34 Self: 'cursor;
35
36 fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
37 let cursor = self.cursor_factory.hashed_account_cursor()?;
38 Ok(HashedPostStateCursor::new_account(cursor, self.post_state.as_ref()))
39 }
40
41 fn hashed_storage_cursor(
42 &self,
43 hashed_address: B256,
44 ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
45 let post_state = self.post_state.as_ref();
46 let cursor = self.cursor_factory.hashed_storage_cursor(hashed_address)?;
47 Ok(HashedPostStateCursor::new_storage(cursor, post_state, hashed_address))
48 }
49}
50
51pub trait HashedPostStateCursorValue: Copy {
61 type NonZero: Copy + std::fmt::Debug;
65
66 fn into_option(self) -> Option<Self::NonZero>;
68}
69
70impl HashedPostStateCursorValue for Option<Account> {
71 type NonZero = Account;
72
73 fn into_option(self) -> Option<Self::NonZero> {
74 self
75 }
76}
77
78impl HashedPostStateCursorValue for U256 {
79 type NonZero = Self;
80
81 fn into_option(self) -> Option<Self::NonZero> {
82 (!self.is_zero()).then_some(self)
83 }
84}
85
86#[derive(Debug)]
89pub struct HashedPostStateCursor<'a, C, V>
90where
91 V: HashedPostStateCursorValue,
92{
93 cursor: C,
95 db_cursor_state: DbCursorState<V::NonZero>,
97 post_state_cursor: ForwardInMemoryCursor<'a, B256, V>,
99 last_key: Option<B256>,
102 #[cfg(debug_assertions)]
103 seeked: bool,
105 post_state: &'a HashedPostStateSorted,
107}
108
109#[derive(Debug)]
110enum DbCursorState<V> {
111 NeedsPosition,
112 Positioned((B256, V)),
113 Exhausted,
114 Wiped,
115}
116
117impl<V> DbCursorState<V> {
118 const fn new(cursor_wiped: bool) -> Self {
119 if cursor_wiped {
120 Self::Wiped
121 } else {
122 Self::NeedsPosition
123 }
124 }
125
126 const fn entry(&self) -> Option<&(B256, V)> {
127 match self {
128 Self::Positioned(entry) => Some(entry),
129 Self::NeedsPosition | Self::Exhausted | Self::Wiped => None,
130 }
131 }
132
133 fn set_entry(&mut self, entry: Option<(B256, V)>) {
134 *self = match entry {
135 Some(entry) => Self::Positioned(entry),
136 None => Self::Exhausted,
137 };
138 }
139}
140
141impl<'a, C> HashedPostStateCursor<'a, C, Option<Account>>
142where
143 C: HashedCursor<Value = Account>,
144{
145 pub fn new_account(cursor: C, post_state: &'a HashedPostStateSorted) -> Self {
147 let post_state_cursor = ForwardInMemoryCursor::new(&post_state.accounts);
148 Self {
149 cursor,
150 db_cursor_state: DbCursorState::NeedsPosition,
151 post_state_cursor,
152 last_key: None,
153 #[cfg(debug_assertions)]
154 seeked: false,
155 post_state,
156 }
157 }
158}
159
160impl<'a, C> HashedPostStateCursor<'a, C, U256>
161where
162 C: HashedStorageCursor<Value = U256>,
163{
164 pub fn new_storage(
167 cursor: C,
168 post_state: &'a HashedPostStateSorted,
169 hashed_address: B256,
170 ) -> Self {
171 let (post_state_cursor, cursor_wiped) =
172 Self::get_storage_overlay(post_state, hashed_address);
173 Self {
174 cursor,
175 db_cursor_state: DbCursorState::new(cursor_wiped),
176 post_state_cursor,
177 last_key: None,
178 #[cfg(debug_assertions)]
179 seeked: false,
180 post_state,
181 }
182 }
183
184 fn get_storage_overlay(
186 post_state: &'a HashedPostStateSorted,
187 hashed_address: B256,
188 ) -> (ForwardInMemoryCursor<'a, B256, U256>, bool) {
189 let post_state_storage = post_state.storages.get(&hashed_address);
190 let cursor_wiped = post_state_storage.is_some_and(|u| u.is_wiped());
191 let storage_slots = post_state_storage.map(|u| u.storage_slots_ref()).unwrap_or(&[]);
192
193 (ForwardInMemoryCursor::new(storage_slots), cursor_wiped)
194 }
195}
196
197impl<'a, C, V> HashedPostStateCursor<'a, C, V>
198where
199 C: HashedCursor<Value = V::NonZero>,
200 V: HashedPostStateCursorValue,
201{
202 fn get_cursor_mut(&mut self) -> Option<&mut C> {
204 (!matches!(self.db_cursor_state, DbCursorState::Wiped)).then_some(&mut self.cursor)
205 }
206
207 fn set_last_key(&mut self, next_entry: &Option<(B256, V::NonZero)>) {
210 let next_key = next_entry.as_ref().map(|e| e.0);
211 debug_assert!(
212 self.last_key.is_none_or(|last| next_key.is_none_or(|next| next >= last)),
213 "Cannot return entry {:?} previous to the last returned entry at {:?}",
214 next_key,
215 self.last_key,
216 );
217 self.last_key = next_key;
218 }
219
220 fn cursor_seek(&mut self, key: B256) -> Result<(), DatabaseError> {
222 let should_seek = match &self.db_cursor_state {
226 DbCursorState::NeedsPosition => true,
227 DbCursorState::Positioned((entry_key, _)) => entry_key < &key,
228 DbCursorState::Exhausted | DbCursorState::Wiped => false,
229 };
230
231 if should_seek {
232 let entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten();
233 self.db_cursor_state.set_entry(entry);
234 }
235
236 Ok(())
237 }
238
239 fn cursor_next(&mut self) -> Result<(), DatabaseError> {
241 #[cfg(debug_assertions)]
242 {
243 debug_assert!(self.seeked);
244 }
245
246 if matches!(self.db_cursor_state, DbCursorState::Positioned(_)) {
248 let entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten();
249 self.db_cursor_state.set_entry(entry);
250 }
251
252 Ok(())
253 }
254
255 fn choose_next_entry(&mut self) -> Result<Option<(B256, V::NonZero)>, DatabaseError> {
261 loop {
262 let post_state_current =
263 self.post_state_cursor.current().copied().map(|(k, v)| (k, v.into_option()));
264 let db_entry = self.db_cursor_state.entry();
265
266 match (post_state_current, db_entry) {
267 (Some((mem_key, None)), _)
268 if db_entry.is_none_or(|(db_key, _)| &mem_key < db_key) =>
269 {
270 self.post_state_cursor.first_after(&mem_key);
274 }
275 (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
276 self.post_state_cursor.first_after(&mem_key);
279 self.cursor_next()?;
280 }
281 (Some((mem_key, Some(value))), _)
282 if db_entry.is_none_or(|(db_key, _)| &mem_key <= db_key) =>
283 {
284 return Ok(Some((mem_key, value)))
287 }
288 _ => return Ok(db_entry.copied()),
293 }
294 }
295 }
296}
297
298impl<C, V> HashedCursor for HashedPostStateCursor<'_, C, V>
299where
300 C: HashedCursor<Value = V::NonZero>,
301 V: HashedPostStateCursorValue,
302{
303 type Value = V::NonZero;
304
305 fn seek(&mut self, key: B256) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
314 let post_state_entry =
315 self.post_state_cursor.seek(&key).copied().map(|(k, v)| (k, v.into_option()));
316
317 if let Some((mem_key, Some(value))) = post_state_entry &&
318 mem_key == key
319 {
320 #[cfg(debug_assertions)]
321 {
322 self.seeked = true;
323 }
324
325 if matches!(&self.db_cursor_state, DbCursorState::Positioned((db_key, _)) if db_key < &key)
328 {
329 self.db_cursor_state = DbCursorState::NeedsPosition;
330 }
331
332 let entry = Some((key, value));
333 self.set_last_key(&entry);
334 return Ok(entry)
335 }
336
337 self.cursor_seek(key)?;
338
339 #[cfg(debug_assertions)]
340 {
341 self.seeked = true;
342 }
343
344 let entry = self.choose_next_entry()?;
345 self.set_last_key(&entry);
346 Ok(entry)
347 }
348
349 fn next(&mut self) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
356 #[cfg(debug_assertions)]
357 {
358 debug_assert!(self.seeked, "Cursor must be seek'd before next is called");
359 }
360
361 let Some(last_key) = self.last_key else {
363 return Ok(None);
364 };
365
366 if let Some((key, _)) = self.post_state_cursor.current() &&
369 key == &last_key
370 {
371 self.post_state_cursor.first_after(&last_key);
372 }
373
374 if matches!(self.db_cursor_state, DbCursorState::NeedsPosition) {
375 self.cursor_seek(last_key)?;
376 }
377
378 if let Some((key, _)) = self.db_cursor_state.entry() &&
379 key == &last_key
380 {
381 self.cursor_next()?;
382 }
383
384 let entry = self.choose_next_entry()?;
385 self.set_last_key(&entry);
386 Ok(entry)
387 }
388
389 fn reset(&mut self) {
390 let Self { cursor, db_cursor_state, post_state_cursor, last_key, .. } = self;
391
392 cursor.reset();
393 post_state_cursor.reset();
394
395 *db_cursor_state = DbCursorState::NeedsPosition;
396 *last_key = None;
397 #[cfg(debug_assertions)]
398 {
399 self.seeked = false;
400 }
401 }
402}
403
404impl<C> HashedStorageCursor for HashedPostStateCursor<'_, C, U256>
407where
408 C: HashedStorageCursor<Value = U256>,
409{
410 fn is_storage_empty(&mut self) -> Result<bool, DatabaseError> {
415 if self.post_state_cursor.has_any(|(_, value)| !value.is_zero()) {
417 return Ok(false);
418 }
419
420 self.get_cursor_mut().map_or(Ok(true), |c| c.is_storage_empty())
423 }
424
425 fn set_hashed_address(&mut self, hashed_address: B256) {
426 self.reset();
427 self.cursor.set_hashed_address(hashed_address);
428 let (post_state_cursor, cursor_wiped) =
429 HashedPostStateCursor::<C, U256>::get_storage_overlay(self.post_state, hashed_address);
430 self.post_state_cursor = post_state_cursor;
431 self.db_cursor_state = DbCursorState::new(cursor_wiped);
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::hashed_cursor::mock::MockHashedCursor;
439 use parking_lot::Mutex;
440 use std::{collections::BTreeMap, sync::Arc};
441
442 fn key(byte: u8) -> B256 {
443 B256::repeat_byte(byte)
444 }
445
446 fn storage_post_state(storage_slots: Vec<(B256, U256)>) -> HashedPostStateSorted {
447 let storage_sorted = reth_trie_common::HashedStorageSorted { storage_slots, wiped: false };
448 let mut storages = alloy_primitives::map::B256Map::default();
449 storages.insert(B256::ZERO, storage_sorted);
450 HashedPostStateSorted::new(Vec::new(), storages)
451 }
452
453 #[test]
454 fn test_seek_overlay_exact_hit_does_not_touch_db_until_next() {
455 let db_nodes = vec![(key(0x02), U256::from(2)), (key(0x03), U256::from(3))];
456 let post_state_nodes = vec![(key(0x02), U256::from(42))];
457
458 let db_nodes_map: BTreeMap<B256, U256> = db_nodes.into_iter().collect();
459 let db_nodes_arc = Arc::new(db_nodes_map);
460 let visited_keys = Arc::new(Mutex::new(Vec::new()));
461 let mock_cursor = MockHashedCursor::new(db_nodes_arc, visited_keys.clone());
462
463 let post_state = storage_post_state(post_state_nodes);
464 let mut cursor = HashedPostStateCursor::new_storage(mock_cursor, &post_state, B256::ZERO);
465
466 let result = cursor.seek(key(0x02)).unwrap();
467 assert_eq!(result, Some((key(0x02), U256::from(42))));
468 assert!(visited_keys.lock().is_empty(), "exact overlay hit should not touch the DB cursor");
469
470 let result = cursor.next().unwrap();
471 assert_eq!(result, Some((key(0x03), U256::from(3))));
472 assert!(!visited_keys.lock().is_empty(), "next should lazily position the DB cursor");
473 }
474
475 #[test]
476 fn test_seek_overlay_exact_hit_repositions_stale_db_on_next() {
477 let db_nodes = vec![(key(0x01), U256::from(1)), (key(0x03), U256::from(3))];
478 let post_state_nodes = vec![(key(0x02), U256::from(2))];
479
480 let db_nodes_map: BTreeMap<B256, U256> = db_nodes.into_iter().collect();
481 let db_nodes_arc = Arc::new(db_nodes_map);
482 let visited_keys = Arc::new(Mutex::new(Vec::new()));
483 let mock_cursor = MockHashedCursor::new(db_nodes_arc, visited_keys.clone());
484
485 let post_state = storage_post_state(post_state_nodes);
486 let mut cursor = HashedPostStateCursor::new_storage(mock_cursor, &post_state, B256::ZERO);
487
488 let result = cursor.seek(key(0x01)).unwrap();
489 assert_eq!(result, Some((key(0x01), U256::from(1))));
490 assert_eq!(visited_keys.lock().len(), 1);
491
492 let result = cursor.seek(key(0x02)).unwrap();
493 assert_eq!(result, Some((key(0x02), U256::from(2))));
494 assert_eq!(visited_keys.lock().len(), 1, "exact overlay hit should not seek the DB");
495
496 let result = cursor.next().unwrap();
497 assert_eq!(result, Some((key(0x03), U256::from(3))));
498 }
499
500 #[test]
501 fn test_seek_overlay_exact_deletion_still_seeks_db() {
502 let db_nodes = vec![(key(0x02), U256::from(2)), (key(0x03), U256::from(3))];
503 let post_state_nodes = vec![(key(0x02), U256::ZERO)];
504
505 let db_nodes_map: BTreeMap<B256, U256> = db_nodes.into_iter().collect();
506 let db_nodes_arc = Arc::new(db_nodes_map);
507 let visited_keys = Arc::new(Mutex::new(Vec::new()));
508 let mock_cursor = MockHashedCursor::new(db_nodes_arc, visited_keys.clone());
509
510 let post_state = storage_post_state(post_state_nodes);
511 let mut cursor = HashedPostStateCursor::new_storage(mock_cursor, &post_state, B256::ZERO);
512
513 let result = cursor.seek(key(0x02)).unwrap();
514 assert_eq!(result, Some((key(0x03), U256::from(3))));
515 assert!(!visited_keys.lock().is_empty(), "exact overlay deletion should consult the DB");
516 }
517
518 mod proptest_tests {
519 use super::*;
520 use itertools::Itertools;
521 use proptest::prelude::*;
522
523 fn merge_with_overlay<V>(
526 db_nodes: Vec<(B256, V::NonZero)>,
527 post_state_nodes: Vec<(B256, V)>,
528 ) -> Vec<(B256, V::NonZero)>
529 where
530 V: HashedPostStateCursorValue,
531 V::NonZero: Copy,
532 {
533 db_nodes
534 .into_iter()
535 .merge_join_by(post_state_nodes, |db_entry, mem_entry| db_entry.0.cmp(&mem_entry.0))
536 .filter_map(|entry| match entry {
537 itertools::EitherOrBoth::Left((key, node)) => Some((key, node)),
539 itertools::EitherOrBoth::Right((key, wrapped)) => {
541 wrapped.into_option().map(|val| (key, val))
542 }
543 itertools::EitherOrBoth::Both(_, (key, wrapped)) => {
545 wrapped.into_option().map(|val| (key, val))
546 }
547 })
548 .collect()
549 }
550
551 fn u256_strategy() -> impl Strategy<Value = U256> {
553 any::<u64>().prop_map(U256::from)
554 }
555
556 fn sorted_db_nodes_strategy() -> impl Strategy<Value = Vec<(B256, U256)>> {
558 prop::collection::vec((any::<u8>(), u256_strategy()), 0..20).prop_map(|entries| {
559 let mut result: Vec<(B256, U256)> = entries
560 .into_iter()
561 .map(|(byte, value)| (B256::repeat_byte(byte), value))
562 .collect();
563 result.sort_by_key(|a| a.0);
564 result.dedup_by(|a, b| a.0 == b.0);
565 result
566 })
567 }
568
569 fn sorted_post_state_nodes_strategy() -> impl Strategy<Value = Vec<(B256, U256)>> {
571 prop::collection::vec((any::<u8>(), u256_strategy(), any::<bool>()), 0..20).prop_map(
573 |entries| {
574 let mut result: Vec<(B256, U256)> = entries
575 .into_iter()
576 .map(|(byte, value, is_deletion)| {
577 let effective_value = if is_deletion { U256::ZERO } else { value };
578 (B256::repeat_byte(byte), effective_value)
579 })
580 .collect();
581 result.sort_by_key(|a| a.0);
582 result.dedup_by(|a, b| a.0 == b.0);
583 result
584 },
585 )
586 }
587
588 proptest! {
589 #![proptest_config(ProptestConfig::with_cases(1000))]
590 #[test]
600 fn proptest_hashed_post_state_cursor(
601 db_nodes in sorted_db_nodes_strategy(),
602 post_state_nodes in sorted_post_state_nodes_strategy(),
603 op_choices in prop::collection::vec(any::<u8>(), 10..500),
604 ) {
605 reth_tracing::init_test_tracing();
606 use tracing::debug;
607
608 debug!("Starting proptest!");
609
610 let expected_combined = merge_with_overlay(db_nodes.clone(), post_state_nodes.clone());
613
614 let all_keys: Vec<B256> = expected_combined.iter().map(|(k, _)| *k).collect();
616
617 let control_db_map: BTreeMap<B256, U256> = expected_combined.into_iter().collect();
619 let control_db_arc = Arc::new(control_db_map);
620 let control_visited_keys = Arc::new(Mutex::new(Vec::new()));
621 let mut control_cursor = MockHashedCursor::new(control_db_arc, control_visited_keys);
622
623 let db_nodes_map: BTreeMap<B256, U256> = db_nodes.into_iter().collect();
625 let db_nodes_arc = Arc::new(db_nodes_map);
626 let visited_keys = Arc::new(Mutex::new(Vec::new()));
627 let mock_cursor = MockHashedCursor::new(db_nodes_arc, visited_keys);
628
629 let hashed_address = B256::ZERO;
631 let storage_sorted = reth_trie_common::HashedStorageSorted {
632 storage_slots: post_state_nodes,
633 wiped: false,
634 };
635 let mut storages = alloy_primitives::map::B256Map::default();
636 storages.insert(hashed_address, storage_sorted);
637 let post_state = HashedPostStateSorted::new(Vec::new(), storages);
638
639 let mut test_cursor = HashedPostStateCursor::new_storage(mock_cursor, &post_state, hashed_address);
640
641 let control_first = control_cursor.seek(B256::ZERO).unwrap();
643 let test_first = test_cursor.seek(B256::ZERO).unwrap();
644 debug!(
645 control=?control_first.as_ref().map(|(k, _)| k),
646 test=?test_first.as_ref().map(|(k, _)| k),
647 "Initial seek returned",
648 );
649 assert_eq!(control_first, test_first, "Initial seek mismatch");
650
651 if control_first.is_none() && test_first.is_none() {
653 return Ok(());
654 }
655
656 let mut last_returned_key = control_first.as_ref().map(|(k, _)| *k);
658
659 for choice in op_choices {
661 let op_type = choice % 2; match op_type {
664 0 => {
665 let control_result = control_cursor.next().unwrap();
667 let test_result = test_cursor.next().unwrap();
668 debug!(
669 control=?control_result.as_ref().map(|(k, _)| k),
670 test=?test_result.as_ref().map(|(k, _)| k),
671 "Next returned",
672 );
673 assert_eq!(control_result, test_result, "Next operation mismatch");
674
675 last_returned_key = control_result.as_ref().map(|(k, _)| *k);
676
677 if control_result.is_none() && test_result.is_none() {
679 break;
680 }
681 }
682 _ => {
683 if all_keys.is_empty() {
685 continue;
686 }
687
688 let valid_keys: Vec<_> = all_keys
689 .iter()
690 .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
691 .collect();
692
693 if valid_keys.is_empty() {
694 continue;
695 }
696
697 let key = *valid_keys[(choice as usize / 2) % valid_keys.len()];
698
699 let control_result = control_cursor.seek(key).unwrap();
700 let test_result = test_cursor.seek(key).unwrap();
701 debug!(
702 control=?control_result.as_ref().map(|(k, _)| k),
703 test=?test_result.as_ref().map(|(k, _)| k),
704 ?key,
705 "Seek returned",
706 );
707 assert_eq!(control_result, test_result, "Seek operation mismatch for key {:?}", key);
708
709 last_returned_key = control_result.as_ref().map(|(k, _)| *k);
710
711 if control_result.is_none() && test_result.is_none() {
713 break;
714 }
715 }
716 }
717 }
718 }
719 }
720 }
721}