1use std::{fmt, str::FromStr, time::Duration};
4
5use crate::version::default_client_version;
6use clap::{
7 builder::{PossibleValue, TypedValueParser},
8 error::ErrorKind,
9 Arg, Args, Command, Error,
10};
11use reth_db::{mdbx::MaxReadTransactionDuration, ClientVersion};
12use reth_storage_errors::db::LogLevel;
13
14#[derive(Debug, Args, PartialEq, Eq, Default, Clone, Copy)]
16#[command(next_help_heading = "Database")]
17pub struct DatabaseArgs {
18 #[arg(long = "db.log-level", value_parser = LogLevelValueParser::default())]
20 pub log_level: Option<LogLevel>,
21 #[arg(long = "db.exclusive")]
24 pub exclusive: Option<bool>,
25 #[arg(long = "db.max-size", value_parser = parse_byte_size)]
27 pub max_size: Option<usize>,
28 #[arg(long = "db.growth-step", value_parser = parse_byte_size)]
30 pub growth_step: Option<usize>,
31 #[arg(long = "db.read-transaction-timeout")]
33 pub read_transaction_timeout: Option<u64>,
34}
35
36impl DatabaseArgs {
37 pub fn database_args(&self) -> reth_db::mdbx::DatabaseArguments {
39 self.get_database_args(default_client_version())
40 }
41
42 pub fn get_database_args(
45 &self,
46 client_version: ClientVersion,
47 ) -> reth_db::mdbx::DatabaseArguments {
48 let max_read_transaction_duration = match self.read_transaction_timeout {
49 None => None, Some(0) => Some(MaxReadTransactionDuration::Unbounded), Some(secs) => Some(MaxReadTransactionDuration::Set(Duration::from_secs(secs))),
52 };
53
54 reth_db::mdbx::DatabaseArguments::new(client_version)
55 .with_log_level(self.log_level)
56 .with_exclusive(self.exclusive)
57 .with_max_read_transaction_duration(max_read_transaction_duration)
58 .with_geometry_max_size(self.max_size)
59 .with_growth_step(self.growth_step)
60 }
61}
62
63#[derive(Clone, Debug, Default)]
65#[non_exhaustive]
66struct LogLevelValueParser;
67
68impl TypedValueParser for LogLevelValueParser {
69 type Value = LogLevel;
70
71 fn parse_ref(
72 &self,
73 _cmd: &Command,
74 arg: Option<&Arg>,
75 value: &std::ffi::OsStr,
76 ) -> Result<Self::Value, Error> {
77 let val =
78 value.to_str().ok_or_else(|| Error::raw(ErrorKind::InvalidUtf8, "Invalid UTF-8"))?;
79
80 val.parse::<LogLevel>().map_err(|err| {
81 let arg = arg.map(|a| a.to_string()).unwrap_or_else(|| "...".to_owned());
82 let possible_values = LogLevel::value_variants()
83 .iter()
84 .map(|v| format!("- {:?}: {}", v, v.help_message()))
85 .collect::<Vec<_>>()
86 .join("\n");
87 let msg = format!(
88 "Invalid value '{val}' for {arg}: {err}.\n Possible values:\n{possible_values}"
89 );
90 clap::Error::raw(clap::error::ErrorKind::InvalidValue, msg)
91 })
92 }
93
94 fn possible_values(&self) -> Option<Box<dyn Iterator<Item = PossibleValue> + '_>> {
95 let values = LogLevel::value_variants()
96 .iter()
97 .map(|v| PossibleValue::new(v.variant_name()).help(v.help_message()));
98 Some(Box::new(values))
99 }
100}
101
102#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
104pub struct ByteSize(pub usize);
105
106impl From<ByteSize> for usize {
107 fn from(s: ByteSize) -> Self {
108 s.0
109 }
110}
111
112impl FromStr for ByteSize {
113 type Err = String;
114
115 fn from_str(s: &str) -> Result<Self, Self::Err> {
116 let s = s.trim().to_uppercase();
117 let parts: Vec<&str> = s.split_whitespace().collect();
118
119 let (num_str, unit) = match parts.len() {
120 1 => {
121 let (num, unit) =
122 s.split_at(s.find(|c: char| c.is_alphabetic()).unwrap_or(s.len()));
123 (num, unit)
124 }
125 2 => (parts[0], parts[1]),
126 _ => {
127 return Err("Invalid format. Use '<number><unit>' or '<number> <unit>'.".to_string())
128 }
129 };
130
131 let num: usize = num_str.parse().map_err(|_| "Invalid number".to_string())?;
132
133 let multiplier = match unit {
134 "B" | "" => 1, "KB" => 1024,
136 "MB" => 1024 * 1024,
137 "GB" => 1024 * 1024 * 1024,
138 "TB" => 1024 * 1024 * 1024 * 1024,
139 _ => return Err(format!("Invalid unit: {}. Use B, KB, MB, GB, or TB.", unit)),
140 };
141
142 Ok(Self(num * multiplier))
143 }
144}
145
146impl fmt::Display for ByteSize {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 const KB: usize = 1024;
149 const MB: usize = KB * 1024;
150 const GB: usize = MB * 1024;
151 const TB: usize = GB * 1024;
152
153 let (size, unit) = if self.0 >= TB {
154 (self.0 as f64 / TB as f64, "TB")
155 } else if self.0 >= GB {
156 (self.0 as f64 / GB as f64, "GB")
157 } else if self.0 >= MB {
158 (self.0 as f64 / MB as f64, "MB")
159 } else if self.0 >= KB {
160 (self.0 as f64 / KB as f64, "KB")
161 } else {
162 (self.0 as f64, "B")
163 };
164
165 write!(f, "{:.2}{}", size, unit)
166 }
167}
168
169fn parse_byte_size(s: &str) -> Result<usize, String> {
171 s.parse::<ByteSize>().map(Into::into)
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use clap::Parser;
178 use reth_db::mdbx::{GIGABYTE, KILOBYTE, MEGABYTE, TERABYTE};
179
180 #[derive(Parser)]
182 struct CommandParser<T: Args> {
183 #[command(flatten)]
184 args: T,
185 }
186
187 #[test]
188 fn test_default_database_args() {
189 let default_args = DatabaseArgs::default();
190 let args = CommandParser::<DatabaseArgs>::parse_from(["reth"]).args;
191 assert_eq!(args, default_args);
192 }
193
194 #[test]
195 fn test_command_parser_with_valid_max_size() {
196 let cmd = CommandParser::<DatabaseArgs>::try_parse_from([
197 "reth",
198 "--db.max-size",
199 "4398046511104",
200 ])
201 .unwrap();
202 assert_eq!(cmd.args.max_size, Some(TERABYTE * 4));
203 }
204
205 #[test]
206 fn test_command_parser_with_invalid_max_size() {
207 let result =
208 CommandParser::<DatabaseArgs>::try_parse_from(["reth", "--db.max-size", "invalid"]);
209 assert!(result.is_err());
210 }
211
212 #[test]
213 fn test_command_parser_with_valid_growth_step() {
214 let cmd = CommandParser::<DatabaseArgs>::try_parse_from([
215 "reth",
216 "--db.growth-step",
217 "4294967296",
218 ])
219 .unwrap();
220 assert_eq!(cmd.args.growth_step, Some(GIGABYTE * 4));
221 }
222
223 #[test]
224 fn test_command_parser_with_invalid_growth_step() {
225 let result =
226 CommandParser::<DatabaseArgs>::try_parse_from(["reth", "--db.growth-step", "invalid"]);
227 assert!(result.is_err());
228 }
229
230 #[test]
231 fn test_command_parser_with_valid_max_size_and_growth_step_from_str() {
232 let cmd = CommandParser::<DatabaseArgs>::try_parse_from([
233 "reth",
234 "--db.max-size",
235 "2TB",
236 "--db.growth-step",
237 "1GB",
238 ])
239 .unwrap();
240 assert_eq!(cmd.args.max_size, Some(TERABYTE * 2));
241 assert_eq!(cmd.args.growth_step, Some(GIGABYTE));
242
243 let cmd = CommandParser::<DatabaseArgs>::try_parse_from([
244 "reth",
245 "--db.max-size",
246 "12MB",
247 "--db.growth-step",
248 "2KB",
249 ])
250 .unwrap();
251 assert_eq!(cmd.args.max_size, Some(MEGABYTE * 12));
252 assert_eq!(cmd.args.growth_step, Some(KILOBYTE * 2));
253
254 let cmd = CommandParser::<DatabaseArgs>::try_parse_from([
256 "reth",
257 "--db.max-size",
258 "12 MB",
259 "--db.growth-step",
260 "2 KB",
261 ])
262 .unwrap();
263 assert_eq!(cmd.args.max_size, Some(MEGABYTE * 12));
264 assert_eq!(cmd.args.growth_step, Some(KILOBYTE * 2));
265
266 let cmd = CommandParser::<DatabaseArgs>::try_parse_from([
267 "reth",
268 "--db.max-size",
269 "1073741824",
270 "--db.growth-step",
271 "1048576",
272 ])
273 .unwrap();
274 assert_eq!(cmd.args.max_size, Some(GIGABYTE));
275 assert_eq!(cmd.args.growth_step, Some(MEGABYTE));
276 }
277
278 #[test]
279 fn test_command_parser_max_size_and_growth_step_from_str_invalid_unit() {
280 let result =
281 CommandParser::<DatabaseArgs>::try_parse_from(["reth", "--db.growth-step", "1 PB"]);
282 assert!(result.is_err());
283
284 let result =
285 CommandParser::<DatabaseArgs>::try_parse_from(["reth", "--db.max-size", "2PB"]);
286 assert!(result.is_err());
287 }
288
289 #[test]
290 fn test_possible_values() {
291 let parser = LogLevelValueParser;
293
294 let possible_values: Vec<PossibleValue> = parser.possible_values().unwrap().collect();
296
297 let expected_values = vec![
299 PossibleValue::new("fatal")
300 .help("Enables logging for critical conditions, i.e. assertion failures"),
301 PossibleValue::new("error").help("Enables logging for error conditions"),
302 PossibleValue::new("warn").help("Enables logging for warning conditions"),
303 PossibleValue::new("notice")
304 .help("Enables logging for normal but significant condition"),
305 PossibleValue::new("verbose").help("Enables logging for verbose informational"),
306 PossibleValue::new("debug").help("Enables logging for debug-level messages"),
307 PossibleValue::new("trace").help("Enables logging for trace debug-level messages"),
308 PossibleValue::new("extra").help("Enables logging for extra debug-level messages"),
309 ];
310
311 assert_eq!(possible_values.len(), expected_values.len());
313 for (actual, expected) in possible_values.iter().zip(expected_values.iter()) {
314 assert_eq!(actual.get_name(), expected.get_name());
315 assert_eq!(actual.get_help(), expected.get_help());
316 }
317 }
318
319 #[test]
320 fn test_command_parser_with_valid_log_level() {
321 let cmd =
322 CommandParser::<DatabaseArgs>::try_parse_from(["reth", "--db.log-level", "Debug"])
323 .unwrap();
324 assert_eq!(cmd.args.log_level, Some(LogLevel::Debug));
325 }
326
327 #[test]
328 fn test_command_parser_with_invalid_log_level() {
329 let result =
330 CommandParser::<DatabaseArgs>::try_parse_from(["reth", "--db.log-level", "invalid"]);
331 assert!(result.is_err());
332 }
333
334 #[test]
335 fn test_command_parser_without_log_level() {
336 let cmd = CommandParser::<DatabaseArgs>::try_parse_from(["reth"]).unwrap();
337 assert_eq!(cmd.args.log_level, None);
338 }
339}