1use crate::{error::StageError, StageCheckpoint, StageId};
2use alloy_primitives::{BlockNumber, TxNumber};
3use reth_provider::{BlockReader, ProviderError, StaticFileProviderFactory, StaticFileSegment};
4use std::{
5 cmp::{max, min},
6 future::{poll_fn, Future},
7 ops::{Range, RangeInclusive},
8 task::{Context, Poll},
9};
10use tracing::instrument;
11
12#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
14pub struct ExecInput {
15 pub target: Option<BlockNumber>,
17 pub checkpoint: Option<StageCheckpoint>,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
23pub struct BlockRangeOutput {
24 pub block_range: RangeInclusive<BlockNumber>,
26 pub is_final_range: bool,
28}
29
30#[derive(Debug, PartialEq, Eq, Clone)]
32pub struct TransactionRangeOutput {
33 pub tx_range: Range<TxNumber>,
35 pub block_range: RangeInclusive<BlockNumber>,
37 pub is_final_range: bool,
39}
40
41impl ExecInput {
42 pub fn checkpoint(&self) -> StageCheckpoint {
44 self.checkpoint.unwrap_or_default()
45 }
46
47 pub fn next_block(&self) -> BlockNumber {
50 let current_block = self.checkpoint();
51 current_block.block_number + 1
52 }
53
54 pub fn target_reached(&self) -> bool {
56 self.checkpoint().block_number >= self.target()
57 }
58
59 pub fn target(&self) -> BlockNumber {
61 self.target.unwrap_or_default()
62 }
63
64 pub fn next_block_range(&self) -> RangeInclusive<BlockNumber> {
66 self.next_block_range_with_threshold(u64::MAX).block_range
67 }
68
69 pub const fn is_first_range(&self) -> bool {
71 self.checkpoint.is_none()
72 }
73
74 pub fn next_block_range_with_threshold(&self, threshold: u64) -> BlockRangeOutput {
76 let current_block = self.checkpoint();
77 let start = current_block.block_number + 1;
78 let target = self.target();
79
80 let end = min(target, current_block.block_number.saturating_add(threshold));
81
82 let is_final_range = end == target;
83 BlockRangeOutput { block_range: start..=end, is_final_range }
84 }
85
86 #[instrument(level = "debug", target = "sync::stages", skip(provider), ret)]
92 pub fn next_block_range_with_transaction_threshold<Provider>(
93 &self,
94 provider: &Provider,
95 tx_threshold: u64,
96 ) -> Result<Option<TransactionRangeOutput>, StageError>
97 where
98 Provider: StaticFileProviderFactory + BlockReader,
99 {
100 let Some(lowest_transactions_block) =
102 provider.static_file_provider().get_lowest_range_start(StaticFileSegment::Transactions)
103 else {
104 return Ok(None)
105 };
106
107 let start_block = self.next_block().max(lowest_transactions_block);
114 let target_block = self.target();
115
116 if start_block > target_block {
120 return Ok(None)
121 }
122
123 let start_block_body = provider
124 .block_body_indices(start_block)?
125 .ok_or(ProviderError::BlockBodyIndicesNotFound(start_block))?;
126 let first_tx_num = start_block_body.first_tx_num();
127
128 let target_block_body = provider
129 .block_body_indices(target_block)?
130 .ok_or(ProviderError::BlockBodyIndicesNotFound(target_block))?;
131
132 let all_tx_cnt = target_block_body.next_tx_num() - first_tx_num;
134
135 if all_tx_cnt == 0 {
136 return Ok(None)
138 }
139
140 let (end_block, is_final_range, next_tx_num) = if all_tx_cnt <= tx_threshold {
142 (target_block, true, target_block_body.next_tx_num())
143 } else {
144 let end_block_number = provider
147 .block_by_transaction_id(first_tx_num + tx_threshold)?
148 .expect("block of tx must exist");
149 let end_block_body = provider
152 .block_body_indices(end_block_number)?
153 .ok_or(ProviderError::BlockBodyIndicesNotFound(end_block_number))?;
154 (end_block_number, false, end_block_body.next_tx_num())
155 };
156
157 let tx_range = first_tx_num..next_tx_num;
158 Ok(Some(TransactionRangeOutput {
159 tx_range,
160 block_range: start_block..=end_block,
161 is_final_range,
162 }))
163 }
164}
165
166#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
168pub struct UnwindInput {
169 pub checkpoint: StageCheckpoint,
171 pub unwind_to: BlockNumber,
173 pub bad_block: Option<BlockNumber>,
175}
176
177impl UnwindInput {
178 pub fn unwind_block_range(&self) -> RangeInclusive<BlockNumber> {
180 self.unwind_block_range_with_threshold(u64::MAX).0
181 }
182
183 pub fn unwind_block_range_with_threshold(
185 &self,
186 threshold: u64,
187 ) -> (RangeInclusive<BlockNumber>, BlockNumber, bool) {
188 let mut start = self.unwind_to + 1;
190 let end = self.checkpoint;
191
192 start = max(start, end.block_number.saturating_sub(threshold));
193
194 let unwind_to = start - 1;
195
196 let is_final_range = unwind_to == self.unwind_to;
197 (start..=end.block_number, unwind_to, is_final_range)
198 }
199}
200
201#[derive(Debug, PartialEq, Eq, Clone)]
203pub struct ExecOutput {
204 pub checkpoint: StageCheckpoint,
206 pub done: bool,
208}
209
210impl ExecOutput {
211 pub const fn in_progress(checkpoint: StageCheckpoint) -> Self {
213 Self { checkpoint, done: false }
214 }
215
216 pub const fn done(checkpoint: StageCheckpoint) -> Self {
218 Self { checkpoint, done: true }
219 }
220}
221
222#[derive(Debug, PartialEq, Eq, Clone)]
224pub struct UnwindOutput {
225 pub checkpoint: StageCheckpoint,
227}
228
229#[auto_impl::auto_impl(Box)]
241pub trait Stage<Provider>: Send + Sync {
242 fn id(&self) -> StageId;
246
247 fn poll_execute_ready(
272 &mut self,
273 _cx: &mut Context<'_>,
274 _input: ExecInput,
275 ) -> Poll<Result<(), StageError>> {
276 Poll::Ready(Ok(()))
277 }
278
279 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError>;
283
284 fn post_execute_commit(&mut self) -> Result<(), StageError> {
290 Ok(())
291 }
292
293 fn unwind(
295 &mut self,
296 provider: &Provider,
297 input: UnwindInput,
298 ) -> Result<UnwindOutput, StageError>;
299
300 fn post_unwind_commit(&mut self) -> Result<(), StageError> {
306 Ok(())
307 }
308}
309
310pub trait StageExt<Provider>: Stage<Provider> {
312 fn execute_ready(
315 &mut self,
316 input: ExecInput,
317 ) -> impl Future<Output = Result<(), StageError>> + Send {
318 poll_fn(move |cx| self.poll_execute_ready(cx, input))
319 }
320}
321
322impl<Provider, S: Stage<Provider> + ?Sized> StageExt<Provider> for S {}
323
324#[cfg(test)]
325mod tests {
326 use reth_chainspec::MAINNET;
327 use reth_db::test_utils::{create_test_rw_db, create_test_static_files_dir};
328 use reth_db_api::{models::StoredBlockBodyIndices, tables, transaction::DbTxMut};
329 use reth_provider::{
330 test_utils::MockNodeTypesWithDB, ProviderFactory, StaticFileProviderBuilder,
331 StaticFileProviderFactory, StaticFileSegment,
332 };
333 use reth_stages_types::StageCheckpoint;
334 use reth_testing_utils::generators::{self, random_signed_tx};
335
336 use crate::ExecInput;
337
338 #[test]
339 fn test_exec_input_next_block_range_with_transaction_threshold() {
340 let mut rng = generators::rng();
341 let provider_factory = ProviderFactory::<MockNodeTypesWithDB>::new(
342 create_test_rw_db(),
343 MAINNET.clone(),
344 StaticFileProviderBuilder::read_write(create_test_static_files_dir().0.keep())
345 .unwrap()
346 .with_blocks_per_file(1)
347 .build()
348 .unwrap(),
349 )
350 .unwrap();
351
352 {
354 let exec_input = ExecInput { target: Some(100), checkpoint: None };
355
356 let range_output = exec_input
357 .next_block_range_with_transaction_threshold(&provider_factory, 10)
358 .unwrap();
359 assert!(range_output.is_none());
360 }
361
362 {
364 let exec_input =
365 ExecInput { target: Some(1), checkpoint: Some(StageCheckpoint::new(10)) };
366
367 let range_output = exec_input
368 .next_block_range_with_transaction_threshold(&provider_factory, 10)
369 .unwrap();
370 assert!(range_output.is_none());
371 }
372
373 {
375 let exec_input = ExecInput { target: Some(1), checkpoint: None };
376
377 let mut provider_rw = provider_factory.provider_rw().unwrap();
378 provider_rw
379 .tx_mut()
380 .put::<tables::BlockBodyIndices>(
381 1,
382 StoredBlockBodyIndices { first_tx_num: 0, tx_count: 2 },
383 )
384 .unwrap();
385 let mut writer =
386 provider_rw.get_static_file_writer(0, StaticFileSegment::Transactions).unwrap();
387 writer.increment_block(0).unwrap();
388 writer.increment_block(1).unwrap();
389 writer.append_transaction(0, &random_signed_tx(&mut rng)).unwrap();
390 writer.append_transaction(1, &random_signed_tx(&mut rng)).unwrap();
391 drop(writer);
392 provider_rw.commit().unwrap();
393
394 let range_output = exec_input
395 .next_block_range_with_transaction_threshold(&provider_factory, 10)
396 .unwrap()
397 .unwrap();
398 assert_eq!(range_output.tx_range, 0..2);
399 assert_eq!(range_output.block_range, 1..=1);
400 assert!(range_output.is_final_range);
401 }
402
403 {
405 let exec_input =
406 ExecInput { target: Some(2), checkpoint: Some(StageCheckpoint::new(1)) };
407
408 let mut provider_rw = provider_factory.provider_rw().unwrap();
409 provider_rw
410 .tx_mut()
411 .put::<tables::BlockBodyIndices>(
412 2,
413 StoredBlockBodyIndices { first_tx_num: 2, tx_count: 1 },
414 )
415 .unwrap();
416 let mut writer =
417 provider_rw.get_static_file_writer(1, StaticFileSegment::Transactions).unwrap();
418 writer.increment_block(2).unwrap();
419 writer.append_transaction(2, &random_signed_tx(&mut rng)).unwrap();
420 drop(writer);
421 provider_rw.commit().unwrap();
422
423 let range_output = exec_input
424 .next_block_range_with_transaction_threshold(&provider_factory, 10)
425 .unwrap()
426 .unwrap();
427 assert_eq!(range_output.tx_range, 2..3);
428 assert_eq!(range_output.block_range, 2..=2);
429 assert!(range_output.is_final_range);
430 }
431
432 {
434 let exec_input = ExecInput { target: Some(2), checkpoint: None };
435
436 provider_factory
437 .static_file_provider()
438 .delete_jar(StaticFileSegment::Transactions, 0)
439 .unwrap();
440 provider_factory
441 .static_file_provider()
442 .delete_jar(StaticFileSegment::Transactions, 1)
443 .unwrap();
444
445 let range_output = exec_input
446 .next_block_range_with_transaction_threshold(&provider_factory, 10)
447 .unwrap()
448 .unwrap();
449 assert_eq!(range_output.tx_range, 2..3);
450 assert_eq!(range_output.block_range, 2..=2);
451 assert!(range_output.is_final_range);
452 }
453
454 {
456 let exec_input =
457 ExecInput { target: Some(3), checkpoint: Some(StageCheckpoint::new(2)) };
458
459 let mut provider_rw = provider_factory.provider_rw().unwrap();
460 provider_rw
461 .tx_mut()
462 .put::<tables::BlockBodyIndices>(
463 3,
464 StoredBlockBodyIndices { first_tx_num: 3, tx_count: 1 },
465 )
466 .unwrap();
467 let mut writer =
468 provider_rw.get_static_file_writer(1, StaticFileSegment::Transactions).unwrap();
469 writer.increment_block(3).unwrap();
470 writer.append_transaction(3, &random_signed_tx(&mut rng)).unwrap();
471 drop(writer);
472 provider_rw.commit().unwrap();
473
474 let range_output = exec_input
475 .next_block_range_with_transaction_threshold(&provider_factory, 10)
476 .unwrap()
477 .unwrap();
478 assert_eq!(range_output.tx_range, 3..4);
479 assert_eq!(range_output.block_range, 3..=3);
480 assert!(range_output.is_final_range);
481 }
482 }
483}