reth_provider/providers/state/
overlay.rs1use alloy_primitives::{BlockNumber, B256};
2use reth_db_api::DatabaseError;
3use reth_errors::{ProviderError, ProviderResult};
4use reth_prune_types::PruneSegment;
5use reth_stages_types::StageId;
6use reth_storage_api::{
7 BlockNumReader, DBProvider, DatabaseProviderFactory, DatabaseProviderROFactory,
8 PruneCheckpointReader, StageCheckpointReader, TrieReader,
9};
10use reth_trie::{
11 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
12 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
13 updates::TrieUpdatesSorted,
14 HashedPostState, HashedPostStateSorted, KeccakKeyHasher,
15};
16use reth_trie_db::{
17 DatabaseHashedCursorFactory, DatabaseHashedPostState, DatabaseTrieCursorFactory,
18};
19use std::sync::Arc;
20use tracing::debug;
21
22#[derive(Debug, Clone)]
27pub struct OverlayStateProviderFactory<F> {
28 factory: F,
30 block_hash: Option<B256>,
32 trie_overlay: Option<Arc<TrieUpdatesSorted>>,
34 hashed_state_overlay: Option<Arc<HashedPostStateSorted>>,
36}
37
38impl<F> OverlayStateProviderFactory<F> {
39 pub const fn new(factory: F) -> Self {
41 Self { factory, block_hash: None, trie_overlay: None, hashed_state_overlay: None }
42 }
43
44 pub const fn with_block_hash(mut self, block_hash: Option<B256>) -> Self {
47 self.block_hash = block_hash;
48 self
49 }
50
51 pub fn with_trie_overlay(mut self, trie_overlay: Option<Arc<TrieUpdatesSorted>>) -> Self {
55 self.trie_overlay = trie_overlay;
56 self
57 }
58
59 pub fn with_hashed_state_overlay(
63 mut self,
64 hashed_state_overlay: Option<Arc<HashedPostStateSorted>>,
65 ) -> Self {
66 self.hashed_state_overlay = hashed_state_overlay;
67 self
68 }
69}
70
71impl<F> OverlayStateProviderFactory<F>
72where
73 F: DatabaseProviderFactory,
74 F::Provider: TrieReader + StageCheckpointReader + PruneCheckpointReader + BlockNumReader,
75{
76 fn get_block_number(&self, provider: &F::Provider) -> ProviderResult<Option<BlockNumber>> {
78 if let Some(block_hash) = self.block_hash {
79 Ok(Some(
80 provider
81 .convert_hash_or_number(block_hash.into())?
82 .ok_or_else(|| ProviderError::BlockHashNotFound(block_hash))?,
83 ))
84 } else {
85 Ok(None)
86 }
87 }
88
89 fn reverts_required(
96 &self,
97 provider: &F::Provider,
98 requested_block: BlockNumber,
99 ) -> ProviderResult<bool> {
100 let stage_checkpoint = provider.get_stage_checkpoint(StageId::MerkleChangeSets)?;
102 let prune_checkpoint = provider.get_prune_checkpoint(PruneSegment::MerkleChangeSets)?;
103
104 let upper_bound =
106 stage_checkpoint.as_ref().map(|chk| chk.block_number).ok_or_else(|| {
107 ProviderError::InsufficientChangesets {
108 requested: requested_block,
109 available: 0..=0,
110 }
111 })?;
112
113 if upper_bound == requested_block {
116 return Ok(false)
117 }
118
119 let lower_bound = prune_checkpoint
123 .and_then(|chk| chk.block_number)
124 .map(|block_number| block_number + 1)
125 .ok_or_else(|| ProviderError::InsufficientChangesets {
126 requested: requested_block,
127 available: 0..=upper_bound,
128 })?;
129
130 let available_range = lower_bound..=upper_bound;
131
132 if !available_range.contains(&requested_block) {
134 return Err(ProviderError::InsufficientChangesets {
135 requested: requested_block,
136 available: available_range,
137 });
138 }
139
140 Ok(true)
141 }
142}
143
144impl<F> DatabaseProviderROFactory for OverlayStateProviderFactory<F>
145where
146 F: DatabaseProviderFactory,
147 F::Provider: TrieReader + StageCheckpointReader + PruneCheckpointReader + BlockNumReader,
148{
149 type Provider = OverlayStateProvider<F::Provider>;
150
151 fn database_provider_ro(&self) -> ProviderResult<OverlayStateProvider<F::Provider>> {
153 let provider = self.factory.database_provider_ro()?;
155
156 let (trie_updates, hashed_state) = if let Some(from_block) =
158 self.get_block_number(&provider)? &&
159 self.reverts_required(&provider, from_block)?
160 {
161 let mut trie_reverts = provider.trie_reverts(from_block + 1)?;
163
164 let mut hashed_state_reverts = HashedPostState::from_reverts::<KeccakKeyHasher>(
169 provider.tx_ref(),
170 from_block + 1..,
171 )?
172 .into_sorted();
173
174 let trie_updates = match self.trie_overlay.as_ref() {
177 Some(trie_overlay) if trie_reverts.is_empty() => Arc::clone(trie_overlay),
178 Some(trie_overlay) => {
179 trie_reverts.extend_ref(trie_overlay);
180 Arc::new(trie_reverts)
181 }
182 None => Arc::new(trie_reverts),
183 };
184
185 let hashed_state_updates = match self.hashed_state_overlay.as_ref() {
186 Some(hashed_state_overlay) if hashed_state_reverts.is_empty() => {
187 Arc::clone(hashed_state_overlay)
188 }
189 Some(hashed_state_overlay) => {
190 hashed_state_reverts.extend_ref(hashed_state_overlay);
191 Arc::new(hashed_state_reverts)
192 }
193 None => Arc::new(hashed_state_reverts),
194 };
195
196 debug!(
197 target: "providers::state::overlay",
198 block_hash = ?self.block_hash,
199 ?from_block,
200 num_trie_updates = ?trie_updates.total_len(),
201 num_state_updates = ?hashed_state_updates.total_len(),
202 "Reverted to target block",
203 );
204
205 (trie_updates, hashed_state_updates)
206 } else {
207 let trie_updates =
209 self.trie_overlay.clone().unwrap_or_else(|| Arc::new(TrieUpdatesSorted::default()));
210 let hashed_state = self
211 .hashed_state_overlay
212 .clone()
213 .unwrap_or_else(|| Arc::new(HashedPostStateSorted::default()));
214
215 (trie_updates, hashed_state)
216 };
217
218 Ok(OverlayStateProvider::new(provider, trie_updates, hashed_state))
219 }
220}
221
222#[derive(Debug)]
228pub struct OverlayStateProvider<Provider: DBProvider> {
229 provider: Provider,
230 trie_updates: Arc<TrieUpdatesSorted>,
231 hashed_post_state: Arc<HashedPostStateSorted>,
232}
233
234impl<Provider> OverlayStateProvider<Provider>
235where
236 Provider: DBProvider,
237{
238 pub const fn new(
241 provider: Provider,
242 trie_updates: Arc<TrieUpdatesSorted>,
243 hashed_post_state: Arc<HashedPostStateSorted>,
244 ) -> Self {
245 Self { provider, trie_updates, hashed_post_state }
246 }
247}
248
249impl<Provider> TrieCursorFactory for OverlayStateProvider<Provider>
250where
251 Provider: DBProvider,
252{
253 type AccountTrieCursor<'a>
254 = <InMemoryTrieCursorFactory<
255 DatabaseTrieCursorFactory<&'a Provider::Tx>,
256 &'a TrieUpdatesSorted,
257 > as TrieCursorFactory>::AccountTrieCursor<'a>
258 where
259 Self: 'a;
260
261 type StorageTrieCursor<'a>
262 = <InMemoryTrieCursorFactory<
263 DatabaseTrieCursorFactory<&'a Provider::Tx>,
264 &'a TrieUpdatesSorted,
265 > as TrieCursorFactory>::StorageTrieCursor<'a>
266 where
267 Self: 'a;
268
269 fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
270 let db_trie_cursor_factory = DatabaseTrieCursorFactory::new(self.provider.tx_ref());
271 let trie_cursor_factory =
272 InMemoryTrieCursorFactory::new(db_trie_cursor_factory, self.trie_updates.as_ref());
273 trie_cursor_factory.account_trie_cursor()
274 }
275
276 fn storage_trie_cursor(
277 &self,
278 hashed_address: B256,
279 ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
280 let db_trie_cursor_factory = DatabaseTrieCursorFactory::new(self.provider.tx_ref());
281 let trie_cursor_factory =
282 InMemoryTrieCursorFactory::new(db_trie_cursor_factory, self.trie_updates.as_ref());
283 trie_cursor_factory.storage_trie_cursor(hashed_address)
284 }
285}
286
287impl<Provider> HashedCursorFactory for OverlayStateProvider<Provider>
288where
289 Provider: DBProvider,
290{
291 type AccountCursor<'a>
292 = <HashedPostStateCursorFactory<
293 DatabaseHashedCursorFactory<&'a Provider::Tx>,
294 &'a Arc<HashedPostStateSorted>,
295 > as HashedCursorFactory>::AccountCursor<'a>
296 where
297 Self: 'a;
298
299 type StorageCursor<'a>
300 = <HashedPostStateCursorFactory<
301 DatabaseHashedCursorFactory<&'a Provider::Tx>,
302 &'a Arc<HashedPostStateSorted>,
303 > as HashedCursorFactory>::StorageCursor<'a>
304 where
305 Self: 'a;
306
307 fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
308 let db_hashed_cursor_factory = DatabaseHashedCursorFactory::new(self.provider.tx_ref());
309 let hashed_cursor_factory =
310 HashedPostStateCursorFactory::new(db_hashed_cursor_factory, &self.hashed_post_state);
311 hashed_cursor_factory.hashed_account_cursor()
312 }
313
314 fn hashed_storage_cursor(
315 &self,
316 hashed_address: B256,
317 ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
318 let db_hashed_cursor_factory = DatabaseHashedCursorFactory::new(self.provider.tx_ref());
319 let hashed_cursor_factory =
320 HashedPostStateCursorFactory::new(db_hashed_cursor_factory, &self.hashed_post_state);
321 hashed_cursor_factory.hashed_storage_cursor(hashed_address)
322 }
323}