1use crate::stages::merkle::INVALID_STATE_ROOT_ERROR_MESSAGE;
2use alloy_consensus::BlockHeader;
3use alloy_primitives::BlockNumber;
4use reth_consensus::ConsensusError;
5use reth_primitives_traits::{GotExpected, SealedHeader};
6use reth_provider::{
7 ChainStateBlockReader, DBProvider, HeaderProvider, ProviderError, PruneCheckpointReader,
8 PruneCheckpointWriter, StageCheckpointReader, TrieWriter,
9};
10use reth_prune_types::{PruneCheckpoint, PruneMode, PruneSegment};
11use reth_stages_api::{
12 BlockErrorKind, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId,
13 UnwindInput, UnwindOutput,
14};
15use reth_trie::{updates::TrieUpdates, HashedPostState, KeccakKeyHasher, StateRoot, TrieInput};
16use reth_trie_db::{DatabaseHashedPostState, DatabaseStateRoot};
17use std::ops::Range;
18use tracing::{debug, error};
19
20#[derive(Debug, Clone)]
24pub struct MerkleChangeSets {
25 retention_blocks: u64,
28}
29
30impl MerkleChangeSets {
31 pub const fn new() -> Self {
33 Self { retention_blocks: 64 }
34 }
35
36 pub const fn with_retention_blocks(retention_blocks: u64) -> Self {
38 Self { retention_blocks }
39 }
40
41 fn computed_range<Provider>(
44 provider: &Provider,
45 checkpoint: Option<StageCheckpoint>,
46 ) -> Result<Range<BlockNumber>, StageError>
47 where
48 Provider: PruneCheckpointReader,
49 {
50 let to = checkpoint.map(|chk| chk.block_number).unwrap_or_default();
51
52 let Some(from) = provider
55 .get_prune_checkpoint(PruneSegment::MerkleChangeSets)?
56 .and_then(|chk| chk.block_number)
57 .map(|block_number| block_number + 1)
60 else {
61 return Ok(0..0)
62 };
63
64 Ok(from..to + 1)
65 }
66
67 fn determine_target_range<Provider>(
72 &self,
73 provider: &Provider,
74 ) -> Result<Range<BlockNumber>, StageError>
75 where
76 Provider: StageCheckpointReader + ChainStateBlockReader,
77 {
78 let merkle_checkpoint = provider
80 .get_stage_checkpoint(StageId::MerkleExecute)?
81 .map(|checkpoint| checkpoint.block_number)
82 .unwrap_or(0);
83
84 let target_end = merkle_checkpoint + 1; let finalized_block = provider.last_finalized_block_number()?;
89
90 let retention_based_start = merkle_checkpoint.saturating_sub(self.retention_blocks);
92
93 let mut target_start = finalized_block
99 .map(|finalized| finalized.saturating_add(1).max(retention_based_start))
100 .unwrap_or(retention_based_start);
101
102 target_start = target_start.max(1);
104
105 Ok(target_start..target_end)
106 }
107
108 fn calculate_block_trie_updates<Provider: DBProvider + HeaderProvider>(
111 provider: &Provider,
112 block_number: BlockNumber,
113 input: TrieInput,
114 ) -> Result<TrieUpdates, StageError> {
115 let (root, trie_updates) =
116 StateRoot::overlay_root_from_nodes_with_updates(provider.tx_ref(), input).map_err(
117 |e| {
118 error!(
119 target: "sync::stages::merkle_changesets",
120 %e,
121 ?block_number,
122 "Incremental state root failed! {INVALID_STATE_ROOT_ERROR_MESSAGE}");
123 StageError::Fatal(Box::new(e))
124 },
125 )?;
126
127 let block = provider
128 .header_by_number(block_number)?
129 .ok_or_else(|| ProviderError::HeaderNotFound(block_number.into()))?;
130
131 let (got, expected) = (root, block.state_root());
132 if got != expected {
133 let header = SealedHeader::seal_slow(block);
135 error!(
136 target: "sync::stages::merkle_changesets",
137 ?block_number,
138 ?got,
139 ?expected,
140 "Failed to verify block state root! {INVALID_STATE_ROOT_ERROR_MESSAGE}",
141 );
142 return Err(StageError::Block {
143 error: BlockErrorKind::Validation(ConsensusError::BodyStateRootDiff(
144 GotExpected { got, expected }.into(),
145 )),
146 block: Box::new(header.block_with_parent()),
147 })
148 }
149
150 Ok(trie_updates)
151 }
152
153 fn populate_range<Provider>(
154 provider: &Provider,
155 target_range: Range<BlockNumber>,
156 ) -> Result<(), StageError>
157 where
158 Provider: StageCheckpointReader
159 + TrieWriter
160 + DBProvider
161 + HeaderProvider
162 + ChainStateBlockReader,
163 {
164 let target_start = target_range.start;
165 let target_end = target_range.end;
166 debug!(
167 target: "sync::stages::merkle_changesets",
168 ?target_range,
169 "Starting trie changeset computation",
170 );
171
172 debug!(
189 target: "sync::stages::merkle_changesets",
190 ?target_range,
191 "Computing per-block state reverts",
192 );
193 let mut per_block_state_reverts = Vec::new();
194 for block_number in target_range.clone() {
195 per_block_state_reverts.push(HashedPostState::from_reverts::<KeccakKeyHasher>(
196 provider.tx_ref(),
197 block_number..=block_number,
198 )?);
199 }
200
201 let get_block_state_revert = |block_number: BlockNumber| -> &HashedPostState {
203 let index = (block_number - target_start) as usize;
204 &per_block_state_reverts[index]
205 };
206
207 let compute_cumulative_state_revert = |block_number: BlockNumber| -> HashedPostState {
209 let mut cumulative_revert = HashedPostState::default();
210 for n in (block_number..target_end).rev() {
211 cumulative_revert.extend_ref(get_block_state_revert(n))
212 }
213 cumulative_revert
214 };
215
216 debug!(
226 target: "sync::stages::merkle_changesets",
227 ?target_start,
228 "Computing trie state at starting block",
229 );
230 let mut input = TrieInput::default();
231 input.state = compute_cumulative_state_revert(target_start);
232 input.prefix_sets = input.state.construct_prefix_sets();
233 input.nodes =
235 Self::calculate_block_trie_updates(provider, target_start - 1, input.clone())?;
236
237 for block_number in target_range {
238 debug!(
239 target: "sync::stages::merkle_changesets",
240 ?block_number,
241 "Computing trie updates for block",
242 );
243 input.state = compute_cumulative_state_revert(block_number + 1);
246
247 input.prefix_sets = get_block_state_revert(block_number).construct_prefix_sets();
250
251 let this_trie_updates =
255 Self::calculate_block_trie_updates(provider, block_number, input.clone())?;
256
257 let trie_overlay = input.nodes.clone().into_sorted();
258 input.nodes.extend_ref(&this_trie_updates);
259 let this_trie_updates = this_trie_updates.into_sorted();
260
261 debug!(
264 target: "sync::stages::merkle_changesets",
265 ?block_number,
266 "Writing trie changesets for block",
267 );
268 provider.write_trie_changesets(
269 block_number,
270 &this_trie_updates,
271 Some(&trie_overlay),
272 )?;
273 }
274
275 Ok(())
276 }
277}
278
279impl Default for MerkleChangeSets {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285impl<Provider> Stage<Provider> for MerkleChangeSets
286where
287 Provider: StageCheckpointReader
288 + TrieWriter
289 + DBProvider
290 + HeaderProvider
291 + ChainStateBlockReader
292 + PruneCheckpointReader
293 + PruneCheckpointWriter,
294{
295 fn id(&self) -> StageId {
296 StageId::MerkleChangeSets
297 }
298
299 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
300 let merkle_checkpoint = provider
302 .get_stage_checkpoint(StageId::MerkleExecute)?
303 .map(|checkpoint| checkpoint.block_number)
304 .unwrap_or(0);
305
306 if input.target.is_none_or(|target| merkle_checkpoint != target) {
307 return Err(StageError::Fatal(eyre::eyre!("Cannot sync stage to block {:?} when MerkleExecute is at block {merkle_checkpoint:?}", input.target).into()))
308 }
309
310 let mut target_range = self.determine_target_range(provider)?;
311
312 let mut computed_range = Self::computed_range(provider, input.checkpoint)?;
315 debug!(
316 target: "sync::stages::merkle_changesets",
317 ?computed_range,
318 ?target_range,
319 "Got computed and target ranges",
320 );
321
322 if target_range.start >= computed_range.start {
340 target_range.start = target_range.start.max(computed_range.end);
341 }
342
343 if target_range.start >= target_range.end {
346 return Ok(ExecOutput::done(StageCheckpoint::new(target_range.end.saturating_sub(1))));
347 }
348
349 if target_range.start == computed_range.end {
352 provider.clear_trie_changesets_from(target_range.start)?;
354 computed_range.end = target_range.end;
355 } else {
356 provider.clear_trie_changesets()?;
359 computed_range = target_range.clone();
360 }
361
362 Self::populate_range(provider, target_range)?;
364
365 provider.save_prune_checkpoint(
368 PruneSegment::MerkleChangeSets,
369 PruneCheckpoint {
370 block_number: Some(computed_range.start.saturating_sub(1)),
371 tx_number: None,
372 prune_mode: PruneMode::Before(computed_range.start),
373 },
374 )?;
375
376 let checkpoint = StageCheckpoint::new(computed_range.end.saturating_sub(1));
378
379 Ok(ExecOutput::done(checkpoint))
380 }
381
382 fn unwind(
383 &mut self,
384 provider: &Provider,
385 input: UnwindInput,
386 ) -> Result<UnwindOutput, StageError> {
387 provider.clear_trie_changesets_from(input.unwind_to + 1)?;
389
390 let mut computed_range = Self::computed_range(provider, Some(input.checkpoint))?;
391 computed_range.end = input.unwind_to + 1;
392 if computed_range.start > computed_range.end {
393 computed_range.start = computed_range.end;
394 }
395
396 let checkpoint = StageCheckpoint::new(computed_range.end.saturating_sub(1));
398
399 Ok(UnwindOutput { checkpoint })
400 }
401}