reth_cli_commands/
download.rs

1use crate::common::EnvironmentArgs;
2use clap::Parser;
3use eyre::Result;
4use lz4::Decoder;
5use reqwest::{blocking::Client as BlockingClient, header::RANGE, Client, StatusCode};
6use reth_chainspec::{EthChainSpec, EthereumHardforks};
7use reth_cli::chainspec::ChainSpecParser;
8use reth_fs_util as fs;
9use std::{
10    borrow::Cow,
11    fs::OpenOptions,
12    io::{self, BufWriter, Read, Write},
13    path::{Path, PathBuf},
14    sync::{Arc, OnceLock},
15    time::{Duration, Instant},
16};
17use tar::Archive;
18use tokio::task;
19use tracing::info;
20use url::Url;
21use zstd::stream::read::Decoder as ZstdDecoder;
22
23const BYTE_UNITS: [&str; 4] = ["B", "KB", "MB", "GB"];
24const MERKLE_BASE_URL: &str = "https://downloads.merkle.io";
25const EXTENSION_TAR_LZ4: &str = ".tar.lz4";
26const EXTENSION_TAR_ZSTD: &str = ".tar.zst";
27
28/// Global static download defaults
29static DOWNLOAD_DEFAULTS: OnceLock<DownloadDefaults> = OnceLock::new();
30
31/// Download configuration defaults
32///
33/// Global defaults can be set via [`DownloadDefaults::try_init`].
34#[derive(Debug, Clone)]
35pub struct DownloadDefaults {
36    /// List of available snapshot sources
37    pub available_snapshots: Vec<Cow<'static, str>>,
38    /// Default base URL for snapshots
39    pub default_base_url: Cow<'static, str>,
40    /// Optional custom long help text that overrides the generated help
41    pub long_help: Option<String>,
42}
43
44impl DownloadDefaults {
45    /// Initialize the global download defaults with this configuration
46    pub fn try_init(self) -> Result<(), Self> {
47        DOWNLOAD_DEFAULTS.set(self)
48    }
49
50    /// Get a reference to the global download defaults
51    pub fn get_global() -> &'static DownloadDefaults {
52        DOWNLOAD_DEFAULTS.get_or_init(DownloadDefaults::default_download_defaults)
53    }
54
55    /// Default download configuration with defaults from merkle.io and publicnode
56    pub fn default_download_defaults() -> Self {
57        Self {
58            available_snapshots: vec![
59                Cow::Borrowed("https://www.merkle.io/snapshots (default, mainnet archive)"),
60                Cow::Borrowed("https://publicnode.com/snapshots (full nodes & testnets)"),
61            ],
62            default_base_url: Cow::Borrowed(MERKLE_BASE_URL),
63            long_help: None,
64        }
65    }
66
67    /// Generates the long help text for the download URL argument using these defaults.
68    ///
69    /// If a custom long_help is set, it will be returned. Otherwise, help text is generated
70    /// from the available_snapshots list.
71    pub fn long_help(&self) -> String {
72        if let Some(ref custom_help) = self.long_help {
73            return custom_help.clone();
74        }
75
76        let mut help = String::from(
77            "Specify a snapshot URL or let the command propose a default one.\n\nAvailable snapshot sources:\n",
78        );
79
80        for source in &self.available_snapshots {
81            help.push_str("- ");
82            help.push_str(source);
83            help.push('\n');
84        }
85
86        help.push_str(
87            "\nIf no URL is provided, the latest mainnet archive snapshot\nwill be proposed for download from ",
88        );
89        help.push_str(self.default_base_url.as_ref());
90        help.push_str(
91            ".\n\nLocal file:// URLs are also supported for extracting snapshots from disk.",
92        );
93        help
94    }
95
96    /// Add a snapshot source to the list
97    pub fn with_snapshot(mut self, source: impl Into<Cow<'static, str>>) -> Self {
98        self.available_snapshots.push(source.into());
99        self
100    }
101
102    /// Replace all snapshot sources
103    pub fn with_snapshots(mut self, sources: Vec<Cow<'static, str>>) -> Self {
104        self.available_snapshots = sources;
105        self
106    }
107
108    /// Set the default base URL, e.g. `https://downloads.merkle.io`.
109    pub fn with_base_url(mut self, url: impl Into<Cow<'static, str>>) -> Self {
110        self.default_base_url = url.into();
111        self
112    }
113
114    /// Builder: Set custom long help text, overriding the generated help
115    pub fn with_long_help(mut self, help: impl Into<String>) -> Self {
116        self.long_help = Some(help.into());
117        self
118    }
119}
120
121impl Default for DownloadDefaults {
122    fn default() -> Self {
123        Self::default_download_defaults()
124    }
125}
126
127#[derive(Debug, Parser)]
128pub struct DownloadCommand<C: ChainSpecParser> {
129    #[command(flatten)]
130    env: EnvironmentArgs<C>,
131
132    /// Custom URL to download the snapshot from
133    #[arg(long, short, long_help = DownloadDefaults::get_global().long_help())]
134    url: Option<String>,
135}
136
137impl<C: ChainSpecParser<ChainSpec: EthChainSpec + EthereumHardforks>> DownloadCommand<C> {
138    pub async fn execute<N>(self) -> Result<()> {
139        let data_dir = self.env.datadir.resolve_datadir(self.env.chain.chain());
140        fs::create_dir_all(&data_dir)?;
141
142        let url = match self.url {
143            Some(url) => url,
144            None => {
145                let url = get_latest_snapshot_url().await?;
146                info!(target: "reth::cli", "Using default snapshot URL: {}", url);
147                url
148            }
149        };
150
151        info!(target: "reth::cli",
152            chain = %self.env.chain.chain(),
153            dir = ?data_dir.data_dir(),
154            url = %url,
155            "Starting snapshot download and extraction"
156        );
157
158        stream_and_extract(&url, data_dir.data_dir()).await?;
159        info!(target: "reth::cli", "Snapshot downloaded and extracted successfully");
160
161        Ok(())
162    }
163}
164
165impl<C: ChainSpecParser> DownloadCommand<C> {
166    /// Returns the underlying chain being used to run this command
167    pub fn chain_spec(&self) -> Option<&Arc<C::ChainSpec>> {
168        Some(&self.env.chain)
169    }
170}
171
172// Monitor process status and display progress every 100ms
173// to avoid overwhelming stdout
174struct DownloadProgress {
175    downloaded: u64,
176    total_size: u64,
177    last_displayed: Instant,
178    started_at: Instant,
179}
180
181impl DownloadProgress {
182    /// Creates new progress tracker with given total size
183    fn new(total_size: u64) -> Self {
184        let now = Instant::now();
185        Self { downloaded: 0, total_size, last_displayed: now, started_at: now }
186    }
187
188    /// Converts bytes to human readable format (B, KB, MB, GB)
189    fn format_size(size: u64) -> String {
190        let mut size = size as f64;
191        let mut unit_index = 0;
192
193        while size >= 1024.0 && unit_index < BYTE_UNITS.len() - 1 {
194            size /= 1024.0;
195            unit_index += 1;
196        }
197
198        format!("{:.2} {}", size, BYTE_UNITS[unit_index])
199    }
200
201    /// Format duration as human readable string
202    fn format_duration(duration: Duration) -> String {
203        let secs = duration.as_secs();
204        if secs < 60 {
205            format!("{secs}s")
206        } else if secs < 3600 {
207            format!("{}m {}s", secs / 60, secs % 60)
208        } else {
209            format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
210        }
211    }
212
213    /// Updates progress bar
214    fn update(&mut self, chunk_size: u64) -> Result<()> {
215        self.downloaded += chunk_size;
216
217        // Only update display at most 10 times per second for efficiency
218        if self.last_displayed.elapsed() >= Duration::from_millis(100) {
219            let formatted_downloaded = Self::format_size(self.downloaded);
220            let formatted_total = Self::format_size(self.total_size);
221            let progress = (self.downloaded as f64 / self.total_size as f64) * 100.0;
222
223            // Calculate ETA based on current speed
224            let elapsed = self.started_at.elapsed();
225            let eta = if self.downloaded > 0 {
226                let remaining = self.total_size.saturating_sub(self.downloaded);
227                let speed = self.downloaded as f64 / elapsed.as_secs_f64();
228                if speed > 0.0 {
229                    Duration::from_secs_f64(remaining as f64 / speed)
230                } else {
231                    Duration::ZERO
232                }
233            } else {
234                Duration::ZERO
235            };
236            let eta_str = Self::format_duration(eta);
237
238            // Pad with spaces to clear any previous longer line
239            print!(
240                "\rDownloading and extracting... {progress:.2}% ({formatted_downloaded} / {formatted_total}) ETA: {eta_str}     ",
241            );
242            io::stdout().flush()?;
243            self.last_displayed = Instant::now();
244        }
245
246        Ok(())
247    }
248}
249
250/// Adapter to track progress while reading
251struct ProgressReader<R> {
252    reader: R,
253    progress: DownloadProgress,
254}
255
256impl<R: Read> ProgressReader<R> {
257    fn new(reader: R, total_size: u64) -> Self {
258        Self { reader, progress: DownloadProgress::new(total_size) }
259    }
260}
261
262impl<R: Read> Read for ProgressReader<R> {
263    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
264        let bytes = self.reader.read(buf)?;
265        if bytes > 0 &&
266            let Err(e) = self.progress.update(bytes as u64)
267        {
268            return Err(io::Error::other(e));
269        }
270        Ok(bytes)
271    }
272}
273
274/// Supported compression formats for snapshots
275#[derive(Debug, Clone, Copy)]
276enum CompressionFormat {
277    Lz4,
278    Zstd,
279}
280
281impl CompressionFormat {
282    /// Detect compression format from file extension
283    fn from_url(url: &str) -> Result<Self> {
284        let path =
285            Url::parse(url).map(|u| u.path().to_string()).unwrap_or_else(|_| url.to_string());
286
287        if path.ends_with(EXTENSION_TAR_LZ4) {
288            Ok(Self::Lz4)
289        } else if path.ends_with(EXTENSION_TAR_ZSTD) {
290            Ok(Self::Zstd)
291        } else {
292            Err(eyre::eyre!(
293                "Unsupported file format. Expected .tar.lz4 or .tar.zst, got: {}",
294                path
295            ))
296        }
297    }
298}
299
300/// Extracts a compressed tar archive to the target directory with progress tracking.
301fn extract_archive<R: Read>(
302    reader: R,
303    total_size: u64,
304    format: CompressionFormat,
305    target_dir: &Path,
306) -> Result<()> {
307    let progress_reader = ProgressReader::new(reader, total_size);
308
309    match format {
310        CompressionFormat::Lz4 => {
311            let decoder = Decoder::new(progress_reader)?;
312            Archive::new(decoder).unpack(target_dir)?;
313        }
314        CompressionFormat::Zstd => {
315            let decoder = ZstdDecoder::new(progress_reader)?;
316            Archive::new(decoder).unpack(target_dir)?;
317        }
318    }
319
320    info!(target: "reth::cli", "Extraction complete.");
321    Ok(())
322}
323
324/// Extracts a snapshot from a local file.
325fn extract_from_file(path: &Path, format: CompressionFormat, target_dir: &Path) -> Result<()> {
326    let file = std::fs::File::open(path)?;
327    let total_size = file.metadata()?.len();
328    extract_archive(file, total_size, format, target_dir)
329}
330
331const MAX_DOWNLOAD_RETRIES: u32 = 10;
332const RETRY_BACKOFF_SECS: u64 = 5;
333
334/// Wrapper that tracks download progress while writing data.
335/// Used with [`io::copy`] to display progress during downloads.
336struct ProgressWriter<W> {
337    inner: W,
338    progress: DownloadProgress,
339}
340
341impl<W: Write> Write for ProgressWriter<W> {
342    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
343        let n = self.inner.write(buf)?;
344        let _ = self.progress.update(n as u64);
345        Ok(n)
346    }
347
348    fn flush(&mut self) -> io::Result<()> {
349        self.inner.flush()
350    }
351}
352
353/// Downloads a file with resume support using HTTP Range requests.
354/// Automatically retries on failure, resuming from where it left off.
355/// Returns the path to the downloaded file and its total size.
356fn resumable_download(url: &str, target_dir: &Path) -> Result<(PathBuf, u64)> {
357    let file_name = Url::parse(url)
358        .ok()
359        .and_then(|u| u.path_segments()?.next_back().map(|s| s.to_string()))
360        .unwrap_or_else(|| "snapshot.tar".to_string());
361
362    let final_path = target_dir.join(&file_name);
363    let part_path = target_dir.join(format!("{file_name}.part"));
364
365    let client = BlockingClient::builder().timeout(Duration::from_secs(30)).build()?;
366
367    let mut total_size: Option<u64> = None;
368    let mut last_error: Option<eyre::Error> = None;
369
370    for attempt in 1..=MAX_DOWNLOAD_RETRIES {
371        let existing_size = fs::metadata(&part_path).map(|m| m.len()).unwrap_or(0);
372
373        if let Some(total) = total_size &&
374            existing_size >= total
375        {
376            fs::rename(&part_path, &final_path)?;
377            info!(target: "reth::cli", "Download complete: {}", final_path.display());
378            return Ok((final_path, total));
379        }
380
381        if attempt > 1 {
382            info!(target: "reth::cli",
383                "Retry attempt {}/{} - resuming from {} bytes",
384                attempt, MAX_DOWNLOAD_RETRIES, existing_size
385            );
386        }
387
388        let mut request = client.get(url);
389        if existing_size > 0 {
390            request = request.header(RANGE, format!("bytes={existing_size}-"));
391            if attempt == 1 {
392                info!(target: "reth::cli", "Resuming download from {} bytes", existing_size);
393            }
394        }
395
396        let response = match request.send().and_then(|r| r.error_for_status()) {
397            Ok(r) => r,
398            Err(e) => {
399                last_error = Some(e.into());
400                if attempt < MAX_DOWNLOAD_RETRIES {
401                    info!(target: "reth::cli",
402                        "Download failed, retrying in {} seconds...", RETRY_BACKOFF_SECS
403                    );
404                    std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
405                }
406                continue;
407            }
408        };
409
410        let is_partial = response.status() == StatusCode::PARTIAL_CONTENT;
411
412        let size = if is_partial {
413            response
414                .headers()
415                .get("Content-Range")
416                .and_then(|v| v.to_str().ok())
417                .and_then(|v| v.split('/').next_back())
418                .and_then(|v| v.parse().ok())
419        } else {
420            response.content_length()
421        };
422
423        if total_size.is_none() {
424            total_size = size;
425        }
426
427        let current_total = total_size.ok_or_else(|| {
428            eyre::eyre!("Server did not provide Content-Length or Content-Range header")
429        })?;
430
431        let file = if is_partial && existing_size > 0 {
432            OpenOptions::new()
433                .append(true)
434                .open(&part_path)
435                .map_err(|e| fs::FsPathError::open(e, &part_path))?
436        } else {
437            fs::create_file(&part_path)?
438        };
439
440        let start_offset = if is_partial { existing_size } else { 0 };
441        let mut progress = DownloadProgress::new(current_total);
442        progress.downloaded = start_offset;
443
444        let mut writer = ProgressWriter { inner: BufWriter::new(file), progress };
445        let mut reader = response;
446
447        let copy_result = io::copy(&mut reader, &mut writer);
448        let flush_result = writer.inner.flush();
449        println!();
450
451        if let Err(e) = copy_result.and(flush_result) {
452            last_error = Some(e.into());
453            if attempt < MAX_DOWNLOAD_RETRIES {
454                info!(target: "reth::cli",
455                    "Download interrupted, retrying in {} seconds...", RETRY_BACKOFF_SECS
456                );
457                std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
458            }
459            continue;
460        }
461
462        fs::rename(&part_path, &final_path)?;
463        info!(target: "reth::cli", "Download complete: {}", final_path.display());
464        return Ok((final_path, current_total));
465    }
466
467    Err(last_error
468        .unwrap_or_else(|| eyre::eyre!("Download failed after {} attempts", MAX_DOWNLOAD_RETRIES)))
469}
470
471/// Fetches the snapshot from a remote URL with resume support, then extracts it.
472fn download_and_extract(url: &str, format: CompressionFormat, target_dir: &Path) -> Result<()> {
473    let (downloaded_path, total_size) = resumable_download(url, target_dir)?;
474
475    info!(target: "reth::cli", "Extracting snapshot...");
476    let file = fs::open(&downloaded_path)?;
477    extract_archive(file, total_size, format, target_dir)?;
478
479    fs::remove_file(&downloaded_path)?;
480    info!(target: "reth::cli", "Removed downloaded archive");
481
482    Ok(())
483}
484
485/// Downloads and extracts a snapshot, blocking until finished.
486///
487/// Supports both `file://` URLs for local files and HTTP(S) URLs for remote downloads.
488fn blocking_download_and_extract(url: &str, target_dir: &Path) -> Result<()> {
489    let format = CompressionFormat::from_url(url)?;
490
491    if let Ok(parsed_url) = Url::parse(url) &&
492        parsed_url.scheme() == "file"
493    {
494        let file_path = parsed_url
495            .to_file_path()
496            .map_err(|_| eyre::eyre!("Invalid file:// URL path: {}", url))?;
497        extract_from_file(&file_path, format, target_dir)
498    } else {
499        download_and_extract(url, format, target_dir)
500    }
501}
502
503async fn stream_and_extract(url: &str, target_dir: &Path) -> Result<()> {
504    let target_dir = target_dir.to_path_buf();
505    let url = url.to_string();
506    task::spawn_blocking(move || blocking_download_and_extract(&url, &target_dir)).await??;
507
508    Ok(())
509}
510
511// Builds default URL for latest mainnet archive snapshot using configured defaults
512async fn get_latest_snapshot_url() -> Result<String> {
513    let base_url = &DownloadDefaults::get_global().default_base_url;
514    let latest_url = format!("{base_url}/latest.txt");
515    let filename = Client::new()
516        .get(latest_url)
517        .send()
518        .await?
519        .error_for_status()?
520        .text()
521        .await?
522        .trim()
523        .to_string();
524
525    Ok(format!("{base_url}/{filename}"))
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_download_defaults_builder() {
534        let defaults = DownloadDefaults::default()
535            .with_snapshot("https://example.com/snapshots (example)")
536            .with_base_url("https://example.com");
537
538        assert_eq!(defaults.default_base_url, "https://example.com");
539        assert_eq!(defaults.available_snapshots.len(), 3); // 2 defaults + 1 added
540    }
541
542    #[test]
543    fn test_download_defaults_replace_snapshots() {
544        let defaults = DownloadDefaults::default().with_snapshots(vec![
545            Cow::Borrowed("https://custom1.com"),
546            Cow::Borrowed("https://custom2.com"),
547        ]);
548
549        assert_eq!(defaults.available_snapshots.len(), 2);
550        assert_eq!(defaults.available_snapshots[0], "https://custom1.com");
551    }
552
553    #[test]
554    fn test_long_help_generation() {
555        let defaults = DownloadDefaults::default();
556        let help = defaults.long_help();
557
558        assert!(help.contains("Available snapshot sources:"));
559        assert!(help.contains("merkle.io"));
560        assert!(help.contains("publicnode.com"));
561        assert!(help.contains("file://"));
562    }
563
564    #[test]
565    fn test_long_help_override() {
566        let custom_help = "This is custom help text for downloading snapshots.";
567        let defaults = DownloadDefaults::default().with_long_help(custom_help);
568
569        let help = defaults.long_help();
570        assert_eq!(help, custom_help);
571        assert!(!help.contains("Available snapshot sources:"));
572    }
573
574    #[test]
575    fn test_builder_chaining() {
576        let defaults = DownloadDefaults::default()
577            .with_base_url("https://custom.example.com")
578            .with_snapshot("https://snapshot1.com")
579            .with_snapshot("https://snapshot2.com")
580            .with_long_help("Custom help for snapshots");
581
582        assert_eq!(defaults.default_base_url, "https://custom.example.com");
583        assert_eq!(defaults.available_snapshots.len(), 4); // 2 defaults + 2 added
584        assert_eq!(defaults.long_help, Some("Custom help for snapshots".to_string()));
585    }
586
587    #[test]
588    fn test_compression_format_detection() {
589        assert!(matches!(
590            CompressionFormat::from_url("https://example.com/snapshot.tar.lz4"),
591            Ok(CompressionFormat::Lz4)
592        ));
593        assert!(matches!(
594            CompressionFormat::from_url("https://example.com/snapshot.tar.zst"),
595            Ok(CompressionFormat::Zstd)
596        ));
597        assert!(matches!(
598            CompressionFormat::from_url("file:///path/to/snapshot.tar.lz4"),
599            Ok(CompressionFormat::Lz4)
600        ));
601        assert!(matches!(
602            CompressionFormat::from_url("file:///path/to/snapshot.tar.zst"),
603            Ok(CompressionFormat::Zstd)
604        ));
605        assert!(CompressionFormat::from_url("https://example.com/snapshot.tar.gz").is_err());
606    }
607}