reth_stages_api/
test_utils.rs

1#![allow(missing_docs)]
2
3use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput};
4use std::{
5    collections::VecDeque,
6    sync::{
7        atomic::{AtomicUsize, Ordering},
8        Arc,
9    },
10};
11
12/// A test stage that can be used for testing.
13///
14/// This can be used to mock expected outputs of [`Stage::execute`] and [`Stage::unwind`]
15#[derive(Debug)]
16pub struct TestStage {
17    id: StageId,
18    exec_outputs: VecDeque<Result<ExecOutput, StageError>>,
19    unwind_outputs: VecDeque<Result<UnwindOutput, StageError>>,
20    post_execute_commit_counter: Arc<AtomicUsize>,
21    post_unwind_commit_counter: Arc<AtomicUsize>,
22}
23
24impl TestStage {
25    pub fn new(id: StageId) -> Self {
26        Self {
27            id,
28            exec_outputs: VecDeque::new(),
29            unwind_outputs: VecDeque::new(),
30            post_execute_commit_counter: Arc::new(AtomicUsize::new(0)),
31            post_unwind_commit_counter: Arc::new(AtomicUsize::new(0)),
32        }
33    }
34
35    pub fn with_exec(mut self, exec_outputs: VecDeque<Result<ExecOutput, StageError>>) -> Self {
36        self.exec_outputs = exec_outputs;
37        self
38    }
39
40    pub fn with_unwind(
41        mut self,
42        unwind_outputs: VecDeque<Result<UnwindOutput, StageError>>,
43    ) -> Self {
44        self.unwind_outputs = unwind_outputs;
45        self
46    }
47
48    pub fn add_exec(mut self, output: Result<ExecOutput, StageError>) -> Self {
49        self.exec_outputs.push_back(output);
50        self
51    }
52
53    pub fn add_unwind(mut self, output: Result<UnwindOutput, StageError>) -> Self {
54        self.unwind_outputs.push_back(output);
55        self
56    }
57
58    pub fn with_post_execute_commit_counter(mut self) -> (Self, Arc<AtomicUsize>) {
59        let counter = Arc::new(AtomicUsize::new(0));
60        self.post_execute_commit_counter = counter.clone();
61        (self, counter)
62    }
63
64    pub fn with_post_unwind_commit_counter(mut self) -> (Self, Arc<AtomicUsize>) {
65        let counter = Arc::new(AtomicUsize::new(0));
66        self.post_unwind_commit_counter = counter.clone();
67        (self, counter)
68    }
69}
70
71impl<Provider> Stage<Provider> for TestStage {
72    fn id(&self) -> StageId {
73        self.id
74    }
75
76    fn execute(&mut self, _: &Provider, _input: ExecInput) -> Result<ExecOutput, StageError> {
77        self.exec_outputs
78            .pop_front()
79            .unwrap_or_else(|| panic!("Test stage {} executed too many times.", self.id))
80    }
81
82    fn post_execute_commit(&mut self) -> Result<(), StageError> {
83        self.post_execute_commit_counter.fetch_add(1, Ordering::Relaxed);
84
85        Ok(())
86    }
87
88    fn unwind(&mut self, _: &Provider, _input: UnwindInput) -> Result<UnwindOutput, StageError> {
89        self.unwind_outputs
90            .pop_front()
91            .unwrap_or_else(|| panic!("Test stage {} unwound too many times.", self.id))
92    }
93
94    fn post_unwind_commit(&mut self) -> Result<(), StageError> {
95        self.post_unwind_commit_counter.fetch_add(1, Ordering::Relaxed);
96
97        Ok(())
98    }
99}