reth_stages/stages/
utils.rs

1//! Utils for `stages`.
2use alloy_primitives::{Address, BlockNumber, TxNumber};
3use reth_config::config::EtlConfig;
4use reth_db_api::{
5    cursor::{DbCursorRO, DbCursorRW},
6    models::{sharded_key::NUM_OF_INDICES_IN_SHARD, AccountBeforeTx, ShardedKey},
7    table::{Decompress, Table},
8    transaction::{DbTx, DbTxMut},
9    BlockNumberList, DatabaseError,
10};
11use reth_etl::Collector;
12use reth_provider::{
13    providers::StaticFileProvider, to_range, BlockReader, DBProvider, ProviderError,
14    StaticFileProviderFactory,
15};
16use reth_stages_api::StageError;
17use reth_static_file_types::StaticFileSegment;
18use reth_storage_api::ChangeSetReader;
19use std::{collections::HashMap, hash::Hash, ops::RangeBounds};
20use tracing::info;
21
22/// Number of blocks before pushing indices from cache to [`Collector`]
23const DEFAULT_CACHE_THRESHOLD: u64 = 100_000;
24
25/// Collects all history (`H`) indices for a range of changesets (`CS`) and stores them in a
26/// [`Collector`].
27///
28/// ## Process
29/// The function utilizes a `HashMap` cache with a structure of `PartialKey` (`P`) (Address or
30/// Address.StorageKey) to `BlockNumberList`. When the cache exceeds its capacity, its contents are
31/// moved to a [`Collector`]. Here, each entry's key is a concatenation of `PartialKey` and the
32/// highest block number in its list.
33///
34/// ## Example
35/// 1. Initial Cache State: `{ Address1: [1,2,3], ... }`
36/// 2. Cache is flushed to the `Collector`.
37/// 3. Updated Cache State: `{ Address1: [100,300], ... }`
38/// 4. Cache is flushed again.
39///
40/// As a result, the `Collector` will contain entries such as `(Address1.3, [1,2,3])` and
41/// `(Address1.300, [100,300])`. The entries may be stored across one or more files.
42pub(crate) fn collect_history_indices<Provider, CS, H, P>(
43    provider: &Provider,
44    range: impl RangeBounds<CS::Key>,
45    sharded_key_factory: impl Fn(P, BlockNumber) -> H::Key,
46    partial_key_factory: impl Fn((CS::Key, CS::Value)) -> (u64, P),
47    etl_config: &EtlConfig,
48) -> Result<Collector<H::Key, H::Value>, StageError>
49where
50    Provider: DBProvider,
51    CS: Table,
52    H: Table<Value = BlockNumberList>,
53    P: Copy + Eq + Hash,
54{
55    let mut changeset_cursor = provider.tx_ref().cursor_read::<CS>()?;
56
57    let mut collector = Collector::new(etl_config.file_size, etl_config.dir.clone());
58    let mut cache: HashMap<P, Vec<u64>> = HashMap::default();
59
60    let mut collect = |cache: &HashMap<P, Vec<u64>>| {
61        for (key, indices) in cache {
62            let last = indices.last().expect("qed");
63            collector.insert(
64                sharded_key_factory(*key, *last),
65                BlockNumberList::new_pre_sorted(indices.iter().copied()),
66            )?;
67        }
68        Ok::<(), StageError>(())
69    };
70
71    // observability
72    let total_changesets = provider.tx_ref().entries::<CS>()?;
73    let interval = (total_changesets / 1000).max(1);
74
75    let mut flush_counter = 0;
76    let mut current_block_number = u64::MAX;
77    for (idx, entry) in changeset_cursor.walk_range(range)?.enumerate() {
78        let (block_number, key) = partial_key_factory(entry?);
79        cache.entry(key).or_default().push(block_number);
80
81        if idx > 0 && idx.is_multiple_of(interval) && total_changesets > 1000 {
82            info!(target: "sync::stages::index_history", progress = %format!("{:.4}%", (idx as f64 / total_changesets as f64) * 100.0), "Collecting indices");
83        }
84
85        // Make sure we only flush the cache every DEFAULT_CACHE_THRESHOLD blocks.
86        if current_block_number != block_number {
87            current_block_number = block_number;
88            flush_counter += 1;
89            if flush_counter > DEFAULT_CACHE_THRESHOLD {
90                collect(&cache)?;
91                cache.clear();
92                flush_counter = 0;
93            }
94        }
95    }
96    collect(&cache)?;
97
98    Ok(collector)
99}
100
101/// Allows collecting indices from a cache with a custom insert fn
102fn collect_indices<F>(
103    cache: impl Iterator<Item = (Address, Vec<u64>)>,
104    mut insert_fn: F,
105) -> Result<(), StageError>
106where
107    F: FnMut(Address, Vec<u64>) -> Result<(), StageError>,
108{
109    for (address, indices) in cache {
110        insert_fn(address, indices)?
111    }
112    Ok::<(), StageError>(())
113}
114
115/// Collects account history indices using a provider that implements `ChangeSetReader`.
116pub(crate) fn collect_account_history_indices<Provider>(
117    provider: &Provider,
118    range: impl RangeBounds<BlockNumber>,
119    etl_config: &EtlConfig,
120) -> Result<Collector<ShardedKey<Address>, BlockNumberList>, StageError>
121where
122    Provider: DBProvider + ChangeSetReader + StaticFileProviderFactory,
123{
124    let mut collector = Collector::new(etl_config.file_size, etl_config.dir.clone());
125    let mut cache: HashMap<Address, Vec<u64>> = HashMap::default();
126
127    let mut insert_fn = |address: Address, indices: Vec<u64>| {
128        let last = indices.last().expect("qed");
129        collector.insert(
130            ShardedKey::new(address, *last),
131            BlockNumberList::new_pre_sorted(indices.into_iter()),
132        )?;
133        Ok::<(), StageError>(())
134    };
135
136    // Convert range bounds to concrete range
137    let range = to_range(range);
138
139    // Use the new walker for lazy iteration over static file changesets
140    let static_file_provider = provider.static_file_provider();
141
142    // Get total count for progress reporting
143    let total_changesets = static_file_provider.account_changeset_count()?;
144    let interval = (total_changesets / 1000).max(1);
145
146    let walker = static_file_provider.walk_account_changeset_range(range);
147
148    let mut flush_counter = 0;
149    let mut current_block_number = u64::MAX;
150
151    for (idx, changeset_result) in walker.enumerate() {
152        let (block_number, AccountBeforeTx { address, .. }) = changeset_result?;
153        cache.entry(address).or_default().push(block_number);
154
155        if idx > 0 && idx % interval == 0 && total_changesets > 1000 {
156            info!(target: "sync::stages::index_history", progress = %format!("{:.4}%", (idx as f64 / total_changesets as f64) * 100.0), "Collecting indices");
157        }
158
159        if block_number != current_block_number {
160            current_block_number = block_number;
161            flush_counter += 1;
162        }
163
164        if flush_counter > DEFAULT_CACHE_THRESHOLD {
165            collect_indices(cache.drain(), &mut insert_fn)?;
166            flush_counter = 0;
167        }
168    }
169    collect_indices(cache.into_iter(), insert_fn)?;
170
171    Ok(collector)
172}
173
174/// Given a [`Collector`] created by [`collect_history_indices`] it iterates all entries, loading
175/// the indices into the database in shards.
176///
177///  ## Process
178/// Iterates over elements, grouping indices by their partial keys (e.g., `Address` or
179/// `Address.StorageKey`). It flushes indices to disk when reaching a shard's max length
180/// (`NUM_OF_INDICES_IN_SHARD`) or when the partial key changes, ensuring the last previous partial
181/// key shard is stored.
182pub(crate) fn load_history_indices<Provider, H, P>(
183    provider: &Provider,
184    mut collector: Collector<H::Key, H::Value>,
185    append_only: bool,
186    sharded_key_factory: impl Clone + Fn(P, u64) -> <H as Table>::Key,
187    decode_key: impl Fn(Vec<u8>) -> Result<<H as Table>::Key, DatabaseError>,
188    get_partial: impl Fn(<H as Table>::Key) -> P,
189) -> Result<(), StageError>
190where
191    Provider: DBProvider<Tx: DbTxMut>,
192    H: Table<Value = BlockNumberList>,
193    P: Copy + Default + Eq,
194{
195    let mut write_cursor = provider.tx_ref().cursor_write::<H>()?;
196    let mut current_partial = P::default();
197    let mut current_list = Vec::<u64>::new();
198
199    // observability
200    let total_entries = collector.len();
201    let interval = (total_entries / 10).max(1);
202
203    for (index, element) in collector.iter()?.enumerate() {
204        let (k, v) = element?;
205        let sharded_key = decode_key(k)?;
206        let new_list = BlockNumberList::decompress_owned(v)?;
207
208        if index > 0 && index.is_multiple_of(interval) && total_entries > 10 {
209            info!(target: "sync::stages::index_history", progress = %format!("{:.2}%", (index as f64 / total_entries as f64) * 100.0), "Writing indices");
210        }
211
212        // AccountsHistory: `Address`.
213        // StorageHistory: `Address.StorageKey`.
214        let partial_key = get_partial(sharded_key);
215
216        if current_partial != partial_key {
217            // We have reached the end of this subset of keys so
218            // we need to flush its last indice shard.
219            load_indices(
220                &mut write_cursor,
221                current_partial,
222                &mut current_list,
223                &sharded_key_factory,
224                append_only,
225                LoadMode::Flush,
226            )?;
227
228            current_partial = partial_key;
229            current_list.clear();
230
231            // If it's not the first sync, there might an existing shard already, so we need to
232            // merge it with the one coming from the collector
233            if !append_only &&
234                let Some((_, last_database_shard)) =
235                    write_cursor.seek_exact(sharded_key_factory(current_partial, u64::MAX))?
236            {
237                current_list.extend(last_database_shard.iter());
238            }
239        }
240
241        current_list.extend(new_list.iter());
242        load_indices(
243            &mut write_cursor,
244            current_partial,
245            &mut current_list,
246            &sharded_key_factory,
247            append_only,
248            LoadMode::KeepLast,
249        )?;
250    }
251
252    // There will be one remaining shard that needs to be flushed to DB.
253    load_indices(
254        &mut write_cursor,
255        current_partial,
256        &mut current_list,
257        &sharded_key_factory,
258        append_only,
259        LoadMode::Flush,
260    )?;
261
262    Ok(())
263}
264
265/// Shard and insert the indices list according to [`LoadMode`] and its length.
266pub(crate) fn load_indices<H, C, P>(
267    cursor: &mut C,
268    partial_key: P,
269    list: &mut Vec<BlockNumber>,
270    sharded_key_factory: &impl Fn(P, BlockNumber) -> <H as Table>::Key,
271    append_only: bool,
272    mode: LoadMode,
273) -> Result<(), StageError>
274where
275    C: DbCursorRO<H> + DbCursorRW<H>,
276    H: Table<Value = BlockNumberList>,
277    P: Copy,
278{
279    if list.len() > NUM_OF_INDICES_IN_SHARD || mode.is_flush() {
280        let chunks = list
281            .chunks(NUM_OF_INDICES_IN_SHARD)
282            .map(|chunks| chunks.to_vec())
283            .collect::<Vec<Vec<u64>>>();
284
285        let mut iter = chunks.into_iter().peekable();
286        while let Some(chunk) = iter.next() {
287            let mut highest = *chunk.last().expect("at least one index");
288
289            if !mode.is_flush() && iter.peek().is_none() {
290                *list = chunk;
291            } else {
292                if iter.peek().is_none() {
293                    highest = u64::MAX;
294                }
295                let key = sharded_key_factory(partial_key, highest);
296                let value = BlockNumberList::new_pre_sorted(chunk);
297
298                if append_only {
299                    cursor.append(key, &value)?;
300                } else {
301                    cursor.upsert(key, &value)?;
302                }
303            }
304        }
305    }
306
307    Ok(())
308}
309
310/// Mode on how to load index shards into the database.
311pub(crate) enum LoadMode {
312    /// Keep the last shard in memory and don't flush it to the database.
313    KeepLast,
314    /// Flush all shards into the database.
315    Flush,
316}
317
318impl LoadMode {
319    const fn is_flush(&self) -> bool {
320        matches!(self, Self::Flush)
321    }
322}
323
324/// Called when database is ahead of static files. Attempts to find the first block we are missing
325/// transactions for.
326pub(crate) fn missing_static_data_error<Provider>(
327    last_tx_num: TxNumber,
328    static_file_provider: &StaticFileProvider<Provider::Primitives>,
329    provider: &Provider,
330    segment: StaticFileSegment,
331) -> Result<StageError, ProviderError>
332where
333    Provider: BlockReader + StaticFileProviderFactory,
334{
335    let mut last_block =
336        static_file_provider.get_highest_static_file_block(segment).unwrap_or_default();
337
338    // To be extra safe, we make sure that the last tx num matches the last block from its indices.
339    // If not, get it.
340    loop {
341        if let Some(indices) = provider.block_body_indices(last_block)? &&
342            indices.last_tx_num() <= last_tx_num
343        {
344            break
345        }
346        if last_block == 0 {
347            break
348        }
349        last_block -= 1;
350    }
351
352    let missing_block = Box::new(provider.sealed_header(last_block + 1)?.unwrap_or_default());
353
354    Ok(StageError::MissingStaticFileData {
355        block: Box::new(missing_block.block_with_parent()),
356        segment,
357    })
358}