1use crate::common::EnvironmentArgs;
2use clap::Parser;
3use eyre::Result;
4use lz4::Decoder;
5use reqwest::{blocking::Client as BlockingClient, header::RANGE, Client, StatusCode};
6use reth_chainspec::{EthChainSpec, EthereumHardforks};
7use reth_cli::chainspec::ChainSpecParser;
8use reth_fs_util as fs;
9use std::{
10 borrow::Cow,
11 fs::OpenOptions,
12 io::{self, BufWriter, Read, Write},
13 path::{Path, PathBuf},
14 sync::{Arc, OnceLock},
15 time::{Duration, Instant},
16};
17use tar::Archive;
18use tokio::task;
19use tracing::info;
20use url::Url;
21use zstd::stream::read::Decoder as ZstdDecoder;
22
23const BYTE_UNITS: [&str; 4] = ["B", "KB", "MB", "GB"];
24const MERKLE_BASE_URL: &str = "https://downloads.merkle.io";
25const EXTENSION_TAR_LZ4: &str = ".tar.lz4";
26const EXTENSION_TAR_ZSTD: &str = ".tar.zst";
27
28static DOWNLOAD_DEFAULTS: OnceLock<DownloadDefaults> = OnceLock::new();
30
31#[derive(Debug, Clone)]
35pub struct DownloadDefaults {
36 pub available_snapshots: Vec<Cow<'static, str>>,
38 pub default_base_url: Cow<'static, str>,
40 pub default_chain_aware_base_url: Option<Cow<'static, str>>,
48 pub long_help: Option<String>,
50}
51
52impl DownloadDefaults {
53 pub fn try_init(self) -> Result<(), Self> {
55 DOWNLOAD_DEFAULTS.set(self)
56 }
57
58 pub fn get_global() -> &'static DownloadDefaults {
60 DOWNLOAD_DEFAULTS.get_or_init(DownloadDefaults::default_download_defaults)
61 }
62
63 pub fn default_download_defaults() -> Self {
65 Self {
66 available_snapshots: vec![
67 Cow::Borrowed("https://www.merkle.io/snapshots (default, mainnet archive)"),
68 Cow::Borrowed("https://publicnode.com/snapshots (full nodes & testnets)"),
69 ],
70 default_base_url: Cow::Borrowed(MERKLE_BASE_URL),
71 default_chain_aware_base_url: None,
72 long_help: None,
73 }
74 }
75
76 pub fn long_help(&self) -> String {
81 if let Some(ref custom_help) = self.long_help {
82 return custom_help.clone();
83 }
84
85 let mut help = String::from(
86 "Specify a snapshot URL or let the command propose a default one.\n\nAvailable snapshot sources:\n",
87 );
88
89 for source in &self.available_snapshots {
90 help.push_str("- ");
91 help.push_str(source);
92 help.push('\n');
93 }
94
95 help.push_str(
96 "\nIf no URL is provided, the latest archive snapshot for the selected chain\nwill be proposed for download from ",
97 );
98 help.push_str(
99 self.default_chain_aware_base_url.as_deref().unwrap_or(&self.default_base_url),
100 );
101 help.push_str(
102 ".\n\nLocal file:// URLs are also supported for extracting snapshots from disk.",
103 );
104 help
105 }
106
107 pub fn with_snapshot(mut self, source: impl Into<Cow<'static, str>>) -> Self {
109 self.available_snapshots.push(source.into());
110 self
111 }
112
113 pub fn with_snapshots(mut self, sources: Vec<Cow<'static, str>>) -> Self {
115 self.available_snapshots = sources;
116 self
117 }
118
119 pub fn with_base_url(mut self, url: impl Into<Cow<'static, str>>) -> Self {
121 self.default_base_url = url.into();
122 self
123 }
124
125 pub fn with_chain_aware_base_url(mut self, url: impl Into<Cow<'static, str>>) -> Self {
127 self.default_chain_aware_base_url = Some(url.into());
128 self
129 }
130
131 pub fn with_long_help(mut self, help: impl Into<String>) -> Self {
133 self.long_help = Some(help.into());
134 self
135 }
136}
137
138impl Default for DownloadDefaults {
139 fn default() -> Self {
140 Self::default_download_defaults()
141 }
142}
143
144#[derive(Debug, Parser)]
145pub struct DownloadCommand<C: ChainSpecParser> {
146 #[command(flatten)]
147 env: EnvironmentArgs<C>,
148
149 #[arg(long, short, long_help = DownloadDefaults::get_global().long_help())]
151 url: Option<String>,
152}
153
154impl<C: ChainSpecParser<ChainSpec: EthChainSpec + EthereumHardforks>> DownloadCommand<C> {
155 pub async fn execute<N>(self) -> Result<()> {
156 let data_dir = self.env.datadir.resolve_datadir(self.env.chain.chain());
157 fs::create_dir_all(&data_dir)?;
158
159 let url = match self.url {
160 Some(url) => url,
161 None => {
162 let url = get_latest_snapshot_url(self.env.chain.chain().id()).await?;
163 info!(target: "reth::cli", "Using default snapshot URL: {}", url);
164 url
165 }
166 };
167
168 info!(target: "reth::cli",
169 chain = %self.env.chain.chain(),
170 dir = ?data_dir.data_dir(),
171 url = %url,
172 "Starting snapshot download and extraction"
173 );
174
175 stream_and_extract(&url, data_dir.data_dir()).await?;
176 info!(target: "reth::cli", "Snapshot downloaded and extracted successfully");
177
178 Ok(())
179 }
180}
181
182impl<C: ChainSpecParser> DownloadCommand<C> {
183 pub fn chain_spec(&self) -> Option<&Arc<C::ChainSpec>> {
185 Some(&self.env.chain)
186 }
187}
188
189struct DownloadProgress {
192 downloaded: u64,
193 total_size: u64,
194 last_displayed: Instant,
195 started_at: Instant,
196}
197
198impl DownloadProgress {
199 fn new(total_size: u64) -> Self {
201 let now = Instant::now();
202 Self { downloaded: 0, total_size, last_displayed: now, started_at: now }
203 }
204
205 fn format_size(size: u64) -> String {
207 let mut size = size as f64;
208 let mut unit_index = 0;
209
210 while size >= 1024.0 && unit_index < BYTE_UNITS.len() - 1 {
211 size /= 1024.0;
212 unit_index += 1;
213 }
214
215 format!("{:.2} {}", size, BYTE_UNITS[unit_index])
216 }
217
218 fn format_duration(duration: Duration) -> String {
220 let secs = duration.as_secs();
221 if secs < 60 {
222 format!("{secs}s")
223 } else if secs < 3600 {
224 format!("{}m {}s", secs / 60, secs % 60)
225 } else {
226 format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
227 }
228 }
229
230 fn update(&mut self, chunk_size: u64) -> Result<()> {
232 self.downloaded += chunk_size;
233
234 if self.last_displayed.elapsed() >= Duration::from_millis(100) {
236 let formatted_downloaded = Self::format_size(self.downloaded);
237 let formatted_total = Self::format_size(self.total_size);
238 let progress = (self.downloaded as f64 / self.total_size as f64) * 100.0;
239
240 let elapsed = self.started_at.elapsed();
242 let eta = if self.downloaded > 0 {
243 let remaining = self.total_size.saturating_sub(self.downloaded);
244 let speed = self.downloaded as f64 / elapsed.as_secs_f64();
245 if speed > 0.0 {
246 Duration::from_secs_f64(remaining as f64 / speed)
247 } else {
248 Duration::ZERO
249 }
250 } else {
251 Duration::ZERO
252 };
253 let eta_str = Self::format_duration(eta);
254
255 print!(
257 "\rDownloading and extracting... {progress:.2}% ({formatted_downloaded} / {formatted_total}) ETA: {eta_str} ",
258 );
259 io::stdout().flush()?;
260 self.last_displayed = Instant::now();
261 }
262
263 Ok(())
264 }
265}
266
267struct ProgressReader<R> {
269 reader: R,
270 progress: DownloadProgress,
271}
272
273impl<R: Read> ProgressReader<R> {
274 fn new(reader: R, total_size: u64) -> Self {
275 Self { reader, progress: DownloadProgress::new(total_size) }
276 }
277}
278
279impl<R: Read> Read for ProgressReader<R> {
280 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
281 let bytes = self.reader.read(buf)?;
282 if bytes > 0 &&
283 let Err(e) = self.progress.update(bytes as u64)
284 {
285 return Err(io::Error::other(e));
286 }
287 Ok(bytes)
288 }
289}
290
291#[derive(Debug, Clone, Copy)]
293enum CompressionFormat {
294 Lz4,
295 Zstd,
296}
297
298impl CompressionFormat {
299 fn from_url(url: &str) -> Result<Self> {
301 let path =
302 Url::parse(url).map(|u| u.path().to_string()).unwrap_or_else(|_| url.to_string());
303
304 if path.ends_with(EXTENSION_TAR_LZ4) {
305 Ok(Self::Lz4)
306 } else if path.ends_with(EXTENSION_TAR_ZSTD) {
307 Ok(Self::Zstd)
308 } else {
309 Err(eyre::eyre!(
310 "Unsupported file format. Expected .tar.lz4 or .tar.zst, got: {}",
311 path
312 ))
313 }
314 }
315}
316
317fn extract_archive<R: Read>(
319 reader: R,
320 total_size: u64,
321 format: CompressionFormat,
322 target_dir: &Path,
323) -> Result<()> {
324 let progress_reader = ProgressReader::new(reader, total_size);
325
326 match format {
327 CompressionFormat::Lz4 => {
328 let decoder = Decoder::new(progress_reader)?;
329 Archive::new(decoder).unpack(target_dir)?;
330 }
331 CompressionFormat::Zstd => {
332 let decoder = ZstdDecoder::new(progress_reader)?;
333 Archive::new(decoder).unpack(target_dir)?;
334 }
335 }
336
337 info!(target: "reth::cli", "Extraction complete.");
338 Ok(())
339}
340
341fn extract_from_file(path: &Path, format: CompressionFormat, target_dir: &Path) -> Result<()> {
343 let file = std::fs::File::open(path)?;
344 let total_size = file.metadata()?.len();
345 extract_archive(file, total_size, format, target_dir)
346}
347
348const MAX_DOWNLOAD_RETRIES: u32 = 10;
349const RETRY_BACKOFF_SECS: u64 = 5;
350
351struct ProgressWriter<W> {
354 inner: W,
355 progress: DownloadProgress,
356}
357
358impl<W: Write> Write for ProgressWriter<W> {
359 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
360 let n = self.inner.write(buf)?;
361 let _ = self.progress.update(n as u64);
362 Ok(n)
363 }
364
365 fn flush(&mut self) -> io::Result<()> {
366 self.inner.flush()
367 }
368}
369
370fn resumable_download(url: &str, target_dir: &Path) -> Result<(PathBuf, u64)> {
374 let file_name = Url::parse(url)
375 .ok()
376 .and_then(|u| u.path_segments()?.next_back().map(|s| s.to_string()))
377 .unwrap_or_else(|| "snapshot.tar".to_string());
378
379 let final_path = target_dir.join(&file_name);
380 let part_path = target_dir.join(format!("{file_name}.part"));
381
382 let client = BlockingClient::builder().timeout(Duration::from_secs(30)).build()?;
383
384 let mut total_size: Option<u64> = None;
385 let mut last_error: Option<eyre::Error> = None;
386
387 for attempt in 1..=MAX_DOWNLOAD_RETRIES {
388 let existing_size = fs::metadata(&part_path).map(|m| m.len()).unwrap_or(0);
389
390 if let Some(total) = total_size &&
391 existing_size >= total
392 {
393 fs::rename(&part_path, &final_path)?;
394 info!(target: "reth::cli", "Download complete: {}", final_path.display());
395 return Ok((final_path, total));
396 }
397
398 if attempt > 1 {
399 info!(target: "reth::cli",
400 "Retry attempt {}/{} - resuming from {} bytes",
401 attempt, MAX_DOWNLOAD_RETRIES, existing_size
402 );
403 }
404
405 let mut request = client.get(url);
406 if existing_size > 0 {
407 request = request.header(RANGE, format!("bytes={existing_size}-"));
408 if attempt == 1 {
409 info!(target: "reth::cli", "Resuming download from {} bytes", existing_size);
410 }
411 }
412
413 let response = match request.send().and_then(|r| r.error_for_status()) {
414 Ok(r) => r,
415 Err(e) => {
416 last_error = Some(e.into());
417 if attempt < MAX_DOWNLOAD_RETRIES {
418 info!(target: "reth::cli",
419 "Download failed, retrying in {} seconds...", RETRY_BACKOFF_SECS
420 );
421 std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
422 }
423 continue;
424 }
425 };
426
427 let is_partial = response.status() == StatusCode::PARTIAL_CONTENT;
428
429 let size = if is_partial {
430 response
431 .headers()
432 .get("Content-Range")
433 .and_then(|v| v.to_str().ok())
434 .and_then(|v| v.split('/').next_back())
435 .and_then(|v| v.parse().ok())
436 } else {
437 response.content_length()
438 };
439
440 if total_size.is_none() {
441 total_size = size;
442 }
443
444 let current_total = total_size.ok_or_else(|| {
445 eyre::eyre!("Server did not provide Content-Length or Content-Range header")
446 })?;
447
448 let file = if is_partial && existing_size > 0 {
449 OpenOptions::new()
450 .append(true)
451 .open(&part_path)
452 .map_err(|e| fs::FsPathError::open(e, &part_path))?
453 } else {
454 fs::create_file(&part_path)?
455 };
456
457 let start_offset = if is_partial { existing_size } else { 0 };
458 let mut progress = DownloadProgress::new(current_total);
459 progress.downloaded = start_offset;
460
461 let mut writer = ProgressWriter { inner: BufWriter::new(file), progress };
462 let mut reader = response;
463
464 let copy_result = io::copy(&mut reader, &mut writer);
465 let flush_result = writer.inner.flush();
466 println!();
467
468 if let Err(e) = copy_result.and(flush_result) {
469 last_error = Some(e.into());
470 if attempt < MAX_DOWNLOAD_RETRIES {
471 info!(target: "reth::cli",
472 "Download interrupted, retrying in {} seconds...", RETRY_BACKOFF_SECS
473 );
474 std::thread::sleep(Duration::from_secs(RETRY_BACKOFF_SECS));
475 }
476 continue;
477 }
478
479 fs::rename(&part_path, &final_path)?;
480 info!(target: "reth::cli", "Download complete: {}", final_path.display());
481 return Ok((final_path, current_total));
482 }
483
484 Err(last_error
485 .unwrap_or_else(|| eyre::eyre!("Download failed after {} attempts", MAX_DOWNLOAD_RETRIES)))
486}
487
488fn download_and_extract(url: &str, format: CompressionFormat, target_dir: &Path) -> Result<()> {
490 let (downloaded_path, total_size) = resumable_download(url, target_dir)?;
491
492 info!(target: "reth::cli", "Extracting snapshot...");
493 let file = fs::open(&downloaded_path)?;
494 extract_archive(file, total_size, format, target_dir)?;
495
496 fs::remove_file(&downloaded_path)?;
497 info!(target: "reth::cli", "Removed downloaded archive");
498
499 Ok(())
500}
501
502fn blocking_download_and_extract(url: &str, target_dir: &Path) -> Result<()> {
506 let format = CompressionFormat::from_url(url)?;
507
508 if let Ok(parsed_url) = Url::parse(url) &&
509 parsed_url.scheme() == "file"
510 {
511 let file_path = parsed_url
512 .to_file_path()
513 .map_err(|_| eyre::eyre!("Invalid file:// URL path: {}", url))?;
514 extract_from_file(&file_path, format, target_dir)
515 } else {
516 download_and_extract(url, format, target_dir)
517 }
518}
519
520async fn stream_and_extract(url: &str, target_dir: &Path) -> Result<()> {
521 let target_dir = target_dir.to_path_buf();
522 let url = url.to_string();
523 task::spawn_blocking(move || blocking_download_and_extract(&url, &target_dir)).await??;
524
525 Ok(())
526}
527
528async fn get_latest_snapshot_url(chain_id: u64) -> Result<String> {
530 let defaults = DownloadDefaults::get_global();
531 let base_url = match &defaults.default_chain_aware_base_url {
532 Some(url) => format!("{url}/{chain_id}"),
533 None => defaults.default_base_url.to_string(),
534 };
535 let latest_url = format!("{base_url}/latest.txt");
536 let filename = Client::new()
537 .get(latest_url)
538 .send()
539 .await?
540 .error_for_status()?
541 .text()
542 .await?
543 .trim()
544 .to_string();
545
546 Ok(format!("{base_url}/{filename}"))
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_download_defaults_builder() {
555 let defaults = DownloadDefaults::default()
556 .with_snapshot("https://example.com/snapshots (example)")
557 .with_base_url("https://example.com");
558
559 assert_eq!(defaults.default_base_url, "https://example.com");
560 assert_eq!(defaults.available_snapshots.len(), 3); }
562
563 #[test]
564 fn test_download_defaults_replace_snapshots() {
565 let defaults = DownloadDefaults::default().with_snapshots(vec![
566 Cow::Borrowed("https://custom1.com"),
567 Cow::Borrowed("https://custom2.com"),
568 ]);
569
570 assert_eq!(defaults.available_snapshots.len(), 2);
571 assert_eq!(defaults.available_snapshots[0], "https://custom1.com");
572 }
573
574 #[test]
575 fn test_long_help_generation() {
576 let defaults = DownloadDefaults::default();
577 let help = defaults.long_help();
578
579 assert!(help.contains("Available snapshot sources:"));
580 assert!(help.contains("merkle.io"));
581 assert!(help.contains("publicnode.com"));
582 assert!(help.contains("file://"));
583 }
584
585 #[test]
586 fn test_long_help_override() {
587 let custom_help = "This is custom help text for downloading snapshots.";
588 let defaults = DownloadDefaults::default().with_long_help(custom_help);
589
590 let help = defaults.long_help();
591 assert_eq!(help, custom_help);
592 assert!(!help.contains("Available snapshot sources:"));
593 }
594
595 #[test]
596 fn test_builder_chaining() {
597 let defaults = DownloadDefaults::default()
598 .with_base_url("https://custom.example.com")
599 .with_snapshot("https://snapshot1.com")
600 .with_snapshot("https://snapshot2.com")
601 .with_long_help("Custom help for snapshots");
602
603 assert_eq!(defaults.default_base_url, "https://custom.example.com");
604 assert_eq!(defaults.available_snapshots.len(), 4); assert_eq!(defaults.long_help, Some("Custom help for snapshots".to_string()));
606 }
607
608 #[test]
609 fn test_compression_format_detection() {
610 assert!(matches!(
611 CompressionFormat::from_url("https://example.com/snapshot.tar.lz4"),
612 Ok(CompressionFormat::Lz4)
613 ));
614 assert!(matches!(
615 CompressionFormat::from_url("https://example.com/snapshot.tar.zst"),
616 Ok(CompressionFormat::Zstd)
617 ));
618 assert!(matches!(
619 CompressionFormat::from_url("file:///path/to/snapshot.tar.lz4"),
620 Ok(CompressionFormat::Lz4)
621 ));
622 assert!(matches!(
623 CompressionFormat::from_url("file:///path/to/snapshot.tar.zst"),
624 Ok(CompressionFormat::Zstd)
625 ));
626 assert!(CompressionFormat::from_url("https://example.com/snapshot.tar.gz").is_err());
627 }
628}