reth_stages/stages/s3/
mod.rs
1mod downloader;
2pub use downloader::{fetch, Metadata};
3use downloader::{DownloaderError, S3DownloaderResponse};
4
5mod filelist;
6use filelist::DOWNLOAD_FILE_LIST;
7
8use reth_db_api::transaction::DbTxMut;
9use reth_provider::{
10 DBProvider, StageCheckpointReader, StageCheckpointWriter, StaticFileProviderFactory,
11};
12use reth_stages_api::{
13 ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
14};
15use reth_static_file_types::StaticFileSegment;
16use std::{
17 path::PathBuf,
18 task::{ready, Context, Poll},
19};
20use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
21
22const S3_STAGE_ID: StageId = StageId::Other("S3");
24
25#[derive(Default, Debug)]
27#[non_exhaustive]
28pub struct S3Stage {
29 static_file_directory: PathBuf,
31 url: String,
33 max_concurrent_requests: u64,
35 fetch_rx: Option<UnboundedReceiver<Result<S3DownloaderResponse, DownloaderError>>>,
37}
38
39impl<Provider> Stage<Provider> for S3Stage
40where
41 Provider: DBProvider<Tx: DbTxMut>
42 + StaticFileProviderFactory
43 + StageCheckpointReader
44 + StageCheckpointWriter,
45{
46 fn id(&self) -> StageId {
47 S3_STAGE_ID
48 }
49
50 fn poll_execute_ready(
51 &mut self,
52 cx: &mut Context<'_>,
53 input: ExecInput,
54 ) -> Poll<Result<(), StageError>> {
55 loop {
56 if let Some(rx) = &mut self.fetch_rx {
58 let mut is_done = false;
60
61 let response = match ready!(rx.poll_recv(cx)) {
62 Some(Ok(response)) => {
63 is_done = response.is_done();
64 Ok(())
65 }
66 Some(Err(_)) => todo!(), None => Err(StageError::ChannelClosed),
68 };
69
70 if is_done {
71 self.fetch_rx = None;
72 }
73
74 return Poll::Ready(response)
75 }
76
77 if let Some(fetch_rx) = self.maybe_spawn_fetch(input) {
79 self.fetch_rx = Some(fetch_rx);
80
81 continue
83 }
84
85 return Poll::Ready(Ok(()))
87 }
88 }
89
90 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError>
91 where
92 Provider: DBProvider<Tx: DbTxMut>
93 + StaticFileProviderFactory
94 + StageCheckpointReader
95 + StageCheckpointWriter,
96 {
97 provider.static_file_provider().initialize_index()?;
99
100 Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
117 }
118
119 fn unwind(
120 &mut self,
121 _provider: &Provider,
122 input: UnwindInput,
123 ) -> Result<UnwindOutput, StageError> {
124 Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
126 }
127}
128
129impl S3Stage {
130 fn maybe_spawn_fetch(
137 &self,
138 input: ExecInput,
139 ) -> Option<UnboundedReceiver<Result<S3DownloaderResponse, DownloaderError>>> {
140 let checkpoint = input.checkpoint();
141 let mut requests = vec![];
145 for block_range_files in &DOWNLOAD_FILE_LIST {
146 let (_, block_range) =
147 StaticFileSegment::parse_filename(block_range_files[0].0).expect("qed");
148
149 if block_range.end() <= checkpoint.block_number {
150 continue
151 }
152
153 let mut block_range_requests = vec![];
154 for (filename, file_hash) in block_range_files {
155 if self.static_file_directory.join(filename).exists() {
158 continue
160 }
161
162 block_range_requests.push((filename, file_hash));
163 }
164
165 requests.push((block_range, block_range_requests));
166 }
167
168 if requests.is_empty() {
170 return None
171 }
172
173 let static_file_directory = self.static_file_directory.clone();
174 let url = self.url.clone();
175 let max_concurrent_requests = self.max_concurrent_requests;
176
177 let (fetch_tx, fetch_rx) = unbounded_channel();
178 tokio::spawn(async move {
179 let mut requests_iter = requests.into_iter().peekable();
180
181 while let Some((_, file_requests)) = requests_iter.next() {
182 for (filename, file_hash) in file_requests {
183 if let Err(err) = fetch(
184 filename,
185 &static_file_directory,
186 &format!("{}/{filename}", url),
187 max_concurrent_requests,
188 Some(*file_hash),
189 )
190 .await
191 {
192 let _ = fetch_tx.send(Err(err));
193 return
194 }
195 }
196
197 let response = if requests_iter.peek().is_none() {
198 S3DownloaderResponse::Done
199 } else {
200 S3DownloaderResponse::AddedNewRange
201 };
202
203 let _ = fetch_tx.send(Ok(response));
204 }
205 });
206
207 Some(fetch_rx)
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::test_utils::{
215 ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
216 UnwindStageTestRunner,
217 };
218 use reth_primitives_traits::SealedHeader;
219 use reth_testing_utils::{
220 generators,
221 generators::{random_header, random_header_range},
222 };
223
224 #[derive(Default)]
227 struct S3TestRunner {
228 db: TestStageDB,
229 }
230
231 impl StageTestRunner for S3TestRunner {
232 type S = S3Stage;
233
234 fn db(&self) -> &TestStageDB {
235 &self.db
236 }
237
238 fn stage(&self) -> Self::S {
239 S3Stage::default()
240 }
241 }
242
243 impl ExecuteStageTestRunner for S3TestRunner {
244 type Seed = Vec<SealedHeader>;
245
246 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
247 let start = input.checkpoint().block_number;
248 let mut rng = generators::rng();
249 let head = random_header(&mut rng, start, None);
250 self.db.insert_headers_with_td(std::iter::once(&head))?;
251
252 let end = input.target.unwrap_or_default() + 1;
254
255 if start + 1 >= end {
256 return Ok(Vec::default())
257 }
258
259 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
260 self.db.insert_headers_with_td(headers.iter())?;
261 headers.insert(0, head);
262 Ok(headers)
263 }
264
265 fn validate_execution(
266 &self,
267 input: ExecInput,
268 output: Option<ExecOutput>,
269 ) -> Result<(), TestRunnerError> {
270 if let Some(output) = output {
271 assert!(output.done, "stage should always be done");
272 assert_eq!(
273 output.checkpoint.block_number,
274 input.target(),
275 "stage progress should always match progress of previous stage"
276 );
277 }
278 Ok(())
279 }
280 }
281
282 impl UnwindStageTestRunner for S3TestRunner {
283 fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
284 Ok(())
285 }
286 }
287
288 #[test]
289 fn parse_files() {
290 for block_range_files in &DOWNLOAD_FILE_LIST {
291 let (_, _) = StaticFileSegment::parse_filename(block_range_files[0].0).expect("qed");
292 }
293 }
294}