reth_stages_api/pipeline/
set.rs

1use crate::{Stage, StageId};
2use std::{
3    collections::HashMap,
4    fmt::{Debug, Formatter},
5};
6
7/// Combines multiple [`Stage`]s into a single unit.
8///
9/// A [`StageSet`] is a logical chunk of stages that depend on each other. It is up to the
10/// individual stage sets to determine what kind of configuration they expose.
11///
12/// Individual stages in the set can be added, removed and overridden using [`StageSetBuilder`].
13pub trait StageSet<Provider>: Sized {
14    /// Configures the stages in the set.
15    fn builder(self) -> StageSetBuilder<Provider>;
16
17    /// Overrides the given [`Stage`], if it is in this set.
18    ///
19    /// # Panics
20    ///
21    /// Panics if the [`Stage`] is not in this set.
22    fn set<S: Stage<Provider> + 'static>(self, stage: S) -> StageSetBuilder<Provider> {
23        self.builder().set(stage)
24    }
25}
26
27struct StageEntry<Provider> {
28    stage: Box<dyn Stage<Provider>>,
29    enabled: bool,
30}
31
32impl<Provider> Debug for StageEntry<Provider> {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("StageEntry")
35            .field("stage", &self.stage.id())
36            .field("enabled", &self.enabled)
37            .finish()
38    }
39}
40
41/// Helper to create and configure a [`StageSet`].
42///
43/// The builder provides ordering helpers to ensure that stages that depend on each other are added
44/// to the final sync pipeline before/after their dependencies.
45///
46/// Stages inside the set can be disabled, enabled, overridden and reordered.
47pub struct StageSetBuilder<Provider> {
48    stages: HashMap<StageId, StageEntry<Provider>>,
49    order: Vec<StageId>,
50}
51
52impl<Provider> Default for StageSetBuilder<Provider> {
53    fn default() -> Self {
54        Self { stages: HashMap::default(), order: Vec::new() }
55    }
56}
57
58impl<Provider> Debug for StageSetBuilder<Provider> {
59    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("StageSetBuilder")
61            .field("stages", &self.stages)
62            .field("order", &self.order)
63            .finish()
64    }
65}
66
67impl<Provider> StageSetBuilder<Provider> {
68    fn index_of(&self, stage_id: StageId) -> usize {
69        let index = self.order.iter().position(|&id| id == stage_id);
70
71        index.unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"))
72    }
73
74    fn upsert_stage_state(&mut self, stage: Box<dyn Stage<Provider>>, added_at_index: usize) {
75        let stage_id = stage.id();
76        if self.stages.insert(stage.id(), StageEntry { stage, enabled: true }).is_some() {
77            if let Some(to_remove) = self
78                .order
79                .iter()
80                .enumerate()
81                .find(|(i, id)| *i != added_at_index && **id == stage_id)
82                .map(|(i, _)| i)
83            {
84                self.order.remove(to_remove);
85            }
86        }
87    }
88
89    /// Overrides the given [`Stage`], if it is in this set.
90    ///
91    /// # Panics
92    ///
93    /// Panics if the [`Stage`] is not in this set.
94    pub fn set<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
95        let entry = self
96            .stages
97            .get_mut(&stage.id())
98            .unwrap_or_else(|| panic!("Stage does not exist in set: {}", stage.id()));
99        entry.stage = Box::new(stage);
100        self
101    }
102
103    /// Adds the given [`Stage`] at the end of this set.
104    ///
105    /// If the stage was already in the group, it is removed from its previous place.
106    pub fn add_stage<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
107        let target_index = self.order.len();
108        self.order.push(stage.id());
109        self.upsert_stage_state(Box::new(stage), target_index);
110        self
111    }
112
113    /// Adds the given [`Stage`] at the end of this set if it's [`Some`].
114    ///
115    /// If the stage was already in the group, it is removed from its previous place.
116    pub fn add_stage_opt<S: Stage<Provider> + 'static>(self, stage: Option<S>) -> Self {
117        if let Some(stage) = stage {
118            self.add_stage(stage)
119        } else {
120            self
121        }
122    }
123
124    /// Adds the given [`StageSet`] to the end of this set.
125    ///
126    /// If a stage is in both sets, it is removed from its previous place in this set. Because of
127    /// this, it is advisable to merge sets first and re-order stages after if needed.
128    pub fn add_set<Set: StageSet<Provider>>(mut self, set: Set) -> Self {
129        for stage in set.builder().build() {
130            let target_index = self.order.len();
131            self.order.push(stage.id());
132            self.upsert_stage_state(stage, target_index);
133        }
134        self
135    }
136
137    /// Adds the given [`Stage`] before the stage with the given [`StageId`].
138    ///
139    /// If the stage was already in the group, it is removed from its previous place.
140    ///
141    /// # Panics
142    ///
143    /// Panics if the dependency stage is not in this set.
144    pub fn add_before<S: Stage<Provider> + 'static>(mut self, stage: S, before: StageId) -> Self {
145        let target_index = self.index_of(before);
146        self.order.insert(target_index, stage.id());
147        self.upsert_stage_state(Box::new(stage), target_index);
148        self
149    }
150
151    /// Adds the given [`Stage`] after the stage with the given [`StageId`].
152    ///
153    /// If the stage was already in the group, it is removed from its previous place.
154    ///
155    /// # Panics
156    ///
157    /// Panics if the dependency stage is not in this set.
158    pub fn add_after<S: Stage<Provider> + 'static>(mut self, stage: S, after: StageId) -> Self {
159        let target_index = self.index_of(after) + 1;
160        self.order.insert(target_index, stage.id());
161        self.upsert_stage_state(Box::new(stage), target_index);
162        self
163    }
164
165    /// Enables the given stage.
166    ///
167    /// All stages within a [`StageSet`] are enabled by default.
168    ///
169    /// # Panics
170    ///
171    /// Panics if the stage is not in this set.
172    pub fn enable(mut self, stage_id: StageId) -> Self {
173        let entry =
174            self.stages.get_mut(&stage_id).expect("Cannot enable a stage that is not in the set.");
175        entry.enabled = true;
176        self
177    }
178
179    /// Disables the given stage.
180    ///
181    /// The disabled [`Stage`] keeps its place in the set, so it can be used for ordering with
182    /// [`StageSetBuilder::add_before`] or [`StageSetBuilder::add_after`], or it can be re-enabled.
183    ///
184    /// All stages within a [`StageSet`] are enabled by default.
185    ///
186    /// # Panics
187    ///
188    /// Panics if the stage is not in this set.
189    #[track_caller]
190    pub fn disable(mut self, stage_id: StageId) -> Self {
191        let entry = self
192            .stages
193            .get_mut(&stage_id)
194            .unwrap_or_else(|| panic!("Cannot disable a stage that is not in the set: {stage_id}"));
195        entry.enabled = false;
196        self
197    }
198
199    /// Disables all given stages. See [`disable`](Self::disable).
200    ///
201    /// If any of the stages is not in this set, it is ignored.
202    pub fn disable_all(mut self, stages: &[StageId]) -> Self {
203        for stage_id in stages {
204            let Some(entry) = self.stages.get_mut(stage_id) else { continue };
205            entry.enabled = false;
206        }
207        self
208    }
209
210    /// Disables the given stage if the given closure returns true.
211    ///
212    /// See [`Self::disable`]
213    #[track_caller]
214    pub fn disable_if<F>(self, stage_id: StageId, f: F) -> Self
215    where
216        F: FnOnce() -> bool,
217    {
218        if f() {
219            return self.disable(stage_id)
220        }
221        self
222    }
223
224    /// Disables all given stages if the given closure returns true.
225    ///
226    /// See [`Self::disable`]
227    #[track_caller]
228    pub fn disable_all_if<F>(self, stages: &[StageId], f: F) -> Self
229    where
230        F: FnOnce() -> bool,
231    {
232        if f() {
233            return self.disable_all(stages)
234        }
235        self
236    }
237
238    /// Consumes the builder and returns the contained [`Stage`]s in the order specified.
239    pub fn build(mut self) -> Vec<Box<dyn Stage<Provider>>> {
240        let mut stages = Vec::new();
241        for id in &self.order {
242            if let Some(entry) = self.stages.remove(id) {
243                if entry.enabled {
244                    stages.push(entry.stage);
245                }
246            }
247        }
248        stages
249    }
250}
251
252impl<Provider> StageSet<Provider> for StageSetBuilder<Provider> {
253    fn builder(self) -> Self {
254        self
255    }
256}