1use crate::common::EnvironmentArgs;
2use clap::Parser;
3use eyre::Result;
4use lz4::Decoder;
5use reqwest::{blocking::Client as BlockingClient, header::RANGE, Client, StatusCode};
6use reth_chainspec::{EthChainSpec, EthereumHardforks};
7use reth_cli::chainspec::ChainSpecParser;
8use reth_fs_util as fs;
9use std::{
10 borrow::Cow,
11 fs::OpenOptions,
12 io::{self, BufWriter, Read, Write},
13 path::{Path, PathBuf},
14 sync::{Arc, OnceLock},
15 time::{Duration, Instant},
16};
17use tar::Archive;
18use tokio::task;
19use tracing::info;
20use url::Url;
21use zstd::stream::read::Decoder as ZstdDecoder;
22
23const BYTE_UNITS: [&str; 4] = ["B", "KB", "MB", "GB"];
24const MERKLE_BASE_URL: &str = "https://downloads.merkle.io";
25const EXTENSION_TAR_LZ4: &str = ".tar.lz4";
26const EXTENSION_TAR_ZSTD: &str = ".tar.zst";
27
28static DOWNLOAD_DEFAULTS: OnceLock<DownloadDefaults> = OnceLock::new();
30
31#[derive(Debug, Clone)]
35pub struct DownloadDefaults {
36 pub available_snapshots: Vec<Cow<'static, str>>,
38 pub default_base_url: Cow<'static, str>,
40 pub long_help: Option<String>,
42}
43
44impl DownloadDefaults {
45 pub fn try_init(self) -> Result<(), Self> {
47 DOWNLOAD_DEFAULTS.set(self)
48 }
49
50 pub fn get_global() -> &'static DownloadDefaults {
52 DOWNLOAD_DEFAULTS.get_or_init(DownloadDefaults::default_download_defaults)
53 }
54
55 pub fn default_download_defaults() -> Self {
57 Self {
58 available_snapshots: vec![
59 Cow::Borrowed("https://www.merkle.io/snapshots (default, mainnet archive)"),
60 Cow::Borrowed("https://publicnode.com/snapshots (full nodes & testnets)"),
61 ],
62 default_base_url: Cow::Borrowed(MERKLE_BASE_URL),
63 long_help: None,
64 }
65 }
66
67 pub fn long_help(&self) -> String {
72 if let Some(ref custom_help) = self.long_help {
73 return custom_help.clone();
74 }
75
76 let mut help = String::from(
77 "Specify a snapshot URL or let the command propose a default one.\n\nAvailable snapshot sources:\n",
78 );
79
80 for source in &self.available_snapshots {
81 help.push_str("- ");
82 help.push_str(source);
83 help.push('\n');
84 }
85
86 help.push_str(
87 "\nIf no URL is provided, the latest mainnet archive snapshot\nwill be proposed for download from ",
88 );
89 help.push_str(self.default_base_url.as_ref());
90 help.push_str(
91 ".\n\nLocal file:// URLs are also supported for extracting snapshots from disk.",
92 );
93 help
94 }
95
96 pub fn with_snapshot(mut self, source: impl Into<Cow<'static, str>>) -> Self {
98 self.available_snapshots.push(source.into());
99 self
100 }
101
102 pub fn with_snapshots(mut self, sources: Vec<Cow<'static, str>>) -> Self {
104 self.available_snapshots = sources;
105 self
106 }
107
108 pub fn with_base_url(mut self, url: impl Into<Cow<'static, str>>) -> Self {
110 self.default_base_url = url.into();
111 self
112 }
113
114 pub fn with_long_help(mut self, help: impl Into<String>) -> Self {
116 self.long_help = Some(help.into());
117 self
118 }
119}
120
121impl Default for DownloadDefaults {
122 fn default() -> Self {
123 Self::default_download_defaults()
124 }
125}
126
127#[derive(Debug, Parser)]
128pub struct DownloadCommand<C: ChainSpecParser> {
129 #[command(flatten)]
130 env: EnvironmentArgs<C>,
131
132 #[arg(long, short, long_help = DownloadDefaults::get_global().long_help())]
134 url: Option<String>,
135}
136
137impl<C: ChainSpecParser<ChainSpec: EthChainSpec + EthereumHardforks>> DownloadCommand<C> {
138 pub async fn execute<N>(self) -> Result<()> {
139 let data_dir = self.env.datadir.resolve_datadir(self.env.chain.chain());
140 fs::create_dir_all(&data_dir)?;
141
142 let url = match self.url {
143 Some(url) => url,
144 None => {
145 let url = get_latest_snapshot_url().await?;
146 info!(target: "reth::cli", "Using default snapshot URL: {}", url);
147 url
148 }
149 };
150
151 info!(target: "reth::cli",
152 chain = %self.env.chain.chain(),
153 dir = ?data_dir.data_dir(),
154 url = %url,
155 "Starting snapshot download and extraction"
156 );
157
158 stream_and_extract(&url, data_dir.data_dir()).await?;
159 info!(target: "reth::cli", "Snapshot downloaded and extracted successfully");
160
161 Ok(())
162 }
163}
164
165impl<C: ChainSpecParser> DownloadCommand<C> {
166 pub fn chain_spec(&self) -> Option<&Arc<C::ChainSpec>> {
168 Some(&self.env.chain)
169 }
170}
171
172struct DownloadProgress {
175 downloaded: u64,
176 total_size: u64,
177 last_displayed: Instant,
178 started_at: Instant,
179}
180
181impl DownloadProgress {
182 fn new(total_size: u64) -> Self {
184 let now = Instant::now();
185 Self { downloaded: 0, total_size, last_displayed: now, started_at: now }
186 }
187
188 fn format_size(size: u64) -> String {
190 let mut size = size as f64;
191 let mut unit_index = 0;
192
193 while size >= 1024.0 && unit_index < BYTE_UNITS.len() - 1 {
194 size /= 1024.0;
195 unit_index += 1;
196 }
197
198 format!("{:.2} {}", size, BYTE_UNITS[unit_index])
199 }
200
201 fn format_duration(duration: Duration) -> String {
203 let secs = duration.as_secs();
204 if secs < 60 {
205 format!("{secs}s")
206 } else if secs < 3600 {
207 format!("{}m {}s", secs / 60, secs % 60)
208 } else {
209 format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
210 }
211 }
212
213 fn update(&mut self, chunk_size: u64) -> Result<()> {
215 self.downloaded += chunk_size;
216
217 if self.last_displayed.elapsed() >= Duration::from_millis(100) {
219 let formatted_downloaded = Self::format_size(self.downloaded);
220 let formatted_total = Self::format_size(self.total_size);
221 let progress = (self.downloaded as f64 / self.total_size as f64) * 100.0;
222
223 let elapsed = self.started_at.elapsed();
225 let eta = if self.downloaded > 0 {
226 let remaining = self.total_size.saturating_sub(self.downloaded);
227 let speed = self.downloaded as f64 / elapsed.as_secs_f64();
228 if speed > 0.0 {
229 Duration::from_secs_f64(remaining as f64 / speed)
230 } else {
231 Duration::ZERO
232 }
233 } else {
234 Duration::ZERO
235 };
236 let eta_str = Self::format_duration(eta);
237
238 print!(
240 "\rDownloading and extracting... {progress:.2}% ({formatted_downloaded} / {formatted_total}) ETA: {eta_str} ",
241 );
242 io::stdout().flush()?;
243 self.last_displayed = Instant::now();
244 }
245
246 Ok(())
247 }
248}
249
250struct ProgressReader<R> {
252 reader: R,
253 progress: DownloadProgress,
254}
255
256impl<R: Read> ProgressReader<R> {
257 fn new(reader: R, total_size: u64) -> Self {
258 Self { reader, progress: DownloadProgress::new(total_size) }
259 }
260}
261
262impl<R: Read> Read for ProgressReader<R> {
263 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
264 let bytes = self.reader.read(buf)?;
265 if bytes > 0 &&
266 let Err(e) = self.progress.update(bytes as u64)
267 {
268 return Err(io::Error::other(e));
269 }
270 Ok(bytes)
271 }
272}
273
274#[derive(Debug, Clone, Copy)]
276enum CompressionFormat {
277 Lz4,
278 Zstd,
279}
280
281impl CompressionFormat {
282 fn from_url(url: &str) -> Result<Self> {
284 let path =
285 Url::parse(url).map(|u| u.path().to_string()).unwrap_or_else(|_| url.to_string());
286
287 if path.ends_with(EXTENSION_TAR_LZ4) {
288 Ok(Self::Lz4)
289 } else if path.ends_with(EXTENSION_TAR_ZSTD) {
290 Ok(Self::Zstd)
291 } else {
292 Err(eyre::eyre!(
293 "Unsupported file format. Expected .tar.lz4 or .tar.zst, got: {}",
294 path
295 ))
296 }
297 }
298}
299
300fn extract_archive<R: Read>(
302 reader: R,
303 total_size: u64,
304 format: CompressionFormat,
305 target_dir: &Path,
306) -> Result<()> {
307 let progress_reader = ProgressReader::new(reader, total_size);
308
309 match format {
310 CompressionFormat::Lz4 => {
311 let decoder = Decoder::new(progress_reader)?;
312 Archive::new(decoder).unpack(target_dir)?;
313 }
314 CompressionFormat::Zstd => {
315 let decoder = ZstdDecoder::new(progress_reader)?;
316 Archive::new(decoder).unpack(target_dir)?;
317 }
318 }
319
320 info!(target: "reth::cli", "Extraction complete.");
321 Ok(())
322}
323
324fn extract_from_file(path: &Path, format: CompressionFormat, target_dir: &Path) -> Result<()> {
326 let file = std::fs::File::open(path)?;
327 let total_size = file.metadata()?.len();
328 extract_archive(file, total_size, format, target_dir)
329}
330
331const MAX_DOWNLOAD_RETRIES: u32 = 10;
332const RETRY_BACKOFF_SECS: u64 = 5;
333
334struct ProgressWriter<W> {
337 inner: W,
338 progress: DownloadProgress,
339}
340
341impl<W: Write> Write for ProgressWriter<W> {
342 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
343 let n = self.inner.write(buf)?;
344 let _ = self.progress.update(n as u64);
345 Ok(n)
346 }
347
348 fn flush(&mut self) -> io::Result<()> {
349 self.inner.flush()
350 }
351}
352
353fn resumable_download(url: &str, target_dir: &Path) -> Result<(PathBuf, u64)> {
357 let file_name = Url::parse(url)
358 .ok()
359 .and_then(|u| u.path_segments()?.next_back().map(|s| s.to_string()))
360 .unwrap_or_else(|| "snapshot.tar".to_string());
361
362 let final_path = target_dir.join(&file_name);
363 let part_path = target_dir.join(format!("{file_name}.part"));
364
365 let client = BlockingClient::builder().timeout(Duration::from_secs(30)).build()?;
366
367 let mut total_size: Option<u64> = None;
368 let mut last_error: Option<eyre::Error> = None;
369
370 for attempt in 1..=MAX_DOWNLOAD_RETRIES {
371 let existing_size = fs::metadata(&part_path).map(|m| m.len()).unwrap_or(0);
372
373 if let Some(total) = total_size &&
374 existing_size >= total
375 {
376 fs::rename(&part_path, &final_path)?;
377 info!(target: "reth::cli", "Download complete: {}", final_path.display());
378 return Ok((final_path, total));
379 }
380
381 if attempt > 1 {
382 info!(target: "reth::cli",
383 "Retry attempt {}/{} - resuming from {} bytes",
384 attempt, MAX_DOWNLOAD_RETRIES, existing_size
385 );
386 }
387
388 let mut request = client.get(url);
389 if existing_size > 0 {
390 request = request.header(RANGE, format!("bytes={existing_size}-"));
391 if attempt == 1 {
392 info!(target: "reth::cli", "Resuming download from {} bytes", existing_size);
393 }
394 }
395
396 let response = match request.send().and_then(|r| r.error_for_status()) {
397 Ok(r) => r,
398 Err(e) => {
399 last_error = Some(e.into());
400 if attempt < MAX_DOWNLOAD_RETRIES {
401 info!(target: "reth::cli",
402 "Download failed, retrying in {} seconds...", RETRY_BACKOFF_SECS
403 );
404 std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
405 }
406 continue;
407 }
408 };
409
410 let is_partial = response.status() == StatusCode::PARTIAL_CONTENT;
411
412 let size = if is_partial {
413 response
414 .headers()
415 .get("Content-Range")
416 .and_then(|v| v.to_str().ok())
417 .and_then(|v| v.split('/').next_back())
418 .and_then(|v| v.parse().ok())
419 } else {
420 response.content_length()
421 };
422
423 if total_size.is_none() {
424 total_size = size;
425 }
426
427 let current_total = total_size.ok_or_else(|| {
428 eyre::eyre!("Server did not provide Content-Length or Content-Range header")
429 })?;
430
431 let file = if is_partial && existing_size > 0 {
432 OpenOptions::new()
433 .append(true)
434 .open(&part_path)
435 .map_err(|e| fs::FsPathError::open(e, &part_path))?
436 } else {
437 fs::create_file(&part_path)?
438 };
439
440 let start_offset = if is_partial { existing_size } else { 0 };
441 let mut progress = DownloadProgress::new(current_total);
442 progress.downloaded = start_offset;
443
444 let mut writer = ProgressWriter { inner: BufWriter::new(file), progress };
445 let mut reader = response;
446
447 let copy_result = io::copy(&mut reader, &mut writer);
448 let flush_result = writer.inner.flush();
449 println!();
450
451 if let Err(e) = copy_result.and(flush_result) {
452 last_error = Some(e.into());
453 if attempt < MAX_DOWNLOAD_RETRIES {
454 info!(target: "reth::cli",
455 "Download interrupted, retrying in {} seconds...", RETRY_BACKOFF_SECS
456 );
457 std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
458 }
459 continue;
460 }
461
462 fs::rename(&part_path, &final_path)?;
463 info!(target: "reth::cli", "Download complete: {}", final_path.display());
464 return Ok((final_path, current_total));
465 }
466
467 Err(last_error
468 .unwrap_or_else(|| eyre::eyre!("Download failed after {} attempts", MAX_DOWNLOAD_RETRIES)))
469}
470
471fn download_and_extract(url: &str, format: CompressionFormat, target_dir: &Path) -> Result<()> {
473 let (downloaded_path, total_size) = resumable_download(url, target_dir)?;
474
475 info!(target: "reth::cli", "Extracting snapshot...");
476 let file = fs::open(&downloaded_path)?;
477 extract_archive(file, total_size, format, target_dir)?;
478
479 fs::remove_file(&downloaded_path)?;
480 info!(target: "reth::cli", "Removed downloaded archive");
481
482 Ok(())
483}
484
485fn blocking_download_and_extract(url: &str, target_dir: &Path) -> Result<()> {
489 let format = CompressionFormat::from_url(url)?;
490
491 if let Ok(parsed_url) = Url::parse(url) &&
492 parsed_url.scheme() == "file"
493 {
494 let file_path = parsed_url
495 .to_file_path()
496 .map_err(|_| eyre::eyre!("Invalid file:// URL path: {}", url))?;
497 extract_from_file(&file_path, format, target_dir)
498 } else {
499 download_and_extract(url, format, target_dir)
500 }
501}
502
503async fn stream_and_extract(url: &str, target_dir: &Path) -> Result<()> {
504 let target_dir = target_dir.to_path_buf();
505 let url = url.to_string();
506 task::spawn_blocking(move || blocking_download_and_extract(&url, &target_dir)).await??;
507
508 Ok(())
509}
510
511async fn get_latest_snapshot_url() -> Result<String> {
513 let base_url = &DownloadDefaults::get_global().default_base_url;
514 let latest_url = format!("{base_url}/latest.txt");
515 let filename = Client::new()
516 .get(latest_url)
517 .send()
518 .await?
519 .error_for_status()?
520 .text()
521 .await?
522 .trim()
523 .to_string();
524
525 Ok(format!("{base_url}/{filename}"))
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_download_defaults_builder() {
534 let defaults = DownloadDefaults::default()
535 .with_snapshot("https://example.com/snapshots (example)")
536 .with_base_url("https://example.com");
537
538 assert_eq!(defaults.default_base_url, "https://example.com");
539 assert_eq!(defaults.available_snapshots.len(), 3); }
541
542 #[test]
543 fn test_download_defaults_replace_snapshots() {
544 let defaults = DownloadDefaults::default().with_snapshots(vec![
545 Cow::Borrowed("https://custom1.com"),
546 Cow::Borrowed("https://custom2.com"),
547 ]);
548
549 assert_eq!(defaults.available_snapshots.len(), 2);
550 assert_eq!(defaults.available_snapshots[0], "https://custom1.com");
551 }
552
553 #[test]
554 fn test_long_help_generation() {
555 let defaults = DownloadDefaults::default();
556 let help = defaults.long_help();
557
558 assert!(help.contains("Available snapshot sources:"));
559 assert!(help.contains("merkle.io"));
560 assert!(help.contains("publicnode.com"));
561 assert!(help.contains("file://"));
562 }
563
564 #[test]
565 fn test_long_help_override() {
566 let custom_help = "This is custom help text for downloading snapshots.";
567 let defaults = DownloadDefaults::default().with_long_help(custom_help);
568
569 let help = defaults.long_help();
570 assert_eq!(help, custom_help);
571 assert!(!help.contains("Available snapshot sources:"));
572 }
573
574 #[test]
575 fn test_builder_chaining() {
576 let defaults = DownloadDefaults::default()
577 .with_base_url("https://custom.example.com")
578 .with_snapshot("https://snapshot1.com")
579 .with_snapshot("https://snapshot2.com")
580 .with_long_help("Custom help for snapshots");
581
582 assert_eq!(defaults.default_base_url, "https://custom.example.com");
583 assert_eq!(defaults.available_snapshots.len(), 4); assert_eq!(defaults.long_help, Some("Custom help for snapshots".to_string()));
585 }
586
587 #[test]
588 fn test_compression_format_detection() {
589 assert!(matches!(
590 CompressionFormat::from_url("https://example.com/snapshot.tar.lz4"),
591 Ok(CompressionFormat::Lz4)
592 ));
593 assert!(matches!(
594 CompressionFormat::from_url("https://example.com/snapshot.tar.zst"),
595 Ok(CompressionFormat::Zstd)
596 ));
597 assert!(matches!(
598 CompressionFormat::from_url("file:///path/to/snapshot.tar.lz4"),
599 Ok(CompressionFormat::Lz4)
600 ));
601 assert!(matches!(
602 CompressionFormat::from_url("file:///path/to/snapshot.tar.zst"),
603 Ok(CompressionFormat::Zstd)
604 ));
605 assert!(CompressionFormat::from_url("https://example.com/snapshot.tar.gz").is_err());
606 }
607}