Skip to main content

reth_cli_commands/db/
stage_checkpoints.rs

1//! `reth db stage-checkpoints` command for viewing and setting stage checkpoint values.
2
3use clap::{Args, Parser, Subcommand, ValueEnum};
4use reth_db_common::DbTool;
5use reth_provider::{
6    providers::ProviderNodeTypes, DBProvider, DatabaseProviderFactory, StageCheckpointReader,
7    StageCheckpointWriter,
8};
9use reth_stages::StageId;
10
11use crate::common::AccessRights;
12
13/// `reth db stage-checkpoints` subcommand
14#[derive(Debug, Parser)]
15pub struct Command {
16    #[command(subcommand)]
17    command: Subcommands,
18}
19
20impl Command {
21    /// Returns database access rights required for the command.
22    pub fn access_rights(&self) -> AccessRights {
23        match &self.command {
24            Subcommands::Get { .. } => AccessRights::RO,
25            Subcommands::Set(_) => AccessRights::RW,
26        }
27    }
28
29    /// Execute the command
30    pub fn execute<N: ProviderNodeTypes>(self, tool: &DbTool<N>) -> eyre::Result<()> {
31        match self.command {
32            Subcommands::Get { stage } => Self::get(tool, stage),
33            Subcommands::Set(args) => Self::set(tool, args),
34        }
35    }
36
37    fn get<N: ProviderNodeTypes>(tool: &DbTool<N>, stage: Option<StageArg>) -> eyre::Result<()> {
38        let provider = tool.provider_factory.provider()?;
39
40        match stage {
41            Some(stage) => {
42                let stage_id = stage.into();
43                let checkpoint = provider.get_stage_checkpoint(stage_id)?;
44                println!("{stage_id}: {checkpoint:?}");
45            }
46            None => {
47                let mut checkpoints = provider.get_all_checkpoints()?;
48                checkpoints.sort_by(|a, b| a.0.cmp(&b.0));
49                for (stage, checkpoint) in checkpoints {
50                    println!("{stage}: {checkpoint:?}");
51                }
52            }
53        }
54
55        Ok(())
56    }
57
58    fn set<N: ProviderNodeTypes>(tool: &DbTool<N>, args: SetArgs) -> eyre::Result<()> {
59        let stage_id: StageId = args.stage.into();
60        let provider_rw = tool.provider_factory.database_provider_rw()?;
61
62        let previous = provider_rw.get_stage_checkpoint(stage_id)?;
63        let mut checkpoint = previous.unwrap_or_default();
64        checkpoint.block_number = args.block_number;
65
66        if args.clear_stage_unit {
67            checkpoint.stage_checkpoint = None;
68        }
69
70        provider_rw.save_stage_checkpoint(stage_id, checkpoint)?;
71
72        provider_rw.commit()?;
73
74        println!("Updated checkpoint for {stage_id}: {checkpoint:?}");
75
76        Ok(())
77    }
78}
79
80#[derive(Debug, Subcommand)]
81enum Subcommands {
82    /// Get stage checkpoint(s) from database.
83    Get {
84        /// Specific stage to query. If omitted, shows all stages.
85        #[arg(long, value_enum)]
86        stage: Option<StageArg>,
87    },
88    /// Set a stage checkpoint.
89    Set(SetArgs),
90}
91
92/// Arguments for the `set` subcommand.
93#[derive(Debug, Args)]
94pub struct SetArgs {
95    /// Stage to update.
96    #[arg(long, value_enum)]
97    stage: StageArg,
98
99    /// Block number to set as stage checkpoint.
100    #[arg(long)]
101    block_number: u64,
102
103    /// Clear stage-specific unit checkpoint payload.
104    #[arg(long)]
105    clear_stage_unit: bool,
106}
107
108/// CLI-friendly stage names.
109#[derive(Debug, Clone, Copy, ValueEnum)]
110#[clap(rename_all = "kebab-case")]
111pub enum StageArg {
112    Era,
113    Headers,
114    Bodies,
115    SenderRecovery,
116    Execution,
117    PruneSenderRecovery,
118    MerkleUnwind,
119    AccountHashing,
120    StorageHashing,
121    MerkleExecute,
122    TransactionLookup,
123    IndexStorageHistory,
124    IndexAccountHistory,
125    Prune,
126    Finish,
127}
128
129impl From<StageArg> for StageId {
130    fn from(arg: StageArg) -> Self {
131        match arg {
132            StageArg::Era => Self::Era,
133            StageArg::Headers => Self::Headers,
134            StageArg::Bodies => Self::Bodies,
135            StageArg::SenderRecovery => Self::SenderRecovery,
136            StageArg::Execution => Self::Execution,
137            StageArg::PruneSenderRecovery => Self::PruneSenderRecovery,
138            StageArg::MerkleUnwind => Self::MerkleUnwind,
139            StageArg::AccountHashing => Self::AccountHashing,
140            StageArg::StorageHashing => Self::StorageHashing,
141            StageArg::MerkleExecute => Self::MerkleExecute,
142            StageArg::TransactionLookup => Self::TransactionLookup,
143            StageArg::IndexStorageHistory => Self::IndexStorageHistory,
144            StageArg::IndexAccountHistory => Self::IndexAccountHistory,
145            StageArg::Prune => Self::Prune,
146            StageArg::Finish => Self::Finish,
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use clap::Parser;
155    use reth_provider::{
156        test_utils::create_test_provider_factory, DBProvider, DatabaseProviderFactory,
157        StageCheckpointReader, StageCheckpointWriter,
158    };
159    use reth_stages::StageCheckpoint;
160
161    #[test]
162    fn parse_set_args() {
163        let command = Command::parse_from([
164            "stage-checkpoints",
165            "set",
166            "--stage",
167            "headers",
168            "--block-number",
169            "123",
170        ]);
171
172        assert!(matches!(
173            command.command,
174            Subcommands::Set(SetArgs {
175                stage: StageArg::Headers,
176                block_number: 123,
177                clear_stage_unit: false,
178            })
179        ));
180    }
181
182    #[test]
183    fn set_overwrites_block_number() {
184        let provider_factory = create_test_provider_factory();
185        let tool = DbTool::new(provider_factory.clone()).expect("db tool");
186
187        {
188            let provider_rw = provider_factory.database_provider_rw().expect("rw provider");
189            provider_rw
190                .save_stage_checkpoint(StageId::Headers, StageCheckpoint::new(10))
191                .expect("save checkpoint");
192            provider_rw.commit().expect("commit initial checkpoint");
193        }
194
195        let command = Command {
196            command: Subcommands::Set(SetArgs {
197                stage: StageArg::Headers,
198                block_number: 42,
199                clear_stage_unit: false,
200            }),
201        };
202
203        command.execute(&tool).expect("execute command");
204
205        let provider = provider_factory.provider().expect("provider");
206        let checkpoint = provider
207            .get_stage_checkpoint(StageId::Headers)
208            .expect("get stage checkpoint")
209            .expect("missing stage checkpoint");
210
211        assert_eq!(checkpoint.block_number, 42);
212    }
213
214    #[test]
215    fn set_preserves_stage_unit_checkpoint_unless_cleared() {
216        let provider_factory = create_test_provider_factory();
217        let tool = DbTool::new(provider_factory.clone()).expect("db tool");
218
219        {
220            let provider_rw = provider_factory.database_provider_rw().expect("rw provider");
221            let checkpoint = StageCheckpoint::new(10).with_block_range(&StageId::Execution, 5, 10);
222            provider_rw
223                .save_stage_checkpoint(StageId::Execution, checkpoint)
224                .expect("save checkpoint");
225            provider_rw.commit().expect("commit initial checkpoint");
226        }
227
228        Command {
229            command: Subcommands::Set(SetArgs {
230                stage: StageArg::Execution,
231                block_number: 11,
232                clear_stage_unit: false,
233            }),
234        }
235        .execute(&tool)
236        .expect("execute command");
237
238        let provider = provider_factory.provider().expect("provider");
239        let checkpoint = provider
240            .get_stage_checkpoint(StageId::Execution)
241            .expect("get stage checkpoint")
242            .expect("missing stage checkpoint");
243        assert!(checkpoint.stage_checkpoint.is_some());
244
245        Command {
246            command: Subcommands::Set(SetArgs {
247                stage: StageArg::Execution,
248                block_number: 12,
249                clear_stage_unit: true,
250            }),
251        }
252        .execute(&tool)
253        .expect("execute command");
254
255        let checkpoint = provider_factory
256            .provider()
257            .expect("provider")
258            .get_stage_checkpoint(StageId::Execution)
259            .expect("get stage checkpoint")
260            .expect("missing stage checkpoint");
261        assert!(checkpoint.stage_checkpoint.is_none());
262    }
263
264    #[test]
265    fn set_preserves_checkpoint_progress() {
266        let provider_factory = create_test_provider_factory();
267        let tool = DbTool::new(provider_factory.clone()).expect("db tool");
268
269        {
270            let provider_rw = provider_factory.database_provider_rw().expect("rw provider");
271            provider_rw
272                .save_stage_checkpoint(StageId::MerkleExecute, StageCheckpoint::new(10))
273                .expect("save checkpoint");
274            provider_rw
275                .save_stage_checkpoint_progress(StageId::MerkleExecute, vec![1, 2, 3])
276                .expect("save progress");
277            provider_rw.commit().expect("commit initial checkpoint");
278        }
279
280        Command {
281            command: Subcommands::Set(SetArgs {
282                stage: StageArg::MerkleExecute,
283                block_number: 20,
284                clear_stage_unit: false,
285            }),
286        }
287        .execute(&tool)
288        .expect("execute command");
289
290        let provider = provider_factory.provider().expect("provider");
291        let progress = provider
292            .get_stage_checkpoint_progress(StageId::MerkleExecute)
293            .expect("get stage checkpoint progress");
294
295        assert_eq!(progress, Some(vec![1, 2, 3]));
296    }
297}