1use clap::{Args, Parser, Subcommand, ValueEnum};
4use reth_db_common::DbTool;
5use reth_provider::{
6 providers::ProviderNodeTypes, DBProvider, DatabaseProviderFactory, StageCheckpointReader,
7 StageCheckpointWriter,
8};
9use reth_stages::StageId;
10
11use crate::common::AccessRights;
12
13#[derive(Debug, Parser)]
15pub struct Command {
16 #[command(subcommand)]
17 command: Subcommands,
18}
19
20impl Command {
21 pub fn access_rights(&self) -> AccessRights {
23 match &self.command {
24 Subcommands::Get { .. } => AccessRights::RO,
25 Subcommands::Set(_) => AccessRights::RW,
26 }
27 }
28
29 pub fn execute<N: ProviderNodeTypes>(self, tool: &DbTool<N>) -> eyre::Result<()> {
31 match self.command {
32 Subcommands::Get { stage } => Self::get(tool, stage),
33 Subcommands::Set(args) => Self::set(tool, args),
34 }
35 }
36
37 fn get<N: ProviderNodeTypes>(tool: &DbTool<N>, stage: Option<StageArg>) -> eyre::Result<()> {
38 let provider = tool.provider_factory.provider()?;
39
40 match stage {
41 Some(stage) => {
42 let stage_id = stage.into();
43 let checkpoint = provider.get_stage_checkpoint(stage_id)?;
44 println!("{stage_id}: {checkpoint:?}");
45 }
46 None => {
47 let mut checkpoints = provider.get_all_checkpoints()?;
48 checkpoints.sort_by(|a, b| a.0.cmp(&b.0));
49 for (stage, checkpoint) in checkpoints {
50 println!("{stage}: {checkpoint:?}");
51 }
52 }
53 }
54
55 Ok(())
56 }
57
58 fn set<N: ProviderNodeTypes>(tool: &DbTool<N>, args: SetArgs) -> eyre::Result<()> {
59 let stage_id: StageId = args.stage.into();
60 let provider_rw = tool.provider_factory.database_provider_rw()?;
61
62 let previous = provider_rw.get_stage_checkpoint(stage_id)?;
63 let mut checkpoint = previous.unwrap_or_default();
64 checkpoint.block_number = args.block_number;
65
66 if args.clear_stage_unit {
67 checkpoint.stage_checkpoint = None;
68 }
69
70 provider_rw.save_stage_checkpoint(stage_id, checkpoint)?;
71
72 provider_rw.commit()?;
73
74 println!("Updated checkpoint for {stage_id}: {checkpoint:?}");
75
76 Ok(())
77 }
78}
79
80#[derive(Debug, Subcommand)]
81enum Subcommands {
82 Get {
84 #[arg(long, value_enum)]
86 stage: Option<StageArg>,
87 },
88 Set(SetArgs),
90}
91
92#[derive(Debug, Args)]
94pub struct SetArgs {
95 #[arg(long, value_enum)]
97 stage: StageArg,
98
99 #[arg(long)]
101 block_number: u64,
102
103 #[arg(long)]
105 clear_stage_unit: bool,
106}
107
108#[derive(Debug, Clone, Copy, ValueEnum)]
110#[clap(rename_all = "kebab-case")]
111pub enum StageArg {
112 Era,
113 Headers,
114 Bodies,
115 SenderRecovery,
116 Execution,
117 PruneSenderRecovery,
118 MerkleUnwind,
119 AccountHashing,
120 StorageHashing,
121 MerkleExecute,
122 TransactionLookup,
123 IndexStorageHistory,
124 IndexAccountHistory,
125 Prune,
126 Finish,
127}
128
129impl From<StageArg> for StageId {
130 fn from(arg: StageArg) -> Self {
131 match arg {
132 StageArg::Era => Self::Era,
133 StageArg::Headers => Self::Headers,
134 StageArg::Bodies => Self::Bodies,
135 StageArg::SenderRecovery => Self::SenderRecovery,
136 StageArg::Execution => Self::Execution,
137 StageArg::PruneSenderRecovery => Self::PruneSenderRecovery,
138 StageArg::MerkleUnwind => Self::MerkleUnwind,
139 StageArg::AccountHashing => Self::AccountHashing,
140 StageArg::StorageHashing => Self::StorageHashing,
141 StageArg::MerkleExecute => Self::MerkleExecute,
142 StageArg::TransactionLookup => Self::TransactionLookup,
143 StageArg::IndexStorageHistory => Self::IndexStorageHistory,
144 StageArg::IndexAccountHistory => Self::IndexAccountHistory,
145 StageArg::Prune => Self::Prune,
146 StageArg::Finish => Self::Finish,
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use clap::Parser;
155 use reth_provider::{
156 test_utils::create_test_provider_factory, DBProvider, DatabaseProviderFactory,
157 StageCheckpointReader, StageCheckpointWriter,
158 };
159 use reth_stages::StageCheckpoint;
160
161 #[test]
162 fn parse_set_args() {
163 let command = Command::parse_from([
164 "stage-checkpoints",
165 "set",
166 "--stage",
167 "headers",
168 "--block-number",
169 "123",
170 ]);
171
172 assert!(matches!(
173 command.command,
174 Subcommands::Set(SetArgs {
175 stage: StageArg::Headers,
176 block_number: 123,
177 clear_stage_unit: false,
178 })
179 ));
180 }
181
182 #[test]
183 fn set_overwrites_block_number() {
184 let provider_factory = create_test_provider_factory();
185 let tool = DbTool::new(provider_factory.clone()).expect("db tool");
186
187 {
188 let provider_rw = provider_factory.database_provider_rw().expect("rw provider");
189 provider_rw
190 .save_stage_checkpoint(StageId::Headers, StageCheckpoint::new(10))
191 .expect("save checkpoint");
192 provider_rw.commit().expect("commit initial checkpoint");
193 }
194
195 let command = Command {
196 command: Subcommands::Set(SetArgs {
197 stage: StageArg::Headers,
198 block_number: 42,
199 clear_stage_unit: false,
200 }),
201 };
202
203 command.execute(&tool).expect("execute command");
204
205 let provider = provider_factory.provider().expect("provider");
206 let checkpoint = provider
207 .get_stage_checkpoint(StageId::Headers)
208 .expect("get stage checkpoint")
209 .expect("missing stage checkpoint");
210
211 assert_eq!(checkpoint.block_number, 42);
212 }
213
214 #[test]
215 fn set_preserves_stage_unit_checkpoint_unless_cleared() {
216 let provider_factory = create_test_provider_factory();
217 let tool = DbTool::new(provider_factory.clone()).expect("db tool");
218
219 {
220 let provider_rw = provider_factory.database_provider_rw().expect("rw provider");
221 let checkpoint = StageCheckpoint::new(10).with_block_range(&StageId::Execution, 5, 10);
222 provider_rw
223 .save_stage_checkpoint(StageId::Execution, checkpoint)
224 .expect("save checkpoint");
225 provider_rw.commit().expect("commit initial checkpoint");
226 }
227
228 Command {
229 command: Subcommands::Set(SetArgs {
230 stage: StageArg::Execution,
231 block_number: 11,
232 clear_stage_unit: false,
233 }),
234 }
235 .execute(&tool)
236 .expect("execute command");
237
238 let provider = provider_factory.provider().expect("provider");
239 let checkpoint = provider
240 .get_stage_checkpoint(StageId::Execution)
241 .expect("get stage checkpoint")
242 .expect("missing stage checkpoint");
243 assert!(checkpoint.stage_checkpoint.is_some());
244
245 Command {
246 command: Subcommands::Set(SetArgs {
247 stage: StageArg::Execution,
248 block_number: 12,
249 clear_stage_unit: true,
250 }),
251 }
252 .execute(&tool)
253 .expect("execute command");
254
255 let checkpoint = provider_factory
256 .provider()
257 .expect("provider")
258 .get_stage_checkpoint(StageId::Execution)
259 .expect("get stage checkpoint")
260 .expect("missing stage checkpoint");
261 assert!(checkpoint.stage_checkpoint.is_none());
262 }
263
264 #[test]
265 fn set_preserves_checkpoint_progress() {
266 let provider_factory = create_test_provider_factory();
267 let tool = DbTool::new(provider_factory.clone()).expect("db tool");
268
269 {
270 let provider_rw = provider_factory.database_provider_rw().expect("rw provider");
271 provider_rw
272 .save_stage_checkpoint(StageId::MerkleExecute, StageCheckpoint::new(10))
273 .expect("save checkpoint");
274 provider_rw
275 .save_stage_checkpoint_progress(StageId::MerkleExecute, vec![1, 2, 3])
276 .expect("save progress");
277 provider_rw.commit().expect("commit initial checkpoint");
278 }
279
280 Command {
281 command: Subcommands::Set(SetArgs {
282 stage: StageArg::MerkleExecute,
283 block_number: 20,
284 clear_stage_unit: false,
285 }),
286 }
287 .execute(&tool)
288 .expect("execute command");
289
290 let provider = provider_factory.provider().expect("provider");
291 let progress = provider
292 .get_stage_checkpoint_progress(StageId::MerkleExecute)
293 .expect("get stage checkpoint progress");
294
295 assert_eq!(progress, Some(vec![1, 2, 3]));
296 }
297}