1use alloy_primitives::B256;
2use reth_db_api::{
3 cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW},
4 tables,
5 transaction::DbTx,
6 DatabaseError,
7};
8use reth_trie::{
9 trie_cursor::{TrieCursor, TrieCursorFactory},
10 updates::StorageTrieUpdatesSorted,
11 BranchNodeCompact, Nibbles, StorageTrieEntry, StoredNibbles, StoredNibblesSubKey,
12};
13
14#[derive(Debug, Clone)]
16pub struct DatabaseTrieCursorFactory<T>(T);
17
18impl<T> DatabaseTrieCursorFactory<T> {
19 pub const fn new(tx: T) -> Self {
21 Self(tx)
22 }
23}
24
25impl<TX> TrieCursorFactory for DatabaseTrieCursorFactory<&TX>
26where
27 TX: DbTx,
28{
29 type AccountTrieCursor<'a>
30 = DatabaseAccountTrieCursor<<TX as DbTx>::Cursor<tables::AccountsTrie>>
31 where
32 Self: 'a;
33
34 type StorageTrieCursor<'a>
35 = DatabaseStorageTrieCursor<<TX as DbTx>::DupCursor<tables::StoragesTrie>>
36 where
37 Self: 'a;
38
39 fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
40 Ok(DatabaseAccountTrieCursor::new(self.0.cursor_read::<tables::AccountsTrie>()?))
41 }
42
43 fn storage_trie_cursor(
44 &self,
45 hashed_address: B256,
46 ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
47 Ok(DatabaseStorageTrieCursor::new(
48 self.0.cursor_dup_read::<tables::StoragesTrie>()?,
49 hashed_address,
50 ))
51 }
52}
53
54#[derive(Debug)]
56pub struct DatabaseAccountTrieCursor<C>(pub(crate) C);
57
58impl<C> DatabaseAccountTrieCursor<C> {
59 pub const fn new(cursor: C) -> Self {
61 Self(cursor)
62 }
63}
64
65impl<C> TrieCursor for DatabaseAccountTrieCursor<C>
66where
67 C: DbCursorRO<tables::AccountsTrie> + Send + Sync,
68{
69 fn seek_exact(
71 &mut self,
72 key: Nibbles,
73 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
74 Ok(self.0.seek_exact(StoredNibbles(key))?.map(|value| (value.0 .0, value.1)))
75 }
76
77 fn seek(
79 &mut self,
80 key: Nibbles,
81 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
82 Ok(self.0.seek(StoredNibbles(key))?.map(|value| (value.0 .0, value.1)))
83 }
84
85 fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
87 Ok(self.0.next()?.map(|value| (value.0 .0, value.1)))
88 }
89
90 fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
92 Ok(self.0.current()?.map(|(k, _)| k.0))
93 }
94}
95
96#[derive(Debug)]
98pub struct DatabaseStorageTrieCursor<C> {
99 pub cursor: C,
101 hashed_address: B256,
103}
104
105impl<C> DatabaseStorageTrieCursor<C> {
106 pub const fn new(cursor: C, hashed_address: B256) -> Self {
108 Self { cursor, hashed_address }
109 }
110}
111
112impl<C> DatabaseStorageTrieCursor<C>
113where
114 C: DbCursorRO<tables::StoragesTrie>
115 + DbCursorRW<tables::StoragesTrie>
116 + DbDupCursorRO<tables::StoragesTrie>
117 + DbDupCursorRW<tables::StoragesTrie>,
118{
119 pub fn write_storage_trie_updates_sorted(
121 &mut self,
122 updates: &StorageTrieUpdatesSorted,
123 ) -> Result<usize, DatabaseError> {
124 if updates.is_deleted() && self.cursor.seek_exact(self.hashed_address)?.is_some() {
126 self.cursor.delete_current_duplicates()?;
127 }
128
129 let mut num_entries = 0;
130 for (nibbles, maybe_updated) in updates.storage_nodes.iter().filter(|(n, _)| !n.is_empty())
131 {
132 num_entries += 1;
133 let nibbles = StoredNibblesSubKey(*nibbles);
134 if self
136 .cursor
137 .seek_by_key_subkey(self.hashed_address, nibbles.clone())?
138 .filter(|e| e.nibbles == nibbles)
139 .is_some()
140 {
141 self.cursor.delete_current()?;
142 }
143
144 if let Some(node) = maybe_updated {
146 self.cursor.upsert(
147 self.hashed_address,
148 &StorageTrieEntry { nibbles, node: node.clone() },
149 )?;
150 }
151 }
152
153 Ok(num_entries)
154 }
155}
156
157impl<C> TrieCursor for DatabaseStorageTrieCursor<C>
158where
159 C: DbCursorRO<tables::StoragesTrie> + DbDupCursorRO<tables::StoragesTrie> + Send + Sync,
160{
161 fn seek_exact(
163 &mut self,
164 key: Nibbles,
165 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
166 Ok(self
167 .cursor
168 .seek_by_key_subkey(self.hashed_address, StoredNibblesSubKey(key))?
169 .filter(|e| e.nibbles == StoredNibblesSubKey(key))
170 .map(|value| (value.nibbles.0, value.node)))
171 }
172
173 fn seek(
175 &mut self,
176 key: Nibbles,
177 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
178 Ok(self
179 .cursor
180 .seek_by_key_subkey(self.hashed_address, StoredNibblesSubKey(key))?
181 .map(|value| (value.nibbles.0, value.node)))
182 }
183
184 fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
186 Ok(self.cursor.next_dup()?.map(|(_, v)| (v.nibbles.0, v.node)))
187 }
188
189 fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
191 Ok(self.cursor.current()?.map(|(_, v)| v.nibbles.0))
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use alloy_primitives::hex_literal::hex;
199 use reth_db_api::{cursor::DbCursorRW, transaction::DbTxMut};
200 use reth_provider::test_utils::create_test_provider_factory;
201
202 #[test]
203 fn test_account_trie_order() {
204 let factory = create_test_provider_factory();
205 let provider = factory.provider_rw().unwrap();
206 let mut cursor = provider.tx_ref().cursor_write::<tables::AccountsTrie>().unwrap();
207
208 let data = vec![
209 hex!("0303040e").to_vec(),
210 hex!("030305").to_vec(),
211 hex!("03030500").to_vec(),
212 hex!("0303050a").to_vec(),
213 ];
214
215 for key in data.clone() {
216 cursor
217 .upsert(
218 key.into(),
219 &BranchNodeCompact::new(
220 0b0000_0010_0000_0001,
221 0b0000_0010_0000_0001,
222 0,
223 Vec::default(),
224 None,
225 ),
226 )
227 .unwrap();
228 }
229
230 let db_data = cursor.walk_range(..).unwrap().collect::<Result<Vec<_>, _>>().unwrap();
231 assert_eq!(db_data[0].0 .0.to_vec(), data[0]);
232 assert_eq!(db_data[1].0 .0.to_vec(), data[1]);
233 assert_eq!(db_data[2].0 .0.to_vec(), data[2]);
234 assert_eq!(db_data[3].0 .0.to_vec(), data[3]);
235
236 assert_eq!(
237 cursor.seek(hex!("0303040f").to_vec().into()).unwrap().map(|(k, _)| k.0.to_vec()),
238 Some(data[1].clone())
239 );
240 }
241
242 #[test]
244 fn test_storage_cursor_abstraction() {
245 let factory = create_test_provider_factory();
246 let provider = factory.provider_rw().unwrap();
247 let mut cursor = provider.tx_ref().cursor_dup_write::<tables::StoragesTrie>().unwrap();
248
249 let hashed_address = B256::random();
250 let key = StoredNibblesSubKey::from(vec![0x2, 0x3]);
251 let value = BranchNodeCompact::new(1, 1, 1, vec![B256::random()], None);
252
253 cursor
254 .upsert(hashed_address, &StorageTrieEntry { nibbles: key.clone(), node: value.clone() })
255 .unwrap();
256
257 let mut cursor = DatabaseStorageTrieCursor::new(cursor, hashed_address);
258 assert_eq!(cursor.seek(key.into()).unwrap().unwrap().1, value);
259 }
260}