reth_stages/stages/s3/
mod.rs

1mod downloader;
2pub use downloader::{fetch, Metadata};
3use downloader::{DownloaderError, S3DownloaderResponse};
4
5mod filelist;
6use filelist::DOWNLOAD_FILE_LIST;
7
8use reth_db_api::transaction::DbTxMut;
9use reth_provider::{
10    DBProvider, StageCheckpointReader, StageCheckpointWriter, StaticFileProviderFactory,
11};
12use reth_stages_api::{
13    ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
14};
15use reth_static_file_types::StaticFileSegment;
16use std::{
17    path::PathBuf,
18    task::{ready, Context, Poll},
19};
20use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
21
22/// S3 `StageId`
23const S3_STAGE_ID: StageId = StageId::Other("S3");
24
25/// The S3 stage.
26#[derive(Default, Debug)]
27#[non_exhaustive]
28pub struct S3Stage {
29    /// Static file directory.
30    static_file_directory: PathBuf,
31    /// Remote server URL.
32    url: String,
33    /// Maximum number of connections per download.
34    max_concurrent_requests: u64,
35    /// Channel to receive the downloaded ranges from the fetch task.
36    fetch_rx: Option<UnboundedReceiver<Result<S3DownloaderResponse, DownloaderError>>>,
37}
38
39impl<Provider> Stage<Provider> for S3Stage
40where
41    Provider: DBProvider<Tx: DbTxMut>
42        + StaticFileProviderFactory
43        + StageCheckpointReader
44        + StageCheckpointWriter,
45{
46    fn id(&self) -> StageId {
47        S3_STAGE_ID
48    }
49
50    fn poll_execute_ready(
51        &mut self,
52        cx: &mut Context<'_>,
53        input: ExecInput,
54    ) -> Poll<Result<(), StageError>> {
55        loop {
56            // We are currently fetching and may have downloaded ranges that we can process.
57            if let Some(rx) = &mut self.fetch_rx {
58                // Whether we have downloaded all the required files.
59                let mut is_done = false;
60
61                let response = match ready!(rx.poll_recv(cx)) {
62                    Some(Ok(response)) => {
63                        is_done = response.is_done();
64                        Ok(())
65                    }
66                    Some(Err(_)) => todo!(), // TODO: DownloaderError -> StageError
67                    None => Err(StageError::ChannelClosed),
68                };
69
70                if is_done {
71                    self.fetch_rx = None;
72                }
73
74                return Poll::Ready(response)
75            }
76
77            // Spawns the downloader task if there are any missing files
78            if let Some(fetch_rx) = self.maybe_spawn_fetch(input) {
79                self.fetch_rx = Some(fetch_rx);
80
81                // Polls fetch_rx & registers waker
82                continue
83            }
84
85            // No files to be downloaded
86            return Poll::Ready(Ok(()))
87        }
88    }
89
90    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError>
91    where
92        Provider: DBProvider<Tx: DbTxMut>
93            + StaticFileProviderFactory
94            + StageCheckpointReader
95            + StageCheckpointWriter,
96    {
97        // Re-initializes the provider to detect the new additions
98        provider.static_file_provider().initialize_index()?;
99
100        // TODO logic for appending tx_block
101
102        // let (_, _to_block) = input.next_block_range().into_inner();
103        // let static_file_provider = provider.static_file_provider();
104        // let mut _tx_block_cursor =
105        // provider.tx_ref().cursor_write::<tables::TransactionBlocks>()?;
106
107        // tx_block_cursor.append(indice.last_tx_num(), &block_number)?;
108
109        // let checkpoint = StageCheckpoint { block_number: highest_block, stage_checkpoint: None };
110        // provider.save_stage_checkpoint(StageId::Bodies, checkpoint)?;
111        // provider.save_stage_checkpoint(S3_STAGE_ID, checkpoint)?;
112
113        // // TODO: verify input.target according to s3 stage specifications
114        // let done = highest_block == to_block;
115
116        Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
117    }
118
119    fn unwind(
120        &mut self,
121        _provider: &Provider,
122        input: UnwindInput,
123    ) -> Result<UnwindOutput, StageError> {
124        // TODO
125        Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
126    }
127}
128
129impl S3Stage {
130    /// It will only spawn a task to fetch files from the remote server, it there are any missing
131    /// static files.
132    ///
133    /// Every time a block range is ready with all the necessary files, it sends a
134    /// [`S3DownloaderResponse`] to `self.fetch_rx`. If it's the last requested block range, the
135    /// response will have `is_done` set to true.
136    fn maybe_spawn_fetch(
137        &self,
138        input: ExecInput,
139    ) -> Option<UnboundedReceiver<Result<S3DownloaderResponse, DownloaderError>>> {
140        let checkpoint = input.checkpoint();
141        // TODO: input target can only be certain numbers. eg. 499_999 , 999_999 etc.
142
143        // Create a list of all the missing files per block range that need to be downloaded.
144        let mut requests = vec![];
145        for block_range_files in &DOWNLOAD_FILE_LIST {
146            let (_, block_range) =
147                StaticFileSegment::parse_filename(block_range_files[0].0).expect("qed");
148
149            if block_range.end() <= checkpoint.block_number {
150                continue
151            }
152
153            let mut block_range_requests = vec![];
154            for (filename, file_hash) in block_range_files {
155                // If the file already exists, then we are resuming a previously interrupted stage
156                // run.
157                if self.static_file_directory.join(filename).exists() {
158                    // TODO: check hash if the file already exists
159                    continue
160                }
161
162                block_range_requests.push((filename, file_hash));
163            }
164
165            requests.push((block_range, block_range_requests));
166        }
167
168        // Return None, if we have downloaded all the files that are required.
169        if requests.is_empty() {
170            return None
171        }
172
173        let static_file_directory = self.static_file_directory.clone();
174        let url = self.url.clone();
175        let max_concurrent_requests = self.max_concurrent_requests;
176
177        let (fetch_tx, fetch_rx) = unbounded_channel();
178        tokio::spawn(async move {
179            let mut requests_iter = requests.into_iter().peekable();
180
181            while let Some((_, file_requests)) = requests_iter.next() {
182                for (filename, file_hash) in file_requests {
183                    if let Err(err) = fetch(
184                        filename,
185                        &static_file_directory,
186                        &format!("{}/{filename}", url),
187                        max_concurrent_requests,
188                        Some(*file_hash),
189                    )
190                    .await
191                    {
192                        let _ = fetch_tx.send(Err(err));
193                        return
194                    }
195                }
196
197                let response = if requests_iter.peek().is_none() {
198                    S3DownloaderResponse::Done
199                } else {
200                    S3DownloaderResponse::AddedNewRange
201                };
202
203                let _ = fetch_tx.send(Ok(response));
204            }
205        });
206
207        Some(fetch_rx)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::test_utils::{
215        ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
216        UnwindStageTestRunner,
217    };
218    use reth_primitives_traits::SealedHeader;
219    use reth_testing_utils::{
220        generators,
221        generators::{random_header, random_header_range},
222    };
223
224    // stage_test_suite_ext!(S3TestRunner, s3);
225
226    #[derive(Default)]
227    struct S3TestRunner {
228        db: TestStageDB,
229    }
230
231    impl StageTestRunner for S3TestRunner {
232        type S = S3Stage;
233
234        fn db(&self) -> &TestStageDB {
235            &self.db
236        }
237
238        fn stage(&self) -> Self::S {
239            S3Stage::default()
240        }
241    }
242
243    impl ExecuteStageTestRunner for S3TestRunner {
244        type Seed = Vec<SealedHeader>;
245
246        fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
247            let start = input.checkpoint().block_number;
248            let mut rng = generators::rng();
249            let head = random_header(&mut rng, start, None);
250            self.db.insert_headers_with_td(std::iter::once(&head))?;
251
252            // use previous progress as seed size
253            let end = input.target.unwrap_or_default() + 1;
254
255            if start + 1 >= end {
256                return Ok(Vec::default())
257            }
258
259            let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
260            self.db.insert_headers_with_td(headers.iter())?;
261            headers.insert(0, head);
262            Ok(headers)
263        }
264
265        fn validate_execution(
266            &self,
267            input: ExecInput,
268            output: Option<ExecOutput>,
269        ) -> Result<(), TestRunnerError> {
270            if let Some(output) = output {
271                assert!(output.done, "stage should always be done");
272                assert_eq!(
273                    output.checkpoint.block_number,
274                    input.target(),
275                    "stage progress should always match progress of previous stage"
276                );
277            }
278            Ok(())
279        }
280    }
281
282    impl UnwindStageTestRunner for S3TestRunner {
283        fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
284            Ok(())
285        }
286    }
287
288    #[test]
289    fn parse_files() {
290        for block_range_files in &DOWNLOAD_FILE_LIST {
291            let (_, _) = StaticFileSegment::parse_filename(block_range_files[0].0).expect("qed");
292        }
293    }
294}