reth_stages/stages/
utils.rs

1//! Utils for `stages`.
2use alloy_primitives::{BlockNumber, TxNumber};
3use reth_config::config::EtlConfig;
4use reth_db_api::{
5    cursor::{DbCursorRO, DbCursorRW},
6    models::sharded_key::NUM_OF_INDICES_IN_SHARD,
7    table::{Decompress, Table},
8    transaction::{DbTx, DbTxMut},
9    BlockNumberList, DatabaseError,
10};
11use reth_etl::Collector;
12use reth_provider::{
13    providers::StaticFileProvider, BlockReader, DBProvider, ProviderError,
14    StaticFileProviderFactory,
15};
16use reth_stages_api::StageError;
17use reth_static_file_types::StaticFileSegment;
18use std::{collections::HashMap, hash::Hash, ops::RangeBounds};
19use tracing::info;
20
21/// Number of blocks before pushing indices from cache to [`Collector`]
22const DEFAULT_CACHE_THRESHOLD: u64 = 100_000;
23
24/// Collects all history (`H`) indices for a range of changesets (`CS`) and stores them in a
25/// [`Collector`].
26///
27/// ## Process
28/// The function utilizes a `HashMap` cache with a structure of `PartialKey` (`P`) (Address or
29/// Address.StorageKey) to `BlockNumberList`. When the cache exceeds its capacity, its contents are
30/// moved to a [`Collector`]. Here, each entry's key is a concatenation of `PartialKey` and the
31/// highest block number in its list.
32///
33/// ## Example
34/// 1. Initial Cache State: `{ Address1: [1,2,3], ... }`
35/// 2. Cache is flushed to the `Collector`.
36/// 3. Updated Cache State: `{ Address1: [100,300], ... }`
37/// 4. Cache is flushed again.
38///
39/// As a result, the `Collector` will contain entries such as `(Address1.3, [1,2,3])` and
40/// `(Address1.300, [100,300])`. The entries may be stored across one or more files.
41pub(crate) fn collect_history_indices<Provider, CS, H, P>(
42    provider: &Provider,
43    range: impl RangeBounds<CS::Key>,
44    sharded_key_factory: impl Fn(P, BlockNumber) -> H::Key,
45    partial_key_factory: impl Fn((CS::Key, CS::Value)) -> (u64, P),
46    etl_config: &EtlConfig,
47) -> Result<Collector<H::Key, H::Value>, StageError>
48where
49    Provider: DBProvider,
50    CS: Table,
51    H: Table<Value = BlockNumberList>,
52    P: Copy + Eq + Hash,
53{
54    let mut changeset_cursor = provider.tx_ref().cursor_read::<CS>()?;
55
56    let mut collector = Collector::new(etl_config.file_size, etl_config.dir.clone());
57    let mut cache: HashMap<P, Vec<u64>> = HashMap::default();
58
59    let mut collect = |cache: &HashMap<P, Vec<u64>>| {
60        for (key, indices) in cache {
61            let last = indices.last().expect("qed");
62            collector.insert(
63                sharded_key_factory(*key, *last),
64                BlockNumberList::new_pre_sorted(indices.iter().copied()),
65            )?;
66        }
67        Ok::<(), StageError>(())
68    };
69
70    // observability
71    let total_changesets = provider.tx_ref().entries::<CS>()?;
72    let interval = (total_changesets / 1000).max(1);
73
74    let mut flush_counter = 0;
75    let mut current_block_number = u64::MAX;
76    for (idx, entry) in changeset_cursor.walk_range(range)?.enumerate() {
77        let (block_number, key) = partial_key_factory(entry?);
78        cache.entry(key).or_default().push(block_number);
79
80        if idx > 0 && idx % interval == 0 && total_changesets > 1000 {
81            info!(target: "sync::stages::index_history", progress = %format!("{:.4}%", (idx as f64 / total_changesets as f64) * 100.0), "Collecting indices");
82        }
83
84        // Make sure we only flush the cache every DEFAULT_CACHE_THRESHOLD blocks.
85        if current_block_number != block_number {
86            current_block_number = block_number;
87            flush_counter += 1;
88            if flush_counter > DEFAULT_CACHE_THRESHOLD {
89                collect(&cache)?;
90                cache.clear();
91                flush_counter = 0;
92            }
93        }
94    }
95    collect(&cache)?;
96
97    Ok(collector)
98}
99
100/// Given a [`Collector`] created by [`collect_history_indices`] it iterates all entries, loading
101/// the indices into the database in shards.
102///
103///  ## Process
104/// Iterates over elements, grouping indices by their partial keys (e.g., `Address` or
105/// `Address.StorageKey`). It flushes indices to disk when reaching a shard's max length
106/// (`NUM_OF_INDICES_IN_SHARD`) or when the partial key changes, ensuring the last previous partial
107/// key shard is stored.
108pub(crate) fn load_history_indices<Provider, H, P>(
109    provider: &Provider,
110    mut collector: Collector<H::Key, H::Value>,
111    append_only: bool,
112    sharded_key_factory: impl Clone + Fn(P, u64) -> <H as Table>::Key,
113    decode_key: impl Fn(Vec<u8>) -> Result<<H as Table>::Key, DatabaseError>,
114    get_partial: impl Fn(<H as Table>::Key) -> P,
115) -> Result<(), StageError>
116where
117    Provider: DBProvider<Tx: DbTxMut>,
118    H: Table<Value = BlockNumberList>,
119    P: Copy + Default + Eq,
120{
121    let mut write_cursor = provider.tx_ref().cursor_write::<H>()?;
122    let mut current_partial = P::default();
123    let mut current_list = Vec::<u64>::new();
124
125    // observability
126    let total_entries = collector.len();
127    let interval = (total_entries / 100).max(1);
128
129    for (index, element) in collector.iter()?.enumerate() {
130        let (k, v) = element?;
131        let sharded_key = decode_key(k)?;
132        let new_list = BlockNumberList::decompress_owned(v)?;
133
134        if index > 0 && index % interval == 0 && total_entries > 100 {
135            info!(target: "sync::stages::index_history", progress = %format!("{:.2}%", (index as f64 / total_entries as f64) * 100.0), "Writing indices");
136        }
137
138        // AccountsHistory: `Address`.
139        // StorageHistory: `Address.StorageKey`.
140        let partial_key = get_partial(sharded_key);
141
142        if current_partial != partial_key {
143            // We have reached the end of this subset of keys so
144            // we need to flush its last indice shard.
145            load_indices(
146                &mut write_cursor,
147                current_partial,
148                &mut current_list,
149                &sharded_key_factory,
150                append_only,
151                LoadMode::Flush,
152            )?;
153
154            current_partial = partial_key;
155            current_list.clear();
156
157            // If it's not the first sync, there might an existing shard already, so we need to
158            // merge it with the one coming from the collector
159            if !append_only {
160                if let Some((_, last_database_shard)) =
161                    write_cursor.seek_exact(sharded_key_factory(current_partial, u64::MAX))?
162                {
163                    current_list.extend(last_database_shard.iter());
164                }
165            }
166        }
167
168        current_list.extend(new_list.iter());
169        load_indices(
170            &mut write_cursor,
171            current_partial,
172            &mut current_list,
173            &sharded_key_factory,
174            append_only,
175            LoadMode::KeepLast,
176        )?;
177    }
178
179    // There will be one remaining shard that needs to be flushed to DB.
180    load_indices(
181        &mut write_cursor,
182        current_partial,
183        &mut current_list,
184        &sharded_key_factory,
185        append_only,
186        LoadMode::Flush,
187    )?;
188
189    Ok(())
190}
191
192/// Shard and insert the indices list according to [`LoadMode`] and its length.
193pub(crate) fn load_indices<H, C, P>(
194    cursor: &mut C,
195    partial_key: P,
196    list: &mut Vec<BlockNumber>,
197    sharded_key_factory: &impl Fn(P, BlockNumber) -> <H as Table>::Key,
198    append_only: bool,
199    mode: LoadMode,
200) -> Result<(), StageError>
201where
202    C: DbCursorRO<H> + DbCursorRW<H>,
203    H: Table<Value = BlockNumberList>,
204    P: Copy,
205{
206    if list.len() > NUM_OF_INDICES_IN_SHARD || mode.is_flush() {
207        let chunks = list
208            .chunks(NUM_OF_INDICES_IN_SHARD)
209            .map(|chunks| chunks.to_vec())
210            .collect::<Vec<Vec<u64>>>();
211
212        let mut iter = chunks.into_iter().peekable();
213        while let Some(chunk) = iter.next() {
214            let mut highest = *chunk.last().expect("at least one index");
215
216            if !mode.is_flush() && iter.peek().is_none() {
217                *list = chunk;
218            } else {
219                if iter.peek().is_none() {
220                    highest = u64::MAX;
221                }
222                let key = sharded_key_factory(partial_key, highest);
223                let value = BlockNumberList::new_pre_sorted(chunk);
224
225                if append_only {
226                    cursor.append(key, &value)?;
227                } else {
228                    cursor.upsert(key, &value)?;
229                }
230            }
231        }
232    }
233
234    Ok(())
235}
236
237/// Mode on how to load index shards into the database.
238pub(crate) enum LoadMode {
239    /// Keep the last shard in memory and don't flush it to the database.
240    KeepLast,
241    /// Flush all shards into the database.
242    Flush,
243}
244
245impl LoadMode {
246    const fn is_flush(&self) -> bool {
247        matches!(self, Self::Flush)
248    }
249}
250
251/// Called when database is ahead of static files. Attempts to find the first block we are missing
252/// transactions for.
253pub(crate) fn missing_static_data_error<Provider>(
254    last_tx_num: TxNumber,
255    static_file_provider: &StaticFileProvider<Provider::Primitives>,
256    provider: &Provider,
257    segment: StaticFileSegment,
258) -> Result<StageError, ProviderError>
259where
260    Provider: BlockReader + StaticFileProviderFactory,
261{
262    let mut last_block =
263        static_file_provider.get_highest_static_file_block(segment).unwrap_or_default();
264
265    // To be extra safe, we make sure that the last tx num matches the last block from its indices.
266    // If not, get it.
267    loop {
268        if let Some(indices) = provider.block_body_indices(last_block)? {
269            if indices.last_tx_num() <= last_tx_num {
270                break
271            }
272        }
273        if last_block == 0 {
274            break
275        }
276        last_block -= 1;
277    }
278
279    let missing_block = Box::new(provider.sealed_header(last_block + 1)?.unwrap_or_default());
280
281    Ok(StageError::MissingStaticFileData {
282        block: Box::new(missing_block.block_with_parent()),
283        segment,
284    })
285}