Skip to main content

reth_cli_commands/
download.rs

1use crate::common::EnvironmentArgs;
2use clap::Parser;
3use eyre::Result;
4use lz4::Decoder;
5use reqwest::{blocking::Client as BlockingClient, header::RANGE, Client, StatusCode};
6use reth_chainspec::{EthChainSpec, EthereumHardforks};
7use reth_cli::chainspec::ChainSpecParser;
8use reth_fs_util as fs;
9use std::{
10    borrow::Cow,
11    fs::OpenOptions,
12    io::{self, BufWriter, Read, Write},
13    path::{Path, PathBuf},
14    sync::{Arc, OnceLock},
15    time::{Duration, Instant},
16};
17use tar::Archive;
18use tokio::task;
19use tracing::info;
20use url::Url;
21use zstd::stream::read::Decoder as ZstdDecoder;
22
23const BYTE_UNITS: [&str; 4] = ["B", "KB", "MB", "GB"];
24const MERKLE_BASE_URL: &str = "https://downloads.merkle.io";
25const EXTENSION_TAR_LZ4: &str = ".tar.lz4";
26const EXTENSION_TAR_ZSTD: &str = ".tar.zst";
27
28/// Global static download defaults
29static DOWNLOAD_DEFAULTS: OnceLock<DownloadDefaults> = OnceLock::new();
30
31/// Download configuration defaults
32///
33/// Global defaults can be set via [`DownloadDefaults::try_init`].
34#[derive(Debug, Clone)]
35pub struct DownloadDefaults {
36    /// List of available snapshot sources
37    pub available_snapshots: Vec<Cow<'static, str>>,
38    /// Default base URL for snapshots
39    pub default_base_url: Cow<'static, str>,
40    /// Default base URL for chain-aware snapshots.
41    ///
42    /// When set, the chain ID is appended to form the full URL: `{base_url}/{chain_id}`.
43    /// For example, given a base URL of `https://snapshots.example.com` and chain ID `1`,
44    /// the resulting URL would be `https://snapshots.example.com/1`.
45    ///
46    /// Falls back to [`default_base_url`](Self::default_base_url) when `None`.
47    pub default_chain_aware_base_url: Option<Cow<'static, str>>,
48    /// Optional custom long help text that overrides the generated help
49    pub long_help: Option<String>,
50}
51
52impl DownloadDefaults {
53    /// Initialize the global download defaults with this configuration
54    pub fn try_init(self) -> Result<(), Self> {
55        DOWNLOAD_DEFAULTS.set(self)
56    }
57
58    /// Get a reference to the global download defaults
59    pub fn get_global() -> &'static DownloadDefaults {
60        DOWNLOAD_DEFAULTS.get_or_init(DownloadDefaults::default_download_defaults)
61    }
62
63    /// Default download configuration with defaults from merkle.io and publicnode
64    pub fn default_download_defaults() -> Self {
65        Self {
66            available_snapshots: vec![
67                Cow::Borrowed("https://www.merkle.io/snapshots (default, mainnet archive)"),
68                Cow::Borrowed("https://publicnode.com/snapshots (full nodes & testnets)"),
69            ],
70            default_base_url: Cow::Borrowed(MERKLE_BASE_URL),
71            default_chain_aware_base_url: None,
72            long_help: None,
73        }
74    }
75
76    /// Generates the long help text for the download URL argument using these defaults.
77    ///
78    /// If a custom long_help is set, it will be returned. Otherwise, help text is generated
79    /// from the available_snapshots list.
80    pub fn long_help(&self) -> String {
81        if let Some(ref custom_help) = self.long_help {
82            return custom_help.clone();
83        }
84
85        let mut help = String::from(
86            "Specify a snapshot URL or let the command propose a default one.\n\nAvailable snapshot sources:\n",
87        );
88
89        for source in &self.available_snapshots {
90            help.push_str("- ");
91            help.push_str(source);
92            help.push('\n');
93        }
94
95        help.push_str(
96            "\nIf no URL is provided, the latest archive snapshot for the selected chain\nwill be proposed for download from ",
97        );
98        help.push_str(
99            self.default_chain_aware_base_url.as_deref().unwrap_or(&self.default_base_url),
100        );
101        help.push_str(
102            ".\n\nLocal file:// URLs are also supported for extracting snapshots from disk.",
103        );
104        help
105    }
106
107    /// Add a snapshot source to the list
108    pub fn with_snapshot(mut self, source: impl Into<Cow<'static, str>>) -> Self {
109        self.available_snapshots.push(source.into());
110        self
111    }
112
113    /// Replace all snapshot sources
114    pub fn with_snapshots(mut self, sources: Vec<Cow<'static, str>>) -> Self {
115        self.available_snapshots = sources;
116        self
117    }
118
119    /// Set the default base URL, e.g. `https://downloads.merkle.io`.
120    pub fn with_base_url(mut self, url: impl Into<Cow<'static, str>>) -> Self {
121        self.default_base_url = url.into();
122        self
123    }
124
125    /// Set the default chain-aware base URL.
126    pub fn with_chain_aware_base_url(mut self, url: impl Into<Cow<'static, str>>) -> Self {
127        self.default_chain_aware_base_url = Some(url.into());
128        self
129    }
130
131    /// Builder: Set custom long help text, overriding the generated help
132    pub fn with_long_help(mut self, help: impl Into<String>) -> Self {
133        self.long_help = Some(help.into());
134        self
135    }
136}
137
138impl Default for DownloadDefaults {
139    fn default() -> Self {
140        Self::default_download_defaults()
141    }
142}
143
144#[derive(Debug, Parser)]
145pub struct DownloadCommand<C: ChainSpecParser> {
146    #[command(flatten)]
147    env: EnvironmentArgs<C>,
148
149    /// Custom URL to download the snapshot from
150    #[arg(long, short, long_help = DownloadDefaults::get_global().long_help())]
151    url: Option<String>,
152}
153
154impl<C: ChainSpecParser<ChainSpec: EthChainSpec + EthereumHardforks>> DownloadCommand<C> {
155    pub async fn execute<N>(self) -> Result<()> {
156        let data_dir = self.env.datadir.resolve_datadir(self.env.chain.chain());
157        fs::create_dir_all(&data_dir)?;
158
159        let url = match self.url {
160            Some(url) => url,
161            None => {
162                let url = get_latest_snapshot_url(self.env.chain.chain().id()).await?;
163                info!(target: "reth::cli", "Using default snapshot URL: {}", url);
164                url
165            }
166        };
167
168        info!(target: "reth::cli",
169            chain = %self.env.chain.chain(),
170            dir = ?data_dir.data_dir(),
171            url = %url,
172            "Starting snapshot download and extraction"
173        );
174
175        stream_and_extract(&url, data_dir.data_dir()).await?;
176        info!(target: "reth::cli", "Snapshot downloaded and extracted successfully");
177
178        Ok(())
179    }
180}
181
182impl<C: ChainSpecParser> DownloadCommand<C> {
183    /// Returns the underlying chain being used to run this command
184    pub fn chain_spec(&self) -> Option<&Arc<C::ChainSpec>> {
185        Some(&self.env.chain)
186    }
187}
188
189// Monitor process status and display progress every 100ms
190// to avoid overwhelming stdout
191struct DownloadProgress {
192    downloaded: u64,
193    total_size: u64,
194    last_displayed: Instant,
195    started_at: Instant,
196}
197
198impl DownloadProgress {
199    /// Creates new progress tracker with given total size
200    fn new(total_size: u64) -> Self {
201        let now = Instant::now();
202        Self { downloaded: 0, total_size, last_displayed: now, started_at: now }
203    }
204
205    /// Converts bytes to human readable format (B, KB, MB, GB)
206    fn format_size(size: u64) -> String {
207        let mut size = size as f64;
208        let mut unit_index = 0;
209
210        while size >= 1024.0 && unit_index < BYTE_UNITS.len() - 1 {
211            size /= 1024.0;
212            unit_index += 1;
213        }
214
215        format!("{:.2} {}", size, BYTE_UNITS[unit_index])
216    }
217
218    /// Format duration as human readable string
219    fn format_duration(duration: Duration) -> String {
220        let secs = duration.as_secs();
221        if secs < 60 {
222            format!("{secs}s")
223        } else if secs < 3600 {
224            format!("{}m {}s", secs / 60, secs % 60)
225        } else {
226            format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
227        }
228    }
229
230    /// Updates progress bar
231    fn update(&mut self, chunk_size: u64) -> Result<()> {
232        self.downloaded += chunk_size;
233
234        // Only update display at most 10 times per second for efficiency
235        if self.last_displayed.elapsed() >= Duration::from_millis(100) {
236            let formatted_downloaded = Self::format_size(self.downloaded);
237            let formatted_total = Self::format_size(self.total_size);
238            let progress = (self.downloaded as f64 / self.total_size as f64) * 100.0;
239
240            // Calculate ETA based on current speed
241            let elapsed = self.started_at.elapsed();
242            let eta = if self.downloaded > 0 {
243                let remaining = self.total_size.saturating_sub(self.downloaded);
244                let speed = self.downloaded as f64 / elapsed.as_secs_f64();
245                if speed > 0.0 {
246                    Duration::from_secs_f64(remaining as f64 / speed)
247                } else {
248                    Duration::ZERO
249                }
250            } else {
251                Duration::ZERO
252            };
253            let eta_str = Self::format_duration(eta);
254
255            // Pad with spaces to clear any previous longer line
256            print!(
257                "\rDownloading and extracting... {progress:.2}% ({formatted_downloaded} / {formatted_total}) ETA: {eta_str}     ",
258            );
259            io::stdout().flush()?;
260            self.last_displayed = Instant::now();
261        }
262
263        Ok(())
264    }
265}
266
267/// Adapter to track progress while reading
268struct ProgressReader<R> {
269    reader: R,
270    progress: DownloadProgress,
271}
272
273impl<R: Read> ProgressReader<R> {
274    fn new(reader: R, total_size: u64) -> Self {
275        Self { reader, progress: DownloadProgress::new(total_size) }
276    }
277}
278
279impl<R: Read> Read for ProgressReader<R> {
280    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
281        let bytes = self.reader.read(buf)?;
282        if bytes > 0 &&
283            let Err(e) = self.progress.update(bytes as u64)
284        {
285            return Err(io::Error::other(e));
286        }
287        Ok(bytes)
288    }
289}
290
291/// Supported compression formats for snapshots
292#[derive(Debug, Clone, Copy)]
293enum CompressionFormat {
294    Lz4,
295    Zstd,
296}
297
298impl CompressionFormat {
299    /// Detect compression format from file extension
300    fn from_url(url: &str) -> Result<Self> {
301        let path =
302            Url::parse(url).map(|u| u.path().to_string()).unwrap_or_else(|_| url.to_string());
303
304        if path.ends_with(EXTENSION_TAR_LZ4) {
305            Ok(Self::Lz4)
306        } else if path.ends_with(EXTENSION_TAR_ZSTD) {
307            Ok(Self::Zstd)
308        } else {
309            Err(eyre::eyre!(
310                "Unsupported file format. Expected .tar.lz4 or .tar.zst, got: {}",
311                path
312            ))
313        }
314    }
315}
316
317/// Extracts a compressed tar archive to the target directory with progress tracking.
318fn extract_archive<R: Read>(
319    reader: R,
320    total_size: u64,
321    format: CompressionFormat,
322    target_dir: &Path,
323) -> Result<()> {
324    let progress_reader = ProgressReader::new(reader, total_size);
325
326    match format {
327        CompressionFormat::Lz4 => {
328            let decoder = Decoder::new(progress_reader)?;
329            Archive::new(decoder).unpack(target_dir)?;
330        }
331        CompressionFormat::Zstd => {
332            let decoder = ZstdDecoder::new(progress_reader)?;
333            Archive::new(decoder).unpack(target_dir)?;
334        }
335    }
336
337    info!(target: "reth::cli", "Extraction complete.");
338    Ok(())
339}
340
341/// Extracts a snapshot from a local file.
342fn extract_from_file(path: &Path, format: CompressionFormat, target_dir: &Path) -> Result<()> {
343    let file = std::fs::File::open(path)?;
344    let total_size = file.metadata()?.len();
345    extract_archive(file, total_size, format, target_dir)
346}
347
348const MAX_DOWNLOAD_RETRIES: u32 = 10;
349const RETRY_BACKOFF_SECS: u64 = 5;
350
351/// Wrapper that tracks download progress while writing data.
352/// Used with [`io::copy`] to display progress during downloads.
353struct ProgressWriter<W> {
354    inner: W,
355    progress: DownloadProgress,
356}
357
358impl<W: Write> Write for ProgressWriter<W> {
359    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
360        let n = self.inner.write(buf)?;
361        let _ = self.progress.update(n as u64);
362        Ok(n)
363    }
364
365    fn flush(&mut self) -> io::Result<()> {
366        self.inner.flush()
367    }
368}
369
370/// Downloads a file with resume support using HTTP Range requests.
371/// Automatically retries on failure, resuming from where it left off.
372/// Returns the path to the downloaded file and its total size.
373fn resumable_download(url: &str, target_dir: &Path) -> Result<(PathBuf, u64)> {
374    let file_name = Url::parse(url)
375        .ok()
376        .and_then(|u| u.path_segments()?.next_back().map(|s| s.to_string()))
377        .unwrap_or_else(|| "snapshot.tar".to_string());
378
379    let final_path = target_dir.join(&file_name);
380    let part_path = target_dir.join(format!("{file_name}.part"));
381
382    let client = BlockingClient::builder().timeout(Duration::from_secs(30)).build()?;
383
384    let mut total_size: Option<u64> = None;
385    let mut last_error: Option<eyre::Error> = None;
386
387    for attempt in 1..=MAX_DOWNLOAD_RETRIES {
388        let existing_size = fs::metadata(&part_path).map(|m| m.len()).unwrap_or(0);
389
390        if let Some(total) = total_size &&
391            existing_size >= total
392        {
393            fs::rename(&part_path, &final_path)?;
394            info!(target: "reth::cli", "Download complete: {}", final_path.display());
395            return Ok((final_path, total));
396        }
397
398        if attempt > 1 {
399            info!(target: "reth::cli",
400                "Retry attempt {}/{} - resuming from {} bytes",
401                attempt, MAX_DOWNLOAD_RETRIES, existing_size
402            );
403        }
404
405        let mut request = client.get(url);
406        if existing_size > 0 {
407            request = request.header(RANGE, format!("bytes={existing_size}-"));
408            if attempt == 1 {
409                info!(target: "reth::cli", "Resuming download from {} bytes", existing_size);
410            }
411        }
412
413        let response = match request.send().and_then(|r| r.error_for_status()) {
414            Ok(r) => r,
415            Err(e) => {
416                last_error = Some(e.into());
417                if attempt < MAX_DOWNLOAD_RETRIES {
418                    info!(target: "reth::cli",
419                        "Download failed, retrying in {} seconds...", RETRY_BACKOFF_SECS
420                    );
421                    std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
422                }
423                continue;
424            }
425        };
426
427        let is_partial = response.status() == StatusCode::PARTIAL_CONTENT;
428
429        let size = if is_partial {
430            response
431                .headers()
432                .get("Content-Range")
433                .and_then(|v| v.to_str().ok())
434                .and_then(|v| v.split('/').next_back())
435                .and_then(|v| v.parse().ok())
436        } else {
437            response.content_length()
438        };
439
440        if total_size.is_none() {
441            total_size = size;
442        }
443
444        let current_total = total_size.ok_or_else(|| {
445            eyre::eyre!("Server did not provide Content-Length or Content-Range header")
446        })?;
447
448        let file = if is_partial && existing_size > 0 {
449            OpenOptions::new()
450                .append(true)
451                .open(&part_path)
452                .map_err(|e| fs::FsPathError::open(e, &part_path))?
453        } else {
454            fs::create_file(&part_path)?
455        };
456
457        let start_offset = if is_partial { existing_size } else { 0 };
458        let mut progress = DownloadProgress::new(current_total);
459        progress.downloaded = start_offset;
460
461        let mut writer = ProgressWriter { inner: BufWriter::new(file), progress };
462        let mut reader = response;
463
464        let copy_result = io::copy(&mut reader, &mut writer);
465        let flush_result = writer.inner.flush();
466        println!();
467
468        if let Err(e) = copy_result.and(flush_result) {
469            last_error = Some(e.into());
470            if attempt < MAX_DOWNLOAD_RETRIES {
471                info!(target: "reth::cli",
472                    "Download interrupted, retrying in {} seconds...", RETRY_BACKOFF_SECS
473                );
474                std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
475            }
476            continue;
477        }
478
479        fs::rename(&part_path, &final_path)?;
480        info!(target: "reth::cli", "Download complete: {}", final_path.display());
481        return Ok((final_path, current_total));
482    }
483
484    Err(last_error
485        .unwrap_or_else(|| eyre::eyre!("Download failed after {} attempts", MAX_DOWNLOAD_RETRIES)))
486}
487
488/// Fetches the snapshot from a remote URL with resume support, then extracts it.
489fn download_and_extract(url: &str, format: CompressionFormat, target_dir: &Path) -> Result<()> {
490    let (downloaded_path, total_size) = resumable_download(url, target_dir)?;
491
492    info!(target: "reth::cli", "Extracting snapshot...");
493    let file = fs::open(&downloaded_path)?;
494    extract_archive(file, total_size, format, target_dir)?;
495
496    fs::remove_file(&downloaded_path)?;
497    info!(target: "reth::cli", "Removed downloaded archive");
498
499    Ok(())
500}
501
502/// Downloads and extracts a snapshot, blocking until finished.
503///
504/// Supports both `file://` URLs for local files and HTTP(S) URLs for remote downloads.
505fn blocking_download_and_extract(url: &str, target_dir: &Path) -> Result<()> {
506    let format = CompressionFormat::from_url(url)?;
507
508    if let Ok(parsed_url) = Url::parse(url) &&
509        parsed_url.scheme() == "file"
510    {
511        let file_path = parsed_url
512            .to_file_path()
513            .map_err(|_| eyre::eyre!("Invalid file:// URL path: {}", url))?;
514        extract_from_file(&file_path, format, target_dir)
515    } else {
516        download_and_extract(url, format, target_dir)
517    }
518}
519
520async fn stream_and_extract(url: &str, target_dir: &Path) -> Result<()> {
521    let target_dir = target_dir.to_path_buf();
522    let url = url.to_string();
523    task::spawn_blocking(move || blocking_download_and_extract(&url, &target_dir)).await??;
524
525    Ok(())
526}
527
528// Builds default URL for latest mainnet archive snapshot using configured defaults
529async fn get_latest_snapshot_url(chain_id: u64) -> Result<String> {
530    let defaults = DownloadDefaults::get_global();
531    let base_url = match &defaults.default_chain_aware_base_url {
532        Some(url) => format!("{url}/{chain_id}"),
533        None => defaults.default_base_url.to_string(),
534    };
535    let latest_url = format!("{base_url}/latest.txt");
536    let filename = Client::new()
537        .get(latest_url)
538        .send()
539        .await?
540        .error_for_status()?
541        .text()
542        .await?
543        .trim()
544        .to_string();
545
546    Ok(format!("{base_url}/{filename}"))
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_download_defaults_builder() {
555        let defaults = DownloadDefaults::default()
556            .with_snapshot("https://example.com/snapshots (example)")
557            .with_base_url("https://example.com");
558
559        assert_eq!(defaults.default_base_url, "https://example.com");
560        assert_eq!(defaults.available_snapshots.len(), 3); // 2 defaults + 1 added
561    }
562
563    #[test]
564    fn test_download_defaults_replace_snapshots() {
565        let defaults = DownloadDefaults::default().with_snapshots(vec![
566            Cow::Borrowed("https://custom1.com"),
567            Cow::Borrowed("https://custom2.com"),
568        ]);
569
570        assert_eq!(defaults.available_snapshots.len(), 2);
571        assert_eq!(defaults.available_snapshots[0], "https://custom1.com");
572    }
573
574    #[test]
575    fn test_long_help_generation() {
576        let defaults = DownloadDefaults::default();
577        let help = defaults.long_help();
578
579        assert!(help.contains("Available snapshot sources:"));
580        assert!(help.contains("merkle.io"));
581        assert!(help.contains("publicnode.com"));
582        assert!(help.contains("file://"));
583    }
584
585    #[test]
586    fn test_long_help_override() {
587        let custom_help = "This is custom help text for downloading snapshots.";
588        let defaults = DownloadDefaults::default().with_long_help(custom_help);
589
590        let help = defaults.long_help();
591        assert_eq!(help, custom_help);
592        assert!(!help.contains("Available snapshot sources:"));
593    }
594
595    #[test]
596    fn test_builder_chaining() {
597        let defaults = DownloadDefaults::default()
598            .with_base_url("https://custom.example.com")
599            .with_snapshot("https://snapshot1.com")
600            .with_snapshot("https://snapshot2.com")
601            .with_long_help("Custom help for snapshots");
602
603        assert_eq!(defaults.default_base_url, "https://custom.example.com");
604        assert_eq!(defaults.available_snapshots.len(), 4); // 2 defaults + 2 added
605        assert_eq!(defaults.long_help, Some("Custom help for snapshots".to_string()));
606    }
607
608    #[test]
609    fn test_compression_format_detection() {
610        assert!(matches!(
611            CompressionFormat::from_url("https://example.com/snapshot.tar.lz4"),
612            Ok(CompressionFormat::Lz4)
613        ));
614        assert!(matches!(
615            CompressionFormat::from_url("https://example.com/snapshot.tar.zst"),
616            Ok(CompressionFormat::Zstd)
617        ));
618        assert!(matches!(
619            CompressionFormat::from_url("file:///path/to/snapshot.tar.lz4"),
620            Ok(CompressionFormat::Lz4)
621        ));
622        assert!(matches!(
623            CompressionFormat::from_url("file:///path/to/snapshot.tar.zst"),
624            Ok(CompressionFormat::Zstd)
625        ));
626        assert!(CompressionFormat::from_url("https://example.com/snapshot.tar.gz").is_err());
627    }
628}