reth_stages/stages/
merkle_changesets.rs

1use crate::stages::merkle::INVALID_STATE_ROOT_ERROR_MESSAGE;
2use alloy_consensus::BlockHeader;
3use alloy_primitives::BlockNumber;
4use reth_consensus::ConsensusError;
5use reth_primitives_traits::{GotExpected, SealedHeader};
6use reth_provider::{
7    ChainStateBlockReader, DBProvider, HeaderProvider, ProviderError, PruneCheckpointReader,
8    PruneCheckpointWriter, StageCheckpointReader, TrieWriter,
9};
10use reth_prune_types::{PruneCheckpoint, PruneMode, PruneSegment};
11use reth_stages_api::{
12    BlockErrorKind, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId,
13    UnwindInput, UnwindOutput,
14};
15use reth_trie::{updates::TrieUpdates, HashedPostState, KeccakKeyHasher, StateRoot, TrieInput};
16use reth_trie_db::{DatabaseHashedPostState, DatabaseStateRoot};
17use std::ops::Range;
18use tracing::{debug, error};
19
20/// The `MerkleChangeSets` stage.
21///
22/// This stage processes and maintains trie changesets from the finalized block to the latest block.
23#[derive(Debug, Clone)]
24pub struct MerkleChangeSets {
25    /// The number of blocks to retain changesets for, used as a fallback when the finalized block
26    /// is not found. Defaults to 64 (2 epochs in beacon chain).
27    retention_blocks: u64,
28}
29
30impl MerkleChangeSets {
31    /// Creates a new `MerkleChangeSets` stage with default retention blocks of 64.
32    pub const fn new() -> Self {
33        Self { retention_blocks: 64 }
34    }
35
36    /// Creates a new `MerkleChangeSets` stage with a custom finalized block height.
37    pub const fn with_retention_blocks(retention_blocks: u64) -> Self {
38        Self { retention_blocks }
39    }
40
41    /// Returns the range of blocks which are already computed. Will return an empty range if none
42    /// have been computed.
43    fn computed_range<Provider>(
44        provider: &Provider,
45        checkpoint: Option<StageCheckpoint>,
46    ) -> Result<Range<BlockNumber>, StageError>
47    where
48        Provider: PruneCheckpointReader,
49    {
50        let to = checkpoint.map(|chk| chk.block_number).unwrap_or_default();
51
52        // Get the prune checkpoint for MerkleChangeSets to use as the lower bound. If there's no
53        // prune checkpoint or if the pruned block number is None, return empty range
54        let Some(from) = provider
55            .get_prune_checkpoint(PruneSegment::MerkleChangeSets)?
56            .and_then(|chk| chk.block_number)
57            // prune checkpoint indicates the last block pruned, so the block after is the start of
58            // the computed data
59            .map(|block_number| block_number + 1)
60        else {
61            return Ok(0..0)
62        };
63
64        Ok(from..to + 1)
65    }
66
67    /// Determines the target range for changeset computation based on the checkpoint and provider
68    /// state.
69    ///
70    /// Returns the target range (exclusive end) to compute changesets for.
71    fn determine_target_range<Provider>(
72        &self,
73        provider: &Provider,
74    ) -> Result<Range<BlockNumber>, StageError>
75    where
76        Provider: StageCheckpointReader + ChainStateBlockReader,
77    {
78        // Get merkle checkpoint which represents our target end block
79        let merkle_checkpoint = provider
80            .get_stage_checkpoint(StageId::MerkleExecute)?
81            .map(|checkpoint| checkpoint.block_number)
82            .unwrap_or(0);
83
84        let target_end = merkle_checkpoint + 1; // exclusive
85
86        // Calculate the target range based on the finalized block and the target block.
87        // We maintain changesets from the finalized block to the latest block.
88        let finalized_block = provider.last_finalized_block_number()?;
89
90        // Calculate the fallback start position based on retention blocks
91        let retention_based_start = merkle_checkpoint.saturating_sub(self.retention_blocks);
92
93        // If the finalized block was way in the past then we don't want to generate changesets for
94        // all of those past blocks; we only care about the recent history.
95        //
96        // Use maximum of finalized_block and retention_based_start if finalized_block exists,
97        // otherwise just use retention_based_start.
98        let mut target_start = finalized_block
99            .map(|finalized| finalized.saturating_add(1).max(retention_based_start))
100            .unwrap_or(retention_based_start);
101
102        // We cannot revert the genesis block; target_start must be >0
103        target_start = target_start.max(1);
104
105        Ok(target_start..target_end)
106    }
107
108    /// Calculates the trie updates given a [`TrieInput`], asserting that the resulting state root
109    /// matches the expected one for the block.
110    fn calculate_block_trie_updates<Provider: DBProvider + HeaderProvider>(
111        provider: &Provider,
112        block_number: BlockNumber,
113        input: TrieInput,
114    ) -> Result<TrieUpdates, StageError> {
115        let (root, trie_updates) =
116            StateRoot::overlay_root_from_nodes_with_updates(provider.tx_ref(), input).map_err(
117                |e| {
118                    error!(
119                            target: "sync::stages::merkle_changesets",
120                            %e,
121                            ?block_number,
122                            "Incremental state root failed! {INVALID_STATE_ROOT_ERROR_MESSAGE}");
123                    StageError::Fatal(Box::new(e))
124                },
125            )?;
126
127        let block = provider
128            .header_by_number(block_number)?
129            .ok_or_else(|| ProviderError::HeaderNotFound(block_number.into()))?;
130
131        let (got, expected) = (root, block.state_root());
132        if got != expected {
133            // Only seal the header when we need it for the error
134            let header = SealedHeader::seal_slow(block);
135            error!(
136                target: "sync::stages::merkle_changesets",
137                ?block_number,
138                ?got,
139                ?expected,
140                "Failed to verify block state root! {INVALID_STATE_ROOT_ERROR_MESSAGE}",
141            );
142            return Err(StageError::Block {
143                error: BlockErrorKind::Validation(ConsensusError::BodyStateRootDiff(
144                    GotExpected { got, expected }.into(),
145                )),
146                block: Box::new(header.block_with_parent()),
147            })
148        }
149
150        Ok(trie_updates)
151    }
152
153    fn populate_range<Provider>(
154        provider: &Provider,
155        target_range: Range<BlockNumber>,
156    ) -> Result<(), StageError>
157    where
158        Provider: StageCheckpointReader
159            + TrieWriter
160            + DBProvider
161            + HeaderProvider
162            + ChainStateBlockReader,
163    {
164        let target_start = target_range.start;
165        let target_end = target_range.end;
166        debug!(
167            target: "sync::stages::merkle_changesets",
168            ?target_range,
169            "Starting trie changeset computation",
170        );
171
172        // We need to distinguish a cumulative revert and a per-block revert. A cumulative revert
173        // reverts changes starting at db tip all the way to a block. A per-block revert only
174        // reverts a block's changes.
175        //
176        // We need to calculate the cumulative HashedPostState reverts for every block in the
177        // target range. The cumulative HashedPostState revert for block N can be calculated as:
178        //
179        //
180        // ```
181        // // where `extend` overwrites any shared keys
182        // cumulative_state_revert(N) = cumulative_state_revert(N + 1).extend(get_block_state_revert(N))
183        // ```
184        //
185        // We need per-block reverts to calculate the prefix set for each individual block. By
186        // using the per-block reverts to calculate cumulative reverts on-the-fly we can save a
187        // bunch of memory.
188        debug!(
189            target: "sync::stages::merkle_changesets",
190            ?target_range,
191            "Computing per-block state reverts",
192        );
193        let mut per_block_state_reverts = Vec::new();
194        for block_number in target_range.clone() {
195            per_block_state_reverts.push(HashedPostState::from_reverts::<KeccakKeyHasher>(
196                provider.tx_ref(),
197                block_number..=block_number,
198            )?);
199        }
200
201        // Helper to retrieve state revert data for a specific block from the pre-computed array
202        let get_block_state_revert = |block_number: BlockNumber| -> &HashedPostState {
203            let index = (block_number - target_start) as usize;
204            &per_block_state_reverts[index]
205        };
206
207        // Helper to accumulate state reverts from a given block to the target end
208        let compute_cumulative_state_revert = |block_number: BlockNumber| -> HashedPostState {
209            let mut cumulative_revert = HashedPostState::default();
210            for n in (block_number..target_end).rev() {
211                cumulative_revert.extend_ref(get_block_state_revert(n))
212            }
213            cumulative_revert
214        };
215
216        // To calculate the changeset for a block, we first need the TrieUpdates which are
217        // generated as a result of processing the block. To get these we need:
218        // 1) The TrieUpdates which revert the db's trie to _prior_ to the block
219        // 2) The HashedPostState to revert the db's state to _after_ the block
220        //
221        // To get (1) for `target_start` we need to do a big state root calculation which takes
222        // into account all changes between that block and db tip. For each block after the
223        // `target_start` we can update (1) using the TrieUpdates which were output by the previous
224        // block, only targeting the state changes of that block.
225        debug!(
226            target: "sync::stages::merkle_changesets",
227            ?target_start,
228            "Computing trie state at starting block",
229        );
230        let mut input = TrieInput::default();
231        input.state = compute_cumulative_state_revert(target_start);
232        input.prefix_sets = input.state.construct_prefix_sets();
233        // target_start will be >= 1, see `determine_target_range`.
234        input.nodes =
235            Self::calculate_block_trie_updates(provider, target_start - 1, input.clone())?;
236
237        for block_number in target_range {
238            debug!(
239                target: "sync::stages::merkle_changesets",
240                ?block_number,
241                "Computing trie updates for block",
242            );
243            // Revert the state so that this block has been just processed, meaning we take the
244            // cumulative revert of the subsequent block.
245            input.state = compute_cumulative_state_revert(block_number + 1);
246
247            // Construct prefix sets from only this block's `HashedPostState`, because we only care
248            // about trie updates which occurred as a result of this block being processed.
249            input.prefix_sets = get_block_state_revert(block_number).construct_prefix_sets();
250
251            // Calculate the trie updates for this block, then apply those updates to the reverts.
252            // We calculate the overlay which will be passed into the next step using the trie
253            // reverts prior to them being updated.
254            let this_trie_updates =
255                Self::calculate_block_trie_updates(provider, block_number, input.clone())?;
256
257            let trie_overlay = input.nodes.clone().into_sorted();
258            input.nodes.extend_ref(&this_trie_updates);
259            let this_trie_updates = this_trie_updates.into_sorted();
260
261            // Write the changesets to the DB using the trie updates produced by the block, and the
262            // trie reverts as the overlay.
263            debug!(
264                target: "sync::stages::merkle_changesets",
265                ?block_number,
266                "Writing trie changesets for block",
267            );
268            provider.write_trie_changesets(
269                block_number,
270                &this_trie_updates,
271                Some(&trie_overlay),
272            )?;
273        }
274
275        Ok(())
276    }
277}
278
279impl Default for MerkleChangeSets {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285impl<Provider> Stage<Provider> for MerkleChangeSets
286where
287    Provider: StageCheckpointReader
288        + TrieWriter
289        + DBProvider
290        + HeaderProvider
291        + ChainStateBlockReader
292        + PruneCheckpointReader
293        + PruneCheckpointWriter,
294{
295    fn id(&self) -> StageId {
296        StageId::MerkleChangeSets
297    }
298
299    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
300        // Get merkle checkpoint and assert that the target is the same.
301        let merkle_checkpoint = provider
302            .get_stage_checkpoint(StageId::MerkleExecute)?
303            .map(|checkpoint| checkpoint.block_number)
304            .unwrap_or(0);
305
306        if input.target.is_none_or(|target| merkle_checkpoint != target) {
307            return Err(StageError::Fatal(eyre::eyre!("Cannot sync stage to block {:?} when MerkleExecute is at block {merkle_checkpoint:?}", input.target).into()))
308        }
309
310        let mut target_range = self.determine_target_range(provider)?;
311
312        // Get the previously computed range. This will be updated to reflect the populating of the
313        // target range.
314        let mut computed_range = Self::computed_range(provider, input.checkpoint)?;
315        debug!(
316            target: "sync::stages::merkle_changesets",
317            ?computed_range,
318            ?target_range,
319            "Got computed and target ranges",
320        );
321
322        // We want the target range to not include any data already computed previously, if
323        // possible, so we start the target range from the end of the computed range if that is
324        // greater.
325        //
326        // ------------------------------> Block #
327        //    |------computed-----|
328        //              |-----target-----|
329        //                        |--actual--|
330        //
331        // However, if the target start is less than the previously computed start, we don't want to
332        // do this, as it would leave a gap of data at `target_range.start..=computed_range.start`.
333        //
334        // ------------------------------> Block #
335        //         |---computed---|
336        //      |-------target-------|
337        //      |-------actual-------|
338        //
339        if target_range.start >= computed_range.start {
340            target_range.start = target_range.start.max(computed_range.end);
341        }
342
343        // If target range is empty (target_start >= target_end), stage is already successfully
344        // executed.
345        if target_range.start >= target_range.end {
346            return Ok(ExecOutput::done(StageCheckpoint::new(target_range.end.saturating_sub(1))));
347        }
348
349        // If our target range is a continuation of the already computed range then we can keep the
350        // already computed data.
351        if target_range.start == computed_range.end {
352            // Clear from target_start onwards to ensure no stale data exists
353            provider.clear_trie_changesets_from(target_range.start)?;
354            computed_range.end = target_range.end;
355        } else {
356            // If our target range is not a continuation of the already computed range then we
357            // simply clear the computed data, to make sure there's no gaps or conflicts.
358            provider.clear_trie_changesets()?;
359            computed_range = target_range.clone();
360        }
361
362        // Populate the target range with changesets
363        Self::populate_range(provider, target_range)?;
364
365        // Update the prune checkpoint to reflect that all data before `computed_range.start`
366        // is not available.
367        provider.save_prune_checkpoint(
368            PruneSegment::MerkleChangeSets,
369            PruneCheckpoint {
370                block_number: Some(computed_range.start.saturating_sub(1)),
371                tx_number: None,
372                prune_mode: PruneMode::Before(computed_range.start),
373            },
374        )?;
375
376        // `computed_range.end` is exclusive.
377        let checkpoint = StageCheckpoint::new(computed_range.end.saturating_sub(1));
378
379        Ok(ExecOutput::done(checkpoint))
380    }
381
382    fn unwind(
383        &mut self,
384        provider: &Provider,
385        input: UnwindInput,
386    ) -> Result<UnwindOutput, StageError> {
387        // Unwinding is trivial; just clear everything after the target block.
388        provider.clear_trie_changesets_from(input.unwind_to + 1)?;
389
390        let mut computed_range = Self::computed_range(provider, Some(input.checkpoint))?;
391        computed_range.end = input.unwind_to + 1;
392        if computed_range.start > computed_range.end {
393            computed_range.start = computed_range.end;
394        }
395
396        // `computed_range.end` is exclusive
397        let checkpoint = StageCheckpoint::new(computed_range.end.saturating_sub(1));
398
399        Ok(UnwindOutput { checkpoint })
400    }
401}