1use super::{HashedCursor, HashedCursorFactory, HashedStorageCursor};
2use crate::forward_cursor::ForwardInMemoryCursor;
3use alloy_primitives::{B256, U256};
4use reth_primitives_traits::Account;
5use reth_storage_errors::db::DatabaseError;
6use reth_trie_common::HashedPostStateSorted;
7
8#[derive(Clone, Debug)]
10pub struct HashedPostStateCursorFactory<CF, T> {
11 cursor_factory: CF,
12 post_state: T,
13}
14
15impl<CF, T> HashedPostStateCursorFactory<CF, T> {
16 pub const fn new(cursor_factory: CF, post_state: T) -> Self {
18 Self { cursor_factory, post_state }
19 }
20}
21
22impl<'overlay, CF, T> HashedCursorFactory for HashedPostStateCursorFactory<CF, &'overlay T>
23where
24 CF: HashedCursorFactory,
25 T: AsRef<HashedPostStateSorted>,
26{
27 type AccountCursor<'cursor>
28 = HashedPostStateCursor<'overlay, CF::AccountCursor<'cursor>, Option<Account>>
29 where
30 Self: 'cursor;
31 type StorageCursor<'cursor>
32 = HashedPostStateCursor<'overlay, CF::StorageCursor<'cursor>, U256>
33 where
34 Self: 'cursor;
35
36 fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
37 let cursor = self.cursor_factory.hashed_account_cursor()?;
38 Ok(HashedPostStateCursor::new_account(cursor, self.post_state.as_ref()))
39 }
40
41 fn hashed_storage_cursor(
42 &self,
43 hashed_address: B256,
44 ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
45 let post_state = self.post_state.as_ref();
46 let cursor = self.cursor_factory.hashed_storage_cursor(hashed_address)?;
47 Ok(HashedPostStateCursor::new_storage(cursor, post_state, hashed_address))
48 }
49}
50
51pub trait HashedPostStateCursorValue: Copy {
61 type NonZero: Copy + std::fmt::Debug;
65
66 fn into_option(self) -> Option<Self::NonZero>;
68}
69
70impl HashedPostStateCursorValue for Option<Account> {
71 type NonZero = Account;
72
73 fn into_option(self) -> Option<Self::NonZero> {
74 self
75 }
76}
77
78impl HashedPostStateCursorValue for U256 {
79 type NonZero = Self;
80
81 fn into_option(self) -> Option<Self::NonZero> {
82 (self != Self::ZERO).then_some(self)
83 }
84}
85
86#[derive(Debug)]
89pub struct HashedPostStateCursor<'a, C, V>
90where
91 V: HashedPostStateCursorValue,
92{
93 cursor: C,
95 cursor_wiped: bool,
97 cursor_entry: Option<(B256, V::NonZero)>,
99 post_state_cursor: ForwardInMemoryCursor<'a, B256, V>,
101 last_key: Option<B256>,
104 seeked: bool,
107 post_state: &'a HashedPostStateSorted,
109}
110
111impl<'a, C> HashedPostStateCursor<'a, C, Option<Account>>
112where
113 C: HashedCursor<Value = Account>,
114{
115 pub fn new_account(cursor: C, post_state: &'a HashedPostStateSorted) -> Self {
117 let post_state_cursor = ForwardInMemoryCursor::new(&post_state.accounts);
118 Self {
119 cursor,
120 cursor_wiped: false,
121 cursor_entry: None,
122 post_state_cursor,
123 last_key: None,
124 seeked: false,
125 post_state,
126 }
127 }
128}
129
130impl<'a, C> HashedPostStateCursor<'a, C, U256>
131where
132 C: HashedStorageCursor<Value = U256>,
133{
134 pub fn new_storage(
137 cursor: C,
138 post_state: &'a HashedPostStateSorted,
139 hashed_address: B256,
140 ) -> Self {
141 let (post_state_cursor, cursor_wiped) =
142 Self::get_storage_overlay(post_state, hashed_address);
143 Self {
144 cursor,
145 cursor_wiped,
146 cursor_entry: None,
147 post_state_cursor,
148 last_key: None,
149 seeked: false,
150 post_state,
151 }
152 }
153
154 fn get_storage_overlay(
156 post_state: &'a HashedPostStateSorted,
157 hashed_address: B256,
158 ) -> (ForwardInMemoryCursor<'a, B256, U256>, bool) {
159 let post_state_storage = post_state.storages.get(&hashed_address);
160 let cursor_wiped = post_state_storage.is_some_and(|u| u.is_wiped());
161 let storage_slots = post_state_storage.map(|u| u.storage_slots_ref()).unwrap_or(&[]);
162
163 (ForwardInMemoryCursor::new(storage_slots), cursor_wiped)
164 }
165}
166
167impl<'a, C, V> HashedPostStateCursor<'a, C, V>
168where
169 C: HashedCursor<Value = V::NonZero>,
170 V: HashedPostStateCursorValue,
171{
172 fn get_cursor_mut(&mut self) -> Option<&mut C> {
174 (!self.cursor_wiped).then_some(&mut self.cursor)
175 }
176
177 fn set_last_key(&mut self, next_entry: &Option<(B256, V::NonZero)>) {
180 let next_key = next_entry.as_ref().map(|e| e.0);
181 debug_assert!(
182 self.last_key.is_none_or(|last| next_key.is_none_or(|next| next >= last)),
183 "Cannot return entry {:?} previous to the last returned entry at {:?}",
184 next_key,
185 self.last_key,
186 );
187 self.last_key = next_key;
188 }
189
190 fn cursor_seek(&mut self, key: B256) -> Result<(), DatabaseError> {
192 let should_seek = match self.cursor_entry.as_ref() {
196 Some(entry) => entry.0 < key,
197 None => !self.seeked,
198 };
199
200 if should_seek {
201 self.cursor_entry = self.get_cursor_mut().map(|c| c.seek(key)).transpose()?.flatten();
202 }
203
204 Ok(())
205 }
206
207 fn cursor_next(&mut self) -> Result<(), DatabaseError> {
209 debug_assert!(self.seeked);
210
211 if self.cursor_entry.is_some() {
214 self.cursor_entry = self.get_cursor_mut().map(|c| c.next()).transpose()?.flatten();
215 }
216
217 Ok(())
218 }
219
220 fn choose_next_entry(&mut self) -> Result<Option<(B256, V::NonZero)>, DatabaseError> {
226 loop {
227 let post_state_current =
228 self.post_state_cursor.current().copied().map(|(k, v)| (k, v.into_option()));
229
230 match (post_state_current, &self.cursor_entry) {
231 (Some((mem_key, None)), _)
232 if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
233 {
234 self.post_state_cursor.first_after(&mem_key);
238 }
239 (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
240 self.post_state_cursor.first_after(&mem_key);
243 self.cursor_next()?;
244 }
245 (Some((mem_key, Some(value))), _)
246 if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
247 {
248 return Ok(Some((mem_key, value)))
251 }
252 _ => return Ok(self.cursor_entry),
257 }
258 }
259 }
260}
261
262impl<C, V> HashedCursor for HashedPostStateCursor<'_, C, V>
263where
264 C: HashedCursor<Value = V::NonZero>,
265 V: HashedPostStateCursorValue,
266{
267 type Value = V::NonZero;
268
269 fn seek(&mut self, key: B256) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
278 self.cursor_seek(key)?;
279 self.post_state_cursor.seek(&key);
280
281 self.seeked = true;
282
283 let entry = self.choose_next_entry()?;
284 self.set_last_key(&entry);
285 Ok(entry)
286 }
287
288 fn next(&mut self) -> Result<Option<(B256, Self::Value)>, DatabaseError> {
295 debug_assert!(self.seeked, "Cursor must be seek'd before next is called");
296
297 let Some(last_key) = self.last_key else {
299 return Ok(None);
300 };
301
302 if let Some((key, _)) = self.post_state_cursor.current() &&
305 key == &last_key
306 {
307 self.post_state_cursor.first_after(&last_key);
308 }
309
310 if let Some((key, _)) = &self.cursor_entry &&
311 key == &last_key
312 {
313 self.cursor_next()?;
314 }
315
316 let entry = self.choose_next_entry()?;
317 self.set_last_key(&entry);
318 Ok(entry)
319 }
320
321 fn reset(&mut self) {
322 let Self {
323 cursor,
324 cursor_wiped,
325 cursor_entry,
326 post_state_cursor,
327 last_key,
328 seeked,
329 post_state: _,
330 } = self;
331
332 cursor.reset();
333 post_state_cursor.reset();
334
335 *cursor_wiped = false;
336 *cursor_entry = None;
337 *last_key = None;
338 *seeked = false;
339 }
340}
341
342impl<C> HashedStorageCursor for HashedPostStateCursor<'_, C, U256>
345where
346 C: HashedStorageCursor<Value = U256>,
347{
348 fn is_storage_empty(&mut self) -> Result<bool, DatabaseError> {
353 if self.post_state_cursor.has_any(|(_, value)| value.into_option().is_some()) {
355 return Ok(false);
356 }
357
358 self.get_cursor_mut().map_or(Ok(true), |c| c.is_storage_empty())
361 }
362
363 fn set_hashed_address(&mut self, hashed_address: B256) {
364 self.reset();
365 self.cursor.set_hashed_address(hashed_address);
366 (self.post_state_cursor, self.cursor_wiped) =
367 HashedPostStateCursor::<C, U256>::get_storage_overlay(self.post_state, hashed_address);
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::hashed_cursor::mock::MockHashedCursor;
375 use parking_lot::Mutex;
376 use std::{collections::BTreeMap, sync::Arc};
377
378 mod proptest_tests {
379 use super::*;
380 use itertools::Itertools;
381 use proptest::prelude::*;
382
383 fn merge_with_overlay<V>(
386 db_nodes: Vec<(B256, V::NonZero)>,
387 post_state_nodes: Vec<(B256, V)>,
388 ) -> Vec<(B256, V::NonZero)>
389 where
390 V: HashedPostStateCursorValue,
391 V::NonZero: Copy,
392 {
393 db_nodes
394 .into_iter()
395 .merge_join_by(post_state_nodes, |db_entry, mem_entry| db_entry.0.cmp(&mem_entry.0))
396 .filter_map(|entry| match entry {
397 itertools::EitherOrBoth::Left((key, node)) => Some((key, node)),
399 itertools::EitherOrBoth::Right((key, wrapped)) => {
401 wrapped.into_option().map(|val| (key, val))
402 }
403 itertools::EitherOrBoth::Both(_, (key, wrapped)) => {
405 wrapped.into_option().map(|val| (key, val))
406 }
407 })
408 .collect()
409 }
410
411 fn u256_strategy() -> impl Strategy<Value = U256> {
413 any::<u64>().prop_map(U256::from)
414 }
415
416 fn sorted_db_nodes_strategy() -> impl Strategy<Value = Vec<(B256, U256)>> {
418 prop::collection::vec((any::<u8>(), u256_strategy()), 0..20).prop_map(|entries| {
419 let mut result: Vec<(B256, U256)> = entries
420 .into_iter()
421 .map(|(byte, value)| (B256::repeat_byte(byte), value))
422 .collect();
423 result.sort_by(|a, b| a.0.cmp(&b.0));
424 result.dedup_by(|a, b| a.0 == b.0);
425 result
426 })
427 }
428
429 fn sorted_post_state_nodes_strategy() -> impl Strategy<Value = Vec<(B256, U256)>> {
431 prop::collection::vec((any::<u8>(), u256_strategy()), 0..20).prop_map(|entries| {
432 let mut result: Vec<(B256, U256)> = entries
433 .into_iter()
434 .map(|(byte, value)| (B256::repeat_byte(byte), value))
435 .collect();
436 result.sort_by(|a, b| a.0.cmp(&b.0));
437 result.dedup_by(|a, b| a.0 == b.0);
438 result
439 })
440 }
441
442 proptest! {
443 #![proptest_config(ProptestConfig::with_cases(1000))]
444 #[test]
454 fn proptest_hashed_post_state_cursor(
455 db_nodes in sorted_db_nodes_strategy(),
456 post_state_nodes in sorted_post_state_nodes_strategy(),
457 op_choices in prop::collection::vec(any::<u8>(), 10..500),
458 ) {
459 reth_tracing::init_test_tracing();
460 use tracing::debug;
461
462 debug!("Starting proptest!");
463
464 let expected_combined = merge_with_overlay(db_nodes.clone(), post_state_nodes.clone());
467
468 let all_keys: Vec<B256> = expected_combined.iter().map(|(k, _)| *k).collect();
470
471 let control_db_map: BTreeMap<B256, U256> = expected_combined.into_iter().collect();
473 let control_db_arc = Arc::new(control_db_map);
474 let control_visited_keys = Arc::new(Mutex::new(Vec::new()));
475 let mut control_cursor = MockHashedCursor::new(control_db_arc, control_visited_keys);
476
477 let db_nodes_map: BTreeMap<B256, U256> = db_nodes.into_iter().collect();
479 let db_nodes_arc = Arc::new(db_nodes_map);
480 let visited_keys = Arc::new(Mutex::new(Vec::new()));
481 let mock_cursor = MockHashedCursor::new(db_nodes_arc, visited_keys);
482
483 let hashed_address = B256::ZERO;
485 let storage_sorted = reth_trie_common::HashedStorageSorted {
486 storage_slots: post_state_nodes,
487 wiped: false,
488 };
489 let mut storages = alloy_primitives::map::B256Map::default();
490 storages.insert(hashed_address, storage_sorted);
491 let post_state = HashedPostStateSorted::new(Vec::new(), storages);
492
493 let mut test_cursor = HashedPostStateCursor::new_storage(mock_cursor, &post_state, hashed_address);
494
495 let control_first = control_cursor.seek(B256::ZERO).unwrap();
497 let test_first = test_cursor.seek(B256::ZERO).unwrap();
498 debug!(
499 control=?control_first.as_ref().map(|(k, _)| k),
500 test=?test_first.as_ref().map(|(k, _)| k),
501 "Initial seek returned",
502 );
503 assert_eq!(control_first, test_first, "Initial seek mismatch");
504
505 if control_first.is_none() && test_first.is_none() {
507 return Ok(());
508 }
509
510 let mut last_returned_key = control_first.as_ref().map(|(k, _)| *k);
512
513 for choice in op_choices {
515 let op_type = choice % 2; match op_type {
518 0 => {
519 let control_result = control_cursor.next().unwrap();
521 let test_result = test_cursor.next().unwrap();
522 debug!(
523 control=?control_result.as_ref().map(|(k, _)| k),
524 test=?test_result.as_ref().map(|(k, _)| k),
525 "Next returned",
526 );
527 assert_eq!(control_result, test_result, "Next operation mismatch");
528
529 last_returned_key = control_result.as_ref().map(|(k, _)| *k);
530
531 if control_result.is_none() && test_result.is_none() {
533 break;
534 }
535 }
536 _ => {
537 if all_keys.is_empty() {
539 continue;
540 }
541
542 let valid_keys: Vec<_> = all_keys
543 .iter()
544 .filter(|k| last_returned_key.is_none_or(|last| **k >= last))
545 .collect();
546
547 if valid_keys.is_empty() {
548 continue;
549 }
550
551 let key = *valid_keys[(choice as usize / 2) % valid_keys.len()];
552
553 let control_result = control_cursor.seek(key).unwrap();
554 let test_result = test_cursor.seek(key).unwrap();
555 debug!(
556 control=?control_result.as_ref().map(|(k, _)| k),
557 test=?test_result.as_ref().map(|(k, _)| k),
558 ?key,
559 "Seek returned",
560 );
561 assert_eq!(control_result, test_result, "Seek operation mismatch for key {:?}", key);
562
563 last_returned_key = control_result.as_ref().map(|(k, _)| *k);
564
565 if control_result.is_none() && test_result.is_none() {
567 break;
568 }
569 }
570 }
571 }
572 }
573 }
574 }
575}