1use alloy_primitives::B256;
2use reth_db_api::{
3 cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW},
4 tables,
5 transaction::DbTx,
6 DatabaseError,
7};
8use reth_trie::{
9 trie_cursor::{TrieCursor, TrieCursorFactory},
10 updates::StorageTrieUpdates,
11 BranchNodeCompact, Nibbles, StorageTrieEntry, StoredNibbles, StoredNibblesSubKey,
12};
13
14#[derive(Debug, Clone)]
16pub struct DatabaseTrieCursorFactory<T>(T);
17
18impl<T> DatabaseTrieCursorFactory<T> {
19 pub const fn new(tx: T) -> Self {
21 Self(tx)
22 }
23}
24
25impl<TX> TrieCursorFactory for DatabaseTrieCursorFactory<&TX>
26where
27 TX: DbTx,
28{
29 type AccountTrieCursor = DatabaseAccountTrieCursor<<TX as DbTx>::Cursor<tables::AccountsTrie>>;
30 type StorageTrieCursor =
31 DatabaseStorageTrieCursor<<TX as DbTx>::DupCursor<tables::StoragesTrie>>;
32
33 fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor, DatabaseError> {
34 Ok(DatabaseAccountTrieCursor::new(self.0.cursor_read::<tables::AccountsTrie>()?))
35 }
36
37 fn storage_trie_cursor(
38 &self,
39 hashed_address: B256,
40 ) -> Result<Self::StorageTrieCursor, DatabaseError> {
41 Ok(DatabaseStorageTrieCursor::new(
42 self.0.cursor_dup_read::<tables::StoragesTrie>()?,
43 hashed_address,
44 ))
45 }
46}
47
48#[derive(Debug)]
50pub struct DatabaseAccountTrieCursor<C>(pub(crate) C);
51
52impl<C> DatabaseAccountTrieCursor<C> {
53 pub const fn new(cursor: C) -> Self {
55 Self(cursor)
56 }
57}
58
59impl<C> TrieCursor for DatabaseAccountTrieCursor<C>
60where
61 C: DbCursorRO<tables::AccountsTrie> + Send + Sync,
62{
63 fn seek_exact(
65 &mut self,
66 key: Nibbles,
67 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
68 Ok(self.0.seek_exact(StoredNibbles(key))?.map(|value| (value.0 .0, value.1)))
69 }
70
71 fn seek(
73 &mut self,
74 key: Nibbles,
75 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
76 Ok(self.0.seek(StoredNibbles(key))?.map(|value| (value.0 .0, value.1)))
77 }
78
79 fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
81 Ok(self.0.next()?.map(|value| (value.0 .0, value.1)))
82 }
83
84 fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
86 Ok(self.0.current()?.map(|(k, _)| k.0))
87 }
88}
89
90#[derive(Debug)]
92pub struct DatabaseStorageTrieCursor<C> {
93 pub cursor: C,
95 hashed_address: B256,
97}
98
99impl<C> DatabaseStorageTrieCursor<C> {
100 pub const fn new(cursor: C, hashed_address: B256) -> Self {
102 Self { cursor, hashed_address }
103 }
104}
105
106impl<C> DatabaseStorageTrieCursor<C>
107where
108 C: DbCursorRO<tables::StoragesTrie>
109 + DbCursorRW<tables::StoragesTrie>
110 + DbDupCursorRO<tables::StoragesTrie>
111 + DbDupCursorRW<tables::StoragesTrie>,
112{
113 pub fn write_storage_trie_updates(
115 &mut self,
116 updates: &StorageTrieUpdates,
117 ) -> Result<usize, DatabaseError> {
118 if updates.is_deleted() && self.cursor.seek_exact(self.hashed_address)?.is_some() {
120 self.cursor.delete_current_duplicates()?;
121 }
122
123 let mut storage_updates = updates
125 .removed_nodes_ref()
126 .iter()
127 .filter_map(|n| (!updates.storage_nodes_ref().contains_key(n)).then_some((n, None)))
128 .collect::<Vec<_>>();
129 storage_updates.extend(
130 updates.storage_nodes_ref().iter().map(|(nibbles, node)| (nibbles, Some(node))),
131 );
132
133 storage_updates.sort_unstable_by(|a, b| a.0.cmp(b.0));
135
136 let mut num_entries = 0;
137 for (nibbles, maybe_updated) in storage_updates.into_iter().filter(|(n, _)| !n.is_empty()) {
138 num_entries += 1;
139 let nibbles = StoredNibblesSubKey(*nibbles);
140 if self
142 .cursor
143 .seek_by_key_subkey(self.hashed_address, nibbles.clone())?
144 .filter(|e| e.nibbles == nibbles)
145 .is_some()
146 {
147 self.cursor.delete_current()?;
148 }
149
150 if let Some(node) = maybe_updated {
152 self.cursor.upsert(
153 self.hashed_address,
154 &StorageTrieEntry { nibbles, node: node.clone() },
155 )?;
156 }
157 }
158
159 Ok(num_entries)
160 }
161}
162
163impl<C> TrieCursor for DatabaseStorageTrieCursor<C>
164where
165 C: DbCursorRO<tables::StoragesTrie> + DbDupCursorRO<tables::StoragesTrie> + Send + Sync,
166{
167 fn seek_exact(
169 &mut self,
170 key: Nibbles,
171 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
172 Ok(self
173 .cursor
174 .seek_by_key_subkey(self.hashed_address, StoredNibblesSubKey(key))?
175 .filter(|e| e.nibbles == StoredNibblesSubKey(key))
176 .map(|value| (value.nibbles.0, value.node)))
177 }
178
179 fn seek(
181 &mut self,
182 key: Nibbles,
183 ) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
184 Ok(self
185 .cursor
186 .seek_by_key_subkey(self.hashed_address, StoredNibblesSubKey(key))?
187 .map(|value| (value.nibbles.0, value.node)))
188 }
189
190 fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
192 Ok(self.cursor.next_dup()?.map(|(_, v)| (v.nibbles.0, v.node)))
193 }
194
195 fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
197 Ok(self.cursor.current()?.map(|(_, v)| v.nibbles.0))
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use alloy_primitives::hex_literal::hex;
205 use reth_db_api::{cursor::DbCursorRW, transaction::DbTxMut};
206 use reth_provider::test_utils::create_test_provider_factory;
207
208 #[test]
209 fn test_account_trie_order() {
210 let factory = create_test_provider_factory();
211 let provider = factory.provider_rw().unwrap();
212 let mut cursor = provider.tx_ref().cursor_write::<tables::AccountsTrie>().unwrap();
213
214 let data = vec![
215 hex!("0303040e").to_vec(),
216 hex!("030305").to_vec(),
217 hex!("03030500").to_vec(),
218 hex!("0303050a").to_vec(),
219 ];
220
221 for key in data.clone() {
222 cursor
223 .upsert(
224 key.into(),
225 &BranchNodeCompact::new(
226 0b0000_0010_0000_0001,
227 0b0000_0010_0000_0001,
228 0,
229 Vec::default(),
230 None,
231 ),
232 )
233 .unwrap();
234 }
235
236 let db_data = cursor.walk_range(..).unwrap().collect::<Result<Vec<_>, _>>().unwrap();
237 assert_eq!(db_data[0].0 .0.to_vec(), data[0]);
238 assert_eq!(db_data[1].0 .0.to_vec(), data[1]);
239 assert_eq!(db_data[2].0 .0.to_vec(), data[2]);
240 assert_eq!(db_data[3].0 .0.to_vec(), data[3]);
241
242 assert_eq!(
243 cursor.seek(hex!("0303040f").to_vec().into()).unwrap().map(|(k, _)| k.0.to_vec()),
244 Some(data[1].clone())
245 );
246 }
247
248 #[test]
250 fn test_storage_cursor_abstraction() {
251 let factory = create_test_provider_factory();
252 let provider = factory.provider_rw().unwrap();
253 let mut cursor = provider.tx_ref().cursor_dup_write::<tables::StoragesTrie>().unwrap();
254
255 let hashed_address = B256::random();
256 let key = StoredNibblesSubKey::from(vec![0x2, 0x3]);
257 let value = BranchNodeCompact::new(1, 1, 1, vec![B256::random()], None);
258
259 cursor
260 .upsert(hashed_address, &StorageTrieEntry { nibbles: key.clone(), node: value.clone() })
261 .unwrap();
262
263 let mut cursor = DatabaseStorageTrieCursor::new(cursor, hashed_address);
264 assert_eq!(cursor.seek(key.into()).unwrap().unwrap().1, value);
265 }
266}