reth_stages_api/pipeline/
set.rs1use crate::{Stage, StageId};
2use std::{
3 collections::HashMap,
4 fmt::{Debug, Formatter},
5};
6
7pub trait StageSet<Provider>: Sized {
14 fn builder(self) -> StageSetBuilder<Provider>;
16
17 fn set<S: Stage<Provider> + 'static>(self, stage: S) -> StageSetBuilder<Provider> {
23 self.builder().set(stage)
24 }
25}
26
27struct StageEntry<Provider> {
28 stage: Box<dyn Stage<Provider>>,
29 enabled: bool,
30}
31
32impl<Provider> Debug for StageEntry<Provider> {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("StageEntry")
35 .field("stage", &self.stage.id())
36 .field("enabled", &self.enabled)
37 .finish()
38 }
39}
40
41pub struct StageSetBuilder<Provider> {
48 stages: HashMap<StageId, StageEntry<Provider>>,
49 order: Vec<StageId>,
50}
51
52impl<Provider> Default for StageSetBuilder<Provider> {
53 fn default() -> Self {
54 Self { stages: HashMap::default(), order: Vec::new() }
55 }
56}
57
58impl<Provider> Debug for StageSetBuilder<Provider> {
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("StageSetBuilder")
61 .field("stages", &self.stages)
62 .field("order", &self.order)
63 .finish()
64 }
65}
66
67impl<Provider> StageSetBuilder<Provider> {
68 fn index_of(&self, stage_id: StageId) -> usize {
69 let index = self.order.iter().position(|&id| id == stage_id);
70
71 index.unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"))
72 }
73
74 fn upsert_stage_state(&mut self, stage: Box<dyn Stage<Provider>>, added_at_index: usize) {
75 let stage_id = stage.id();
76 if self.stages.insert(stage.id(), StageEntry { stage, enabled: true }).is_some() &&
77 let Some(to_remove) = self
78 .order
79 .iter()
80 .enumerate()
81 .find(|(i, id)| *i != added_at_index && **id == stage_id)
82 .map(|(i, _)| i)
83 {
84 self.order.remove(to_remove);
85 }
86 }
87
88 pub fn set<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
94 let entry = self
95 .stages
96 .get_mut(&stage.id())
97 .unwrap_or_else(|| panic!("Stage does not exist in set: {}", stage.id()));
98 entry.stage = Box::new(stage);
99 self
100 }
101
102 pub fn stages(&self) -> impl Iterator<Item = StageId> + '_ {
105 self.order.iter().copied()
106 }
107
108 pub fn replace<S: Stage<Provider> + 'static>(mut self, stage_id: StageId, stage: S) -> Self {
113 self.stages
114 .get(&stage_id)
115 .unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"));
116
117 if stage.id() == stage_id {
118 return self.set(stage);
119 }
120 let index = self.index_of(stage_id);
121 self.stages.remove(&stage_id);
122 self.order[index] = stage.id();
123 self.upsert_stage_state(Box::new(stage), index);
124 self
125 }
126
127 pub fn add_stage<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
131 let target_index = self.order.len();
132 self.order.push(stage.id());
133 self.upsert_stage_state(Box::new(stage), target_index);
134 self
135 }
136
137 pub fn add_stage_opt<S: Stage<Provider> + 'static>(self, stage: Option<S>) -> Self {
141 if let Some(stage) = stage {
142 self.add_stage(stage)
143 } else {
144 self
145 }
146 }
147
148 pub fn add_set<Set: StageSet<Provider>>(mut self, set: Set) -> Self {
153 for stage in set.builder().build() {
154 let target_index = self.order.len();
155 self.order.push(stage.id());
156 self.upsert_stage_state(stage, target_index);
157 }
158 self
159 }
160
161 pub fn add_before<S: Stage<Provider> + 'static>(mut self, stage: S, before: StageId) -> Self {
169 let target_index = self.index_of(before);
170 self.order.insert(target_index, stage.id());
171 self.upsert_stage_state(Box::new(stage), target_index);
172 self
173 }
174
175 pub fn add_after<S: Stage<Provider> + 'static>(mut self, stage: S, after: StageId) -> Self {
183 let target_index = self.index_of(after) + 1;
184 self.order.insert(target_index, stage.id());
185 self.upsert_stage_state(Box::new(stage), target_index);
186 self
187 }
188
189 pub fn enable(mut self, stage_id: StageId) -> Self {
197 let entry =
198 self.stages.get_mut(&stage_id).expect("Cannot enable a stage that is not in the set.");
199 entry.enabled = true;
200 self
201 }
202
203 #[track_caller]
214 pub fn disable(mut self, stage_id: StageId) -> Self {
215 let entry = self
216 .stages
217 .get_mut(&stage_id)
218 .unwrap_or_else(|| panic!("Cannot disable a stage that is not in the set: {stage_id}"));
219 entry.enabled = false;
220 self
221 }
222
223 pub fn disable_all(mut self, stages: &[StageId]) -> Self {
227 for stage_id in stages {
228 let Some(entry) = self.stages.get_mut(stage_id) else { continue };
229 entry.enabled = false;
230 }
231 self
232 }
233
234 #[track_caller]
238 pub fn disable_if<F>(self, stage_id: StageId, f: F) -> Self
239 where
240 F: FnOnce() -> bool,
241 {
242 if f() {
243 return self.disable(stage_id)
244 }
245 self
246 }
247
248 #[track_caller]
252 pub fn disable_all_if<F>(self, stages: &[StageId], f: F) -> Self
253 where
254 F: FnOnce() -> bool,
255 {
256 if f() {
257 return self.disable_all(stages)
258 }
259 self
260 }
261
262 pub fn build(mut self) -> Vec<Box<dyn Stage<Provider>>> {
264 let mut stages = Vec::new();
265 for id in &self.order {
266 if let Some(entry) = self.stages.remove(id) &&
267 entry.enabled
268 {
269 stages.push(entry.stage);
270 }
271 }
272 stages
273 }
274}
275
276impl<Provider> StageSet<Provider> for StageSetBuilder<Provider> {
277 fn builder(self) -> Self {
278 self
279 }
280}