1use crate::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory, PrefixSetLoader};
2use alloy_primitives::{
3 map::{AddressMap, B256Map},
4 BlockNumber, B256, U256,
5};
6use reth_db_api::{
7 cursor::DbCursorRO,
8 models::{AccountBeforeTx, BlockNumberAddress, BlockNumberAddressRange},
9 tables,
10 transaction::DbTx,
11 DatabaseError,
12};
13use reth_execution_errors::StateRootError;
14use reth_trie::{
15 hashed_cursor::HashedPostStateCursorFactory, trie_cursor::InMemoryTrieCursorFactory,
16 updates::TrieUpdates, HashedPostState, HashedStorage, KeccakKeyHasher, KeyHasher, StateRoot,
17 StateRootProgress, TrieInput,
18};
19use std::{
20 collections::HashMap,
21 ops::{RangeBounds, RangeInclusive},
22};
23use tracing::debug;
24
25pub trait DatabaseStateRoot<'a, TX>: Sized {
27 fn from_tx(tx: &'a TX) -> Self;
29
30 fn incremental_root_calculator(
37 tx: &'a TX,
38 range: RangeInclusive<BlockNumber>,
39 ) -> Result<Self, StateRootError>;
40
41 fn incremental_root(
48 tx: &'a TX,
49 range: RangeInclusive<BlockNumber>,
50 ) -> Result<B256, StateRootError>;
51
52 fn incremental_root_with_updates(
61 tx: &'a TX,
62 range: RangeInclusive<BlockNumber>,
63 ) -> Result<(B256, TrieUpdates), StateRootError>;
64
65 fn incremental_root_with_progress(
72 tx: &'a TX,
73 range: RangeInclusive<BlockNumber>,
74 ) -> Result<StateRootProgress, StateRootError>;
75
76 fn overlay_root(tx: &'a TX, post_state: HashedPostState) -> Result<B256, StateRootError>;
109
110 fn overlay_root_with_updates(
113 tx: &'a TX,
114 post_state: HashedPostState,
115 ) -> Result<(B256, TrieUpdates), StateRootError>;
116
117 fn overlay_root_from_nodes(tx: &'a TX, input: TrieInput) -> Result<B256, StateRootError>;
119
120 fn overlay_root_from_nodes_with_updates(
123 tx: &'a TX,
124 input: TrieInput,
125 ) -> Result<(B256, TrieUpdates), StateRootError>;
126}
127
128pub trait DatabaseHashedPostState<TX>: Sized {
130 fn from_reverts<KH: KeyHasher>(
133 tx: &TX,
134 range: impl RangeBounds<BlockNumber>,
135 ) -> Result<Self, DatabaseError>;
136}
137
138impl<'a, TX: DbTx> DatabaseStateRoot<'a, TX>
139 for StateRoot<DatabaseTrieCursorFactory<&'a TX>, DatabaseHashedCursorFactory<&'a TX>>
140{
141 fn from_tx(tx: &'a TX) -> Self {
142 Self::new(DatabaseTrieCursorFactory::new(tx), DatabaseHashedCursorFactory::new(tx))
143 }
144
145 fn incremental_root_calculator(
146 tx: &'a TX,
147 range: RangeInclusive<BlockNumber>,
148 ) -> Result<Self, StateRootError> {
149 let loaded_prefix_sets = PrefixSetLoader::<_, KeccakKeyHasher>::new(tx).load(range)?;
150 Ok(Self::from_tx(tx).with_prefix_sets(loaded_prefix_sets))
151 }
152
153 fn incremental_root(
154 tx: &'a TX,
155 range: RangeInclusive<BlockNumber>,
156 ) -> Result<B256, StateRootError> {
157 debug!(target: "trie::loader", ?range, "incremental state root");
158 Self::incremental_root_calculator(tx, range)?.root()
159 }
160
161 fn incremental_root_with_updates(
162 tx: &'a TX,
163 range: RangeInclusive<BlockNumber>,
164 ) -> Result<(B256, TrieUpdates), StateRootError> {
165 debug!(target: "trie::loader", ?range, "incremental state root");
166 Self::incremental_root_calculator(tx, range)?.root_with_updates()
167 }
168
169 fn incremental_root_with_progress(
170 tx: &'a TX,
171 range: RangeInclusive<BlockNumber>,
172 ) -> Result<StateRootProgress, StateRootError> {
173 debug!(target: "trie::loader", ?range, "incremental state root with progress");
174 Self::incremental_root_calculator(tx, range)?.root_with_progress()
175 }
176
177 fn overlay_root(tx: &'a TX, post_state: HashedPostState) -> Result<B256, StateRootError> {
178 let prefix_sets = post_state.construct_prefix_sets().freeze();
179 let state_sorted = post_state.into_sorted();
180 StateRoot::new(
181 DatabaseTrieCursorFactory::new(tx),
182 HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
183 )
184 .with_prefix_sets(prefix_sets)
185 .root()
186 }
187
188 fn overlay_root_with_updates(
189 tx: &'a TX,
190 post_state: HashedPostState,
191 ) -> Result<(B256, TrieUpdates), StateRootError> {
192 let prefix_sets = post_state.construct_prefix_sets().freeze();
193 let state_sorted = post_state.into_sorted();
194 StateRoot::new(
195 DatabaseTrieCursorFactory::new(tx),
196 HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
197 )
198 .with_prefix_sets(prefix_sets)
199 .root_with_updates()
200 }
201
202 fn overlay_root_from_nodes(tx: &'a TX, input: TrieInput) -> Result<B256, StateRootError> {
203 let state_sorted = input.state.into_sorted();
204 let nodes_sorted = input.nodes.into_sorted();
205 StateRoot::new(
206 InMemoryTrieCursorFactory::new(DatabaseTrieCursorFactory::new(tx), &nodes_sorted),
207 HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
208 )
209 .with_prefix_sets(input.prefix_sets.freeze())
210 .root()
211 }
212
213 fn overlay_root_from_nodes_with_updates(
214 tx: &'a TX,
215 input: TrieInput,
216 ) -> Result<(B256, TrieUpdates), StateRootError> {
217 let state_sorted = input.state.into_sorted();
218 let nodes_sorted = input.nodes.into_sorted();
219 StateRoot::new(
220 InMemoryTrieCursorFactory::new(DatabaseTrieCursorFactory::new(tx), &nodes_sorted),
221 HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
222 )
223 .with_prefix_sets(input.prefix_sets.freeze())
224 .root_with_updates()
225 }
226}
227
228impl<TX: DbTx> DatabaseHashedPostState<TX> for HashedPostState {
229 fn from_reverts<KH: KeyHasher>(
230 tx: &TX,
231 range: impl RangeBounds<BlockNumber>,
232 ) -> Result<Self, DatabaseError> {
233 let account_range = (range.start_bound(), range.end_bound()); let mut accounts = HashMap::new();
236 let mut account_changesets_cursor = tx.cursor_read::<tables::AccountChangeSets>()?;
237 for entry in account_changesets_cursor.walk_range(account_range)? {
238 let (_, AccountBeforeTx { address, info }) = entry?;
239 accounts.entry(address).or_insert(info);
240 }
241
242 let storage_range: BlockNumberAddressRange = range.into();
244 let mut storages = AddressMap::<B256Map<U256>>::default();
245 let mut storage_changesets_cursor = tx.cursor_read::<tables::StorageChangeSets>()?;
246 for entry in storage_changesets_cursor.walk_range(storage_range)? {
247 let (BlockNumberAddress((_, address)), storage) = entry?;
248 let account_storage = storages.entry(address).or_default();
249 account_storage.entry(storage.key).or_insert(storage.value);
250 }
251
252 let hashed_accounts =
253 accounts.into_iter().map(|(address, info)| (KH::hash_key(address), info)).collect();
254
255 let hashed_storages = storages
256 .into_iter()
257 .map(|(address, storage)| {
258 (
259 KH::hash_key(address),
260 HashedStorage::from_iter(
261 false,
265 storage.into_iter().map(|(slot, value)| (KH::hash_key(slot), value)),
266 ),
267 )
268 })
269 .collect();
270
271 Ok(Self { accounts: hashed_accounts, storages: hashed_storages })
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use alloy_primitives::{hex, map::HashMap, Address, U256};
279 use reth_db::test_utils::create_test_rw_db;
280 use reth_db_api::database::Database;
281 use reth_trie::KeccakKeyHasher;
282 use revm::state::AccountInfo;
283 use revm_database::BundleState;
284
285 #[test]
286 fn from_bundle_state_with_rayon() {
287 let address1 = Address::with_last_byte(1);
288 let address2 = Address::with_last_byte(2);
289 let slot1 = U256::from(1015);
290 let slot2 = U256::from(2015);
291
292 let account1 = AccountInfo { nonce: 1, ..Default::default() };
293 let account2 = AccountInfo { nonce: 2, ..Default::default() };
294
295 let bundle_state = BundleState::builder(2..=2)
296 .state_present_account_info(address1, account1)
297 .state_present_account_info(address2, account2)
298 .state_storage(address1, HashMap::from_iter([(slot1, (U256::ZERO, U256::from(10)))]))
299 .state_storage(address2, HashMap::from_iter([(slot2, (U256::ZERO, U256::from(20)))]))
300 .build();
301 assert_eq!(bundle_state.reverts.len(), 1);
302
303 let post_state = HashedPostState::from_bundle_state::<KeccakKeyHasher>(&bundle_state.state);
304 assert_eq!(post_state.accounts.len(), 2);
305 assert_eq!(post_state.storages.len(), 2);
306
307 let db = create_test_rw_db();
308 let tx = db.tx().expect("failed to create transaction");
309 assert_eq!(
310 StateRoot::overlay_root(&tx, post_state).unwrap(),
311 hex!("b464525710cafcf5d4044ac85b72c08b1e76231b8d91f288fe438cc41d8eaafd")
312 );
313 }
314}