reth_provider/providers/state/
overlay.rs1use alloy_primitives::{BlockNumber, B256};
2use metrics::{Counter, Histogram};
3use parking_lot::RwLock;
4use reth_db_api::DatabaseError;
5use reth_errors::{ProviderError, ProviderResult};
6use reth_metrics::Metrics;
7use reth_prune_types::PruneSegment;
8use reth_stages_types::StageId;
9use reth_storage_api::{
10 BlockNumReader, DBProvider, DatabaseProviderFactory, DatabaseProviderROFactory,
11 PruneCheckpointReader, StageCheckpointReader, TrieReader,
12};
13use reth_trie::{
14 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
15 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
16 updates::TrieUpdatesSorted,
17 HashedPostState, HashedPostStateSorted, KeccakKeyHasher,
18};
19use reth_trie_db::{
20 DatabaseHashedCursorFactory, DatabaseHashedPostState, DatabaseTrieCursorFactory,
21};
22use std::{
23 collections::{hash_map::Entry, HashMap},
24 sync::Arc,
25 time::{Duration, Instant},
26};
27use tracing::{debug, debug_span, instrument};
28
29#[derive(Clone, Metrics)]
31#[metrics(scope = "storage.providers.overlay")]
32pub(crate) struct OverlayStateProviderMetrics {
33 create_provider_duration: Histogram,
35 retrieve_trie_reverts_duration: Histogram,
37 retrieve_hashed_state_reverts_duration: Histogram,
39 trie_updates_size: Histogram,
41 hashed_state_size: Histogram,
43 database_provider_ro_duration: Histogram,
45 overlay_cache_misses: Counter,
47}
48
49#[derive(Debug, Clone)]
51struct Overlay {
52 trie_updates: Arc<TrieUpdatesSorted>,
53 hashed_post_state: Arc<HashedPostStateSorted>,
54}
55
56#[derive(Debug, Clone)]
61pub struct OverlayStateProviderFactory<F> {
62 factory: F,
64 block_hash: Option<B256>,
66 trie_overlay: Option<Arc<TrieUpdatesSorted>>,
68 hashed_state_overlay: Option<Arc<HashedPostStateSorted>>,
70 metrics: OverlayStateProviderMetrics,
72 overlay_cache: Arc<RwLock<HashMap<BlockNumber, Overlay>>>,
75}
76
77impl<F> OverlayStateProviderFactory<F> {
78 pub fn new(factory: F) -> Self {
80 Self {
81 factory,
82 block_hash: None,
83 trie_overlay: None,
84 hashed_state_overlay: None,
85 metrics: OverlayStateProviderMetrics::default(),
86 overlay_cache: Default::default(),
87 }
88 }
89
90 pub const fn with_block_hash(mut self, block_hash: Option<B256>) -> Self {
93 self.block_hash = block_hash;
94 self
95 }
96
97 pub fn with_trie_overlay(mut self, trie_overlay: Option<Arc<TrieUpdatesSorted>>) -> Self {
101 self.trie_overlay = trie_overlay;
102 self
103 }
104
105 pub fn with_hashed_state_overlay(
109 mut self,
110 hashed_state_overlay: Option<Arc<HashedPostStateSorted>>,
111 ) -> Self {
112 self.hashed_state_overlay = hashed_state_overlay;
113 self
114 }
115}
116
117impl<F> OverlayStateProviderFactory<F>
118where
119 F: DatabaseProviderFactory,
120 F::Provider: TrieReader + StageCheckpointReader + PruneCheckpointReader + BlockNumReader,
121{
122 fn get_requested_block_number(
124 &self,
125 provider: &F::Provider,
126 ) -> ProviderResult<Option<BlockNumber>> {
127 if let Some(block_hash) = self.block_hash {
128 Ok(Some(
129 provider
130 .convert_hash_or_number(block_hash.into())?
131 .ok_or_else(|| ProviderError::BlockHashNotFound(block_hash))?,
132 ))
133 } else {
134 Ok(None)
135 }
136 }
137
138 fn get_db_tip_block_number(&self, provider: &F::Provider) -> ProviderResult<BlockNumber> {
141 provider
142 .get_stage_checkpoint(StageId::MerkleChangeSets)?
143 .as_ref()
144 .map(|chk| chk.block_number)
145 .ok_or_else(|| ProviderError::InsufficientChangesets { requested: 0, available: 0..=0 })
146 }
147
148 fn reverts_required(
155 &self,
156 provider: &F::Provider,
157 db_tip_block: BlockNumber,
158 requested_block: BlockNumber,
159 ) -> ProviderResult<bool> {
160 if db_tip_block == requested_block {
163 return Ok(false)
164 }
165
166 let prune_checkpoint = provider.get_prune_checkpoint(PruneSegment::MerkleChangeSets)?;
169
170 let lower_bound = prune_checkpoint
179 .and_then(|chk| chk.block_number)
180 .map(|block_number| block_number + 1)
181 .unwrap_or_default();
182
183 let available_range = lower_bound..=db_tip_block;
184
185 if !available_range.contains(&requested_block) {
187 return Err(ProviderError::InsufficientChangesets {
188 requested: requested_block,
189 available: available_range,
190 });
191 }
192
193 Ok(true)
194 }
195
196 #[instrument(
198 level = "debug",
199 target = "providers::state::overlay",
200 skip_all,
201 fields(db_tip_block)
202 )]
203 fn calculate_overlay(
204 &self,
205 provider: &F::Provider,
206 db_tip_block: BlockNumber,
207 ) -> ProviderResult<Overlay> {
208 let retrieve_trie_reverts_duration;
211 let retrieve_hashed_state_reverts_duration;
212 let trie_updates_total_len;
213 let hashed_state_updates_total_len;
214
215 let (trie_updates, hashed_post_state) = if let Some(from_block) =
217 self.get_requested_block_number(provider)? &&
218 self.reverts_required(provider, db_tip_block, from_block)?
219 {
220 let mut trie_reverts = {
222 let _guard =
223 debug_span!(target: "providers::state::overlay", "Retrieving trie reverts")
224 .entered();
225
226 let start = Instant::now();
227 let res = provider.trie_reverts(from_block + 1)?;
228 retrieve_trie_reverts_duration = start.elapsed();
229 res
230 };
231
232 let mut hashed_state_reverts = {
234 let _guard = debug_span!(target: "providers::state::overlay", "Retrieving hashed state reverts").entered();
235
236 let start = Instant::now();
237 let res = HashedPostState::from_reverts::<KeccakKeyHasher>(
240 provider.tx_ref(),
241 from_block + 1..,
242 )?
243 .into_sorted();
244 retrieve_hashed_state_reverts_duration = start.elapsed();
245 res
246 };
247
248 let trie_updates = match self.trie_overlay.as_ref() {
251 Some(trie_overlay) if trie_reverts.is_empty() => Arc::clone(trie_overlay),
252 Some(trie_overlay) => {
253 trie_reverts.extend_ref(trie_overlay);
254 Arc::new(trie_reverts)
255 }
256 None => Arc::new(trie_reverts),
257 };
258
259 let hashed_state_updates = match self.hashed_state_overlay.as_ref() {
260 Some(hashed_state_overlay) if hashed_state_reverts.is_empty() => {
261 Arc::clone(hashed_state_overlay)
262 }
263 Some(hashed_state_overlay) => {
264 hashed_state_reverts.extend_ref(hashed_state_overlay);
265 Arc::new(hashed_state_reverts)
266 }
267 None => Arc::new(hashed_state_reverts),
268 };
269
270 trie_updates_total_len = trie_updates.total_len();
271 hashed_state_updates_total_len = hashed_state_updates.total_len();
272
273 debug!(
274 target: "providers::state::overlay",
275 block_hash = ?self.block_hash,
276 ?from_block,
277 num_trie_updates = ?trie_updates_total_len,
278 num_state_updates = ?hashed_state_updates_total_len,
279 "Reverted to target block",
280 );
281
282 (trie_updates, hashed_state_updates)
283 } else {
284 let trie_updates =
286 self.trie_overlay.clone().unwrap_or_else(|| Arc::new(TrieUpdatesSorted::default()));
287 let hashed_state = self
288 .hashed_state_overlay
289 .clone()
290 .unwrap_or_else(|| Arc::new(HashedPostStateSorted::default()));
291
292 retrieve_trie_reverts_duration = Duration::ZERO;
293 retrieve_hashed_state_reverts_duration = Duration::ZERO;
294 trie_updates_total_len = trie_updates.total_len();
295 hashed_state_updates_total_len = hashed_state.total_len();
296
297 (trie_updates, hashed_state)
298 };
299
300 self.metrics
302 .retrieve_trie_reverts_duration
303 .record(retrieve_trie_reverts_duration.as_secs_f64());
304 self.metrics
305 .retrieve_hashed_state_reverts_duration
306 .record(retrieve_hashed_state_reverts_duration.as_secs_f64());
307 self.metrics.trie_updates_size.record(trie_updates_total_len as f64);
308 self.metrics.hashed_state_size.record(hashed_state_updates_total_len as f64);
309
310 Ok(Overlay { trie_updates, hashed_post_state })
311 }
312
313 #[instrument(level = "debug", target = "providers::state::overlay", skip_all)]
316 fn get_overlay(&self, provider: &F::Provider) -> ProviderResult<Overlay> {
317 if self.block_hash.is_none() {
320 let trie_updates =
321 self.trie_overlay.clone().unwrap_or_else(|| Arc::new(TrieUpdatesSorted::default()));
322 let hashed_post_state = self
323 .hashed_state_overlay
324 .clone()
325 .unwrap_or_else(|| Arc::new(HashedPostStateSorted::default()));
326 return Ok(Overlay { trie_updates, hashed_post_state })
327 }
328
329 let db_tip_block = self.get_db_tip_block_number(provider)?;
330
331 if let Some(overlay) = self.overlay_cache.as_ref().read().get(&db_tip_block) {
333 return Ok(overlay.clone());
334 }
335
336 let mut cache_miss = false;
340 let overlay = match self.overlay_cache.as_ref().write().entry(db_tip_block) {
341 Entry::Occupied(entry) => entry.get().clone(),
342 Entry::Vacant(entry) => {
343 cache_miss = true;
344 let overlay = self.calculate_overlay(provider, db_tip_block)?;
345 entry.insert(overlay.clone());
346 overlay
347 }
348 };
349
350 if cache_miss {
351 self.metrics.overlay_cache_misses.increment(1);
352 }
353
354 Ok(overlay)
355 }
356}
357
358impl<F> DatabaseProviderROFactory for OverlayStateProviderFactory<F>
359where
360 F: DatabaseProviderFactory,
361 F::Provider: TrieReader + StageCheckpointReader + PruneCheckpointReader + BlockNumReader,
362{
363 type Provider = OverlayStateProvider<F::Provider>;
364
365 #[instrument(level = "debug", target = "providers::state::overlay", skip_all)]
367 fn database_provider_ro(&self) -> ProviderResult<OverlayStateProvider<F::Provider>> {
368 let overall_start = Instant::now();
369
370 let provider = {
372 let _guard =
373 debug_span!(target: "providers::state::overlay", "Creating db provider").entered();
374
375 let start = Instant::now();
376 let res = self.factory.database_provider_ro()?;
377 self.metrics.create_provider_duration.record(start.elapsed());
378 res
379 };
380
381 let Overlay { trie_updates, hashed_post_state } = self.get_overlay(&provider)?;
382
383 self.metrics.database_provider_ro_duration.record(overall_start.elapsed());
384 Ok(OverlayStateProvider::new(provider, trie_updates, hashed_post_state))
385 }
386}
387
388#[derive(Debug)]
394pub struct OverlayStateProvider<Provider: DBProvider> {
395 provider: Provider,
396 trie_updates: Arc<TrieUpdatesSorted>,
397 hashed_post_state: Arc<HashedPostStateSorted>,
398}
399
400impl<Provider> OverlayStateProvider<Provider>
401where
402 Provider: DBProvider,
403{
404 pub const fn new(
407 provider: Provider,
408 trie_updates: Arc<TrieUpdatesSorted>,
409 hashed_post_state: Arc<HashedPostStateSorted>,
410 ) -> Self {
411 Self { provider, trie_updates, hashed_post_state }
412 }
413}
414
415impl<Provider> TrieCursorFactory for OverlayStateProvider<Provider>
416where
417 Provider: DBProvider,
418{
419 type AccountTrieCursor<'a>
420 = <InMemoryTrieCursorFactory<
421 DatabaseTrieCursorFactory<&'a Provider::Tx>,
422 &'a TrieUpdatesSorted,
423 > as TrieCursorFactory>::AccountTrieCursor<'a>
424 where
425 Self: 'a;
426
427 type StorageTrieCursor<'a>
428 = <InMemoryTrieCursorFactory<
429 DatabaseTrieCursorFactory<&'a Provider::Tx>,
430 &'a TrieUpdatesSorted,
431 > as TrieCursorFactory>::StorageTrieCursor<'a>
432 where
433 Self: 'a;
434
435 fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor<'_>, DatabaseError> {
436 let db_trie_cursor_factory = DatabaseTrieCursorFactory::new(self.provider.tx_ref());
437 let trie_cursor_factory =
438 InMemoryTrieCursorFactory::new(db_trie_cursor_factory, self.trie_updates.as_ref());
439 trie_cursor_factory.account_trie_cursor()
440 }
441
442 fn storage_trie_cursor(
443 &self,
444 hashed_address: B256,
445 ) -> Result<Self::StorageTrieCursor<'_>, DatabaseError> {
446 let db_trie_cursor_factory = DatabaseTrieCursorFactory::new(self.provider.tx_ref());
447 let trie_cursor_factory =
448 InMemoryTrieCursorFactory::new(db_trie_cursor_factory, self.trie_updates.as_ref());
449 trie_cursor_factory.storage_trie_cursor(hashed_address)
450 }
451}
452
453impl<Provider> HashedCursorFactory for OverlayStateProvider<Provider>
454where
455 Provider: DBProvider,
456{
457 type AccountCursor<'a>
458 = <HashedPostStateCursorFactory<
459 DatabaseHashedCursorFactory<&'a Provider::Tx>,
460 &'a Arc<HashedPostStateSorted>,
461 > as HashedCursorFactory>::AccountCursor<'a>
462 where
463 Self: 'a;
464
465 type StorageCursor<'a>
466 = <HashedPostStateCursorFactory<
467 DatabaseHashedCursorFactory<&'a Provider::Tx>,
468 &'a Arc<HashedPostStateSorted>,
469 > as HashedCursorFactory>::StorageCursor<'a>
470 where
471 Self: 'a;
472
473 fn hashed_account_cursor(&self) -> Result<Self::AccountCursor<'_>, DatabaseError> {
474 let db_hashed_cursor_factory = DatabaseHashedCursorFactory::new(self.provider.tx_ref());
475 let hashed_cursor_factory =
476 HashedPostStateCursorFactory::new(db_hashed_cursor_factory, &self.hashed_post_state);
477 hashed_cursor_factory.hashed_account_cursor()
478 }
479
480 fn hashed_storage_cursor(
481 &self,
482 hashed_address: B256,
483 ) -> Result<Self::StorageCursor<'_>, DatabaseError> {
484 let db_hashed_cursor_factory = DatabaseHashedCursorFactory::new(self.provider.tx_ref());
485 let hashed_cursor_factory =
486 HashedPostStateCursorFactory::new(db_hashed_cursor_factory, &self.hashed_post_state);
487 hashed_cursor_factory.hashed_storage_cursor(hashed_address)
488 }
489}