1use crate::{error::StageError, StageCheckpoint, StageId};
2use alloy_primitives::{BlockNumber, TxNumber};
3use reth_provider::{BlockReader, ProviderError, StaticFileProviderFactory, StaticFileSegment};
4use std::{
5 cmp::{max, min},
6 future::{poll_fn, Future},
7 ops::{Range, RangeInclusive},
8 task::{Context, Poll},
9};
10use tracing::instrument;
11
12#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
14pub struct ExecInput {
15 pub target: Option<BlockNumber>,
17 pub checkpoint: Option<StageCheckpoint>,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
23pub struct BlockRangeOutput {
24 pub block_range: RangeInclusive<BlockNumber>,
26 pub is_final_range: bool,
28}
29
30#[derive(Debug, PartialEq, Eq, Clone)]
32pub struct TransactionRangeOutput {
33 pub tx_range: Range<TxNumber>,
35 pub block_range: RangeInclusive<BlockNumber>,
37 pub is_final_range: bool,
39}
40
41impl ExecInput {
42 pub fn checkpoint(&self) -> StageCheckpoint {
44 self.checkpoint.unwrap_or_default()
45 }
46
47 pub fn next_block(&self) -> BlockNumber {
50 let current_block = self.checkpoint();
51 current_block.block_number + 1
52 }
53
54 pub fn target_reached(&self) -> bool {
56 self.checkpoint().block_number >= self.target()
57 }
58
59 pub fn target(&self) -> BlockNumber {
61 self.target.unwrap_or_default()
62 }
63
64 pub fn next_block_range(&self) -> RangeInclusive<BlockNumber> {
66 self.next_block_range_with_threshold(u64::MAX).block_range
67 }
68
69 pub const fn is_first_range(&self) -> bool {
71 self.checkpoint.is_none()
72 }
73
74 pub fn next_block_range_with_threshold(&self, threshold: u64) -> BlockRangeOutput {
76 let current_block = self.checkpoint();
77 let start = current_block.block_number + 1;
78 let target = self.target();
79
80 let end = min(target, current_block.block_number.saturating_add(threshold));
81
82 let is_final_range = end == target;
83 BlockRangeOutput { block_range: start..=end, is_final_range }
84 }
85
86 #[instrument(level = "debug", target = "sync::stages", skip(provider), ret)]
92 pub fn next_block_range_with_transaction_threshold<Provider>(
93 &self,
94 provider: &Provider,
95 tx_threshold: u64,
96 ) -> Result<Option<TransactionRangeOutput>, StageError>
97 where
98 Provider: StaticFileProviderFactory + BlockReader,
99 {
100 let Some(lowest_transactions_block) =
102 provider.static_file_provider().get_lowest_range_start(StaticFileSegment::Transactions)
103 else {
104 return Ok(None)
105 };
106
107 let start_block = self.next_block().max(lowest_transactions_block);
114 let target_block = self.target();
115
116 if start_block > target_block {
120 return Ok(None)
121 }
122
123 let start_block_body = provider
124 .block_body_indices(start_block)?
125 .ok_or(ProviderError::BlockBodyIndicesNotFound(start_block))?;
126 let first_tx_num = start_block_body.first_tx_num();
127
128 let target_block_body = provider
129 .block_body_indices(target_block)?
130 .ok_or(ProviderError::BlockBodyIndicesNotFound(target_block))?;
131
132 let all_tx_cnt = target_block_body.next_tx_num() - first_tx_num;
134
135 if all_tx_cnt == 0 {
136 return Ok(None)
138 }
139
140 let (end_block, is_final_range, next_tx_num) = if all_tx_cnt <= tx_threshold {
142 (target_block, true, target_block_body.next_tx_num())
143 } else {
144 let end_block_number = provider
147 .block_by_transaction_id(first_tx_num + tx_threshold)?
148 .expect("block of tx must exist");
149 let end_block_body = provider
152 .block_body_indices(end_block_number)?
153 .ok_or(ProviderError::BlockBodyIndicesNotFound(end_block_number))?;
154 (end_block_number, false, end_block_body.next_tx_num())
155 };
156
157 let tx_range = first_tx_num..next_tx_num;
158 Ok(Some(TransactionRangeOutput {
159 tx_range,
160 block_range: start_block..=end_block,
161 is_final_range,
162 }))
163 }
164}
165
166#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
168pub struct UnwindInput {
169 pub checkpoint: StageCheckpoint,
171 pub unwind_to: BlockNumber,
173 pub bad_block: Option<BlockNumber>,
175}
176
177impl UnwindInput {
178 pub fn unwind_block_range(&self) -> RangeInclusive<BlockNumber> {
180 self.unwind_block_range_with_threshold(u64::MAX).0
181 }
182
183 pub fn unwind_block_range_with_threshold(
185 &self,
186 threshold: u64,
187 ) -> (RangeInclusive<BlockNumber>, BlockNumber, bool) {
188 let mut start = self.unwind_to + 1;
190 let end = self.checkpoint;
191
192 start = max(start, end.block_number.saturating_sub(threshold));
193
194 let unwind_to = start - 1;
195
196 let is_final_range = unwind_to == self.unwind_to;
197 (start..=end.block_number, unwind_to, is_final_range)
198 }
199}
200
201#[derive(Debug, PartialEq, Eq, Clone)]
203pub struct ExecOutput {
204 pub checkpoint: StageCheckpoint,
206 pub done: bool,
208}
209
210impl ExecOutput {
211 pub const fn in_progress(checkpoint: StageCheckpoint) -> Self {
213 Self { checkpoint, done: false }
214 }
215
216 pub const fn done(checkpoint: StageCheckpoint) -> Self {
218 Self { checkpoint, done: true }
219 }
220}
221
222#[derive(Debug, PartialEq, Eq, Clone)]
224pub struct UnwindOutput {
225 pub checkpoint: StageCheckpoint,
227}
228
229#[auto_impl::auto_impl(Box)]
241pub trait Stage<Provider>: Send + Sync {
242 fn id(&self) -> StageId;
246
247 fn poll_execute_ready(
272 &mut self,
273 _cx: &mut Context<'_>,
274 _input: ExecInput,
275 ) -> Poll<Result<(), StageError>> {
276 Poll::Ready(Ok(()))
277 }
278
279 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError>;
283
284 fn post_execute_commit(&mut self) -> Result<(), StageError> {
290 Ok(())
291 }
292
293 fn unwind(
295 &mut self,
296 provider: &Provider,
297 input: UnwindInput,
298 ) -> Result<UnwindOutput, StageError>;
299
300 fn post_unwind_commit(&mut self) -> Result<(), StageError> {
306 Ok(())
307 }
308}
309
310pub trait StageExt<Provider>: Stage<Provider> {
312 fn execute_ready(
315 &mut self,
316 input: ExecInput,
317 ) -> impl Future<Output = Result<(), StageError>> + Send {
318 poll_fn(move |cx| self.poll_execute_ready(cx, input))
319 }
320}
321
322impl<Provider, S: Stage<Provider> + ?Sized> StageExt<Provider> for S {}
323
324#[cfg(test)]
325mod tests {
326 use reth_chainspec::MAINNET;
327 use reth_db::test_utils::{
328 create_test_rocksdb_dir, create_test_rw_db, create_test_static_files_dir,
329 };
330 use reth_db_api::{models::StoredBlockBodyIndices, tables, transaction::DbTxMut};
331 use reth_provider::{
332 providers::RocksDBProvider, test_utils::MockNodeTypesWithDB, ProviderFactory,
333 StaticFileProviderBuilder, StaticFileProviderFactory, StaticFileSegment,
334 };
335 use reth_stages_types::StageCheckpoint;
336 use reth_testing_utils::generators::{self, random_signed_tx};
337
338 use crate::ExecInput;
339
340 #[test]
341 fn test_exec_input_next_block_range_with_transaction_threshold() {
342 let mut rng = generators::rng();
343 let provider_factory = ProviderFactory::<MockNodeTypesWithDB>::new(
344 create_test_rw_db(),
345 MAINNET.clone(),
346 StaticFileProviderBuilder::read_write(create_test_static_files_dir().0.keep())
347 .unwrap()
348 .with_blocks_per_file(1)
349 .build()
350 .unwrap(),
351 RocksDBProvider::builder(create_test_rocksdb_dir().0.keep()).build().unwrap(),
352 )
353 .unwrap();
354
355 {
357 let exec_input = ExecInput { target: Some(100), checkpoint: None };
358
359 let range_output = exec_input
360 .next_block_range_with_transaction_threshold(&provider_factory, 10)
361 .unwrap();
362 assert!(range_output.is_none());
363 }
364
365 {
367 let exec_input =
368 ExecInput { target: Some(1), checkpoint: Some(StageCheckpoint::new(10)) };
369
370 let range_output = exec_input
371 .next_block_range_with_transaction_threshold(&provider_factory, 10)
372 .unwrap();
373 assert!(range_output.is_none());
374 }
375
376 {
378 let exec_input = ExecInput { target: Some(1), checkpoint: None };
379
380 let mut provider_rw = provider_factory.provider_rw().unwrap();
381 provider_rw
382 .tx_mut()
383 .put::<tables::BlockBodyIndices>(
384 1,
385 StoredBlockBodyIndices { first_tx_num: 0, tx_count: 2 },
386 )
387 .unwrap();
388 let mut writer =
389 provider_rw.get_static_file_writer(0, StaticFileSegment::Transactions).unwrap();
390 writer.increment_block(0).unwrap();
391 writer.increment_block(1).unwrap();
392 writer.append_transaction(0, &random_signed_tx(&mut rng)).unwrap();
393 writer.append_transaction(1, &random_signed_tx(&mut rng)).unwrap();
394 drop(writer);
395 provider_rw.commit().unwrap();
396
397 let range_output = exec_input
398 .next_block_range_with_transaction_threshold(&provider_factory, 10)
399 .unwrap()
400 .unwrap();
401 assert_eq!(range_output.tx_range, 0..2);
402 assert_eq!(range_output.block_range, 1..=1);
403 assert!(range_output.is_final_range);
404 }
405
406 {
408 let exec_input =
409 ExecInput { target: Some(2), checkpoint: Some(StageCheckpoint::new(1)) };
410
411 let mut provider_rw = provider_factory.provider_rw().unwrap();
412 provider_rw
413 .tx_mut()
414 .put::<tables::BlockBodyIndices>(
415 2,
416 StoredBlockBodyIndices { first_tx_num: 2, tx_count: 1 },
417 )
418 .unwrap();
419 let mut writer =
420 provider_rw.get_static_file_writer(1, StaticFileSegment::Transactions).unwrap();
421 writer.increment_block(2).unwrap();
422 writer.append_transaction(2, &random_signed_tx(&mut rng)).unwrap();
423 drop(writer);
424 provider_rw.commit().unwrap();
425
426 let range_output = exec_input
427 .next_block_range_with_transaction_threshold(&provider_factory, 10)
428 .unwrap()
429 .unwrap();
430 assert_eq!(range_output.tx_range, 2..3);
431 assert_eq!(range_output.block_range, 2..=2);
432 assert!(range_output.is_final_range);
433 }
434
435 {
437 let exec_input = ExecInput { target: Some(2), checkpoint: None };
438
439 provider_factory
440 .static_file_provider()
441 .delete_jar(StaticFileSegment::Transactions, 0)
442 .unwrap();
443 provider_factory
444 .static_file_provider()
445 .delete_jar(StaticFileSegment::Transactions, 1)
446 .unwrap();
447
448 let range_output = exec_input
449 .next_block_range_with_transaction_threshold(&provider_factory, 10)
450 .unwrap()
451 .unwrap();
452 assert_eq!(range_output.tx_range, 2..3);
453 assert_eq!(range_output.block_range, 2..=2);
454 assert!(range_output.is_final_range);
455 }
456
457 {
459 let exec_input =
460 ExecInput { target: Some(3), checkpoint: Some(StageCheckpoint::new(2)) };
461
462 let mut provider_rw = provider_factory.provider_rw().unwrap();
463 provider_rw
464 .tx_mut()
465 .put::<tables::BlockBodyIndices>(
466 3,
467 StoredBlockBodyIndices { first_tx_num: 3, tx_count: 1 },
468 )
469 .unwrap();
470 let mut writer =
471 provider_rw.get_static_file_writer(1, StaticFileSegment::Transactions).unwrap();
472 writer.increment_block(3).unwrap();
473 writer.append_transaction(3, &random_signed_tx(&mut rng)).unwrap();
474 drop(writer);
475 provider_rw.commit().unwrap();
476
477 let range_output = exec_input
478 .next_block_range_with_transaction_threshold(&provider_factory, 10)
479 .unwrap()
480 .unwrap();
481 assert_eq!(range_output.tx_range, 3..4);
482 assert_eq!(range_output.block_range, 3..=3);
483 assert!(range_output.is_final_range);
484 }
485 }
486}