1use crate::metrics::PersistenceMetrics;
2use alloy_eips::BlockNumHash;
3use crossbeam_channel::Sender as CrossbeamSender;
4use reth_chain_state::ExecutedBlock;
5use reth_errors::ProviderError;
6use reth_ethereum_primitives::EthPrimitives;
7use reth_primitives_traits::NodePrimitives;
8use reth_provider::{
9 providers::ProviderNodeTypes, BlockExecutionWriter, BlockHashReader, ChainStateBlockWriter,
10 DBProvider, DatabaseProviderFactory, ProviderFactory, SaveBlocksMode,
11};
12use reth_prune::{PrunerError, PrunerWithFactory};
13use reth_stages_api::{MetricEvent, MetricEventsSender};
14use reth_tasks::spawn_os_thread;
15use std::{
16 sync::{
17 mpsc::{Receiver, SendError, Sender},
18 Arc,
19 },
20 thread::JoinHandle,
21 time::Instant,
22};
23use thiserror::Error;
24use tracing::{debug, error, instrument};
25
26#[derive(Debug)]
34pub struct PersistenceService<N>
35where
36 N: ProviderNodeTypes,
37{
38 provider: ProviderFactory<N>,
40 incoming: Receiver<PersistenceAction<N::Primitives>>,
42 pruner: PrunerWithFactory<ProviderFactory<N>>,
44 metrics: PersistenceMetrics,
46 sync_metrics_tx: MetricEventsSender,
48 pending_finalized_block: Option<u64>,
51 pending_safe_block: Option<u64>,
54}
55
56impl<N> PersistenceService<N>
57where
58 N: ProviderNodeTypes,
59{
60 pub fn new(
62 provider: ProviderFactory<N>,
63 incoming: Receiver<PersistenceAction<N::Primitives>>,
64 pruner: PrunerWithFactory<ProviderFactory<N>>,
65 sync_metrics_tx: MetricEventsSender,
66 ) -> Self {
67 Self {
68 provider,
69 incoming,
70 pruner,
71 metrics: PersistenceMetrics::default(),
72 sync_metrics_tx,
73 pending_finalized_block: None,
74 pending_safe_block: None,
75 }
76 }
77}
78
79impl<N> PersistenceService<N>
80where
81 N: ProviderNodeTypes,
82{
83 pub fn run(mut self) -> Result<(), PersistenceError> {
86 while let Ok(action) = self.incoming.recv() {
88 match action {
89 PersistenceAction::RemoveBlocksAbove(new_tip_num, sender) => {
90 let result = self.on_remove_blocks_above(new_tip_num)?;
91 let _ =
93 self.sync_metrics_tx.send(MetricEvent::SyncHeight { height: new_tip_num });
94 let _ = sender.send(result);
96 }
97 PersistenceAction::SaveBlocks(blocks, sender) => {
98 let result = self.on_save_blocks(blocks)?;
99 let result_number = result.map(|r| r.number);
100
101 let _ = sender.send(result);
103
104 if let Some(block_number) = result_number {
105 let _ = self
107 .sync_metrics_tx
108 .send(MetricEvent::SyncHeight { height: block_number });
109 }
110 }
111 PersistenceAction::SaveFinalizedBlock(finalized_block) => {
112 self.pending_finalized_block = Some(finalized_block);
113 }
114 PersistenceAction::SaveSafeBlock(safe_block) => {
115 self.pending_safe_block = Some(safe_block);
116 }
117 }
118 }
119 Ok(())
120 }
121
122 #[instrument(level = "debug", target = "engine::persistence", skip_all, fields(new_tip_num))]
123 fn on_remove_blocks_above(
124 &self,
125 new_tip_num: u64,
126 ) -> Result<Option<BlockNumHash>, PersistenceError> {
127 debug!(target: "engine::persistence", ?new_tip_num, "Removing blocks");
128 let start_time = Instant::now();
129 let provider_rw = self.provider.database_provider_rw()?;
130
131 let new_tip_hash = provider_rw.block_hash(new_tip_num)?;
132 provider_rw.remove_block_and_execution_above(new_tip_num)?;
133 provider_rw.commit()?;
134
135 debug!(target: "engine::persistence", ?new_tip_num, ?new_tip_hash, "Removed blocks from disk");
136 self.metrics.remove_blocks_above_duration_seconds.record(start_time.elapsed());
137 Ok(new_tip_hash.map(|hash| BlockNumHash { hash, number: new_tip_num }))
138 }
139
140 #[instrument(level = "debug", target = "engine::persistence", skip_all, fields(block_count = blocks.len()))]
141 fn on_save_blocks(
142 &mut self,
143 blocks: Vec<ExecutedBlock<N::Primitives>>,
144 ) -> Result<Option<BlockNumHash>, PersistenceError> {
145 let first_block = blocks.first().map(|b| b.recovered_block.num_hash());
146 let last_block = blocks.last().map(|b| b.recovered_block.num_hash());
147 let block_count = blocks.len();
148
149 let pending_finalized = self.pending_finalized_block.take();
150 let pending_safe = self.pending_safe_block.take();
151
152 debug!(target: "engine::persistence", ?block_count, first=?first_block, last=?last_block, "Saving range of blocks");
153
154 let start_time = Instant::now();
155
156 if let Some(last) = last_block {
157 let provider_rw = self.provider.database_provider_rw()?;
158 provider_rw.save_blocks(blocks, SaveBlocksMode::Full)?;
159
160 if let Some(finalized) = pending_finalized {
161 provider_rw.save_finalized_block_number(finalized)?;
162 }
163 if let Some(safe) = pending_safe {
164 provider_rw.save_safe_block_number(safe)?;
165 }
166
167 if self.pruner.is_pruning_needed(last.number) {
168 debug!(target: "engine::persistence", block_num=?last.number, "Running pruner");
169 let prune_start = Instant::now();
170 let _ = self.pruner.run_with_provider(&provider_rw, last.number)?;
171 self.metrics.prune_before_duration_seconds.record(prune_start.elapsed());
172 }
173
174 provider_rw.commit()?;
175 }
176
177 debug!(target: "engine::persistence", first=?first_block, last=?last_block, "Saved range of blocks");
178
179 self.metrics.save_blocks_batch_size.record(block_count as f64);
180 self.metrics.save_blocks_duration_seconds.record(start_time.elapsed());
181
182 Ok(last_block)
183 }
184}
185
186#[derive(Debug, Error)]
188pub enum PersistenceError {
189 #[error(transparent)]
191 PrunerError(#[from] PrunerError),
192
193 #[error(transparent)]
195 ProviderError(#[from] ProviderError),
196}
197
198#[derive(Debug)]
200pub enum PersistenceAction<N: NodePrimitives = EthPrimitives> {
201 SaveBlocks(Vec<ExecutedBlock<N>>, CrossbeamSender<Option<BlockNumHash>>),
207
208 RemoveBlocksAbove(u64, CrossbeamSender<Option<BlockNumHash>>),
213
214 SaveFinalizedBlock(u64),
216
217 SaveSafeBlock(u64),
219}
220
221#[derive(Debug, Clone)]
223pub struct PersistenceHandle<N: NodePrimitives = EthPrimitives> {
224 sender: Sender<PersistenceAction<N>>,
226 _service_guard: Arc<ServiceGuard>,
229}
230
231impl<T: NodePrimitives> PersistenceHandle<T> {
232 pub fn new(sender: Sender<PersistenceAction<T>>) -> Self {
237 Self { sender, _service_guard: Arc::new(ServiceGuard(None)) }
238 }
239
240 pub fn spawn_service<N>(
246 provider_factory: ProviderFactory<N>,
247 pruner: PrunerWithFactory<ProviderFactory<N>>,
248 sync_metrics_tx: MetricEventsSender,
249 ) -> PersistenceHandle<N::Primitives>
250 where
251 N: ProviderNodeTypes,
252 {
253 let (db_service_tx, db_service_rx) = std::sync::mpsc::channel();
255
256 let db_service =
258 PersistenceService::new(provider_factory, db_service_rx, pruner, sync_metrics_tx);
259 let join_handle = spawn_os_thread("persistence", || {
260 if let Err(err) = db_service.run() {
261 error!(target: "engine::persistence", ?err, "Persistence service failed");
262 }
263 });
264
265 PersistenceHandle {
266 sender: db_service_tx,
267 _service_guard: Arc::new(ServiceGuard(Some(join_handle))),
268 }
269 }
270
271 pub fn send_action(
274 &self,
275 action: PersistenceAction<T>,
276 ) -> Result<(), SendError<PersistenceAction<T>>> {
277 self.sender.send(action)
278 }
279
280 pub fn save_blocks(
289 &self,
290 blocks: Vec<ExecutedBlock<T>>,
291 tx: CrossbeamSender<Option<BlockNumHash>>,
292 ) -> Result<(), SendError<PersistenceAction<T>>> {
293 self.send_action(PersistenceAction::SaveBlocks(blocks, tx))
294 }
295
296 pub fn save_finalized_block_number(
301 &self,
302 finalized_block: u64,
303 ) -> Result<(), SendError<PersistenceAction<T>>> {
304 self.send_action(PersistenceAction::SaveFinalizedBlock(finalized_block))
305 }
306
307 pub fn save_safe_block_number(
312 &self,
313 safe_block: u64,
314 ) -> Result<(), SendError<PersistenceAction<T>>> {
315 self.send_action(PersistenceAction::SaveSafeBlock(safe_block))
316 }
317
318 pub fn remove_blocks_above(
324 &self,
325 block_num: u64,
326 tx: CrossbeamSender<Option<BlockNumHash>>,
327 ) -> Result<(), SendError<PersistenceAction<T>>> {
328 self.send_action(PersistenceAction::RemoveBlocksAbove(block_num, tx))
329 }
330}
331
332struct ServiceGuard(Option<JoinHandle<()>>);
338
339impl std::fmt::Debug for ServiceGuard {
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.debug_tuple("ServiceGuard").field(&self.0.as_ref().map(|_| "...")).finish()
342 }
343}
344
345impl Drop for ServiceGuard {
346 fn drop(&mut self) {
347 if let Some(join_handle) = self.0.take() {
348 let _ = join_handle.join();
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use alloy_primitives::B256;
357 use reth_chain_state::test_utils::TestBlockBuilder;
358 use reth_exex_types::FinishedExExHeight;
359 use reth_provider::test_utils::create_test_provider_factory;
360 use reth_prune::Pruner;
361 use tokio::sync::mpsc::unbounded_channel;
362
363 fn default_persistence_handle() -> PersistenceHandle<EthPrimitives> {
364 let provider = create_test_provider_factory();
365
366 let (_finished_exex_height_tx, finished_exex_height_rx) =
367 tokio::sync::watch::channel(FinishedExExHeight::NoExExs);
368
369 let pruner =
370 Pruner::new_with_factory(provider.clone(), vec![], 5, 0, None, finished_exex_height_rx);
371
372 let (sync_metrics_tx, _sync_metrics_rx) = unbounded_channel();
373 PersistenceHandle::<EthPrimitives>::spawn_service(provider, pruner, sync_metrics_tx)
374 }
375
376 #[test]
377 fn test_save_blocks_empty() {
378 reth_tracing::init_test_tracing();
379 let handle = default_persistence_handle();
380
381 let blocks = vec![];
382 let (tx, rx) = crossbeam_channel::bounded(1);
383
384 handle.save_blocks(blocks, tx).unwrap();
385
386 let hash = rx.recv().unwrap();
387 assert_eq!(hash, None);
388 }
389
390 #[test]
391 fn test_save_blocks_single_block() {
392 reth_tracing::init_test_tracing();
393 let handle = default_persistence_handle();
394 let block_number = 0;
395 let mut test_block_builder = TestBlockBuilder::eth();
396 let executed =
397 test_block_builder.get_executed_block_with_number(block_number, B256::random());
398 let block_hash = executed.recovered_block().hash();
399
400 let blocks = vec![executed];
401 let (tx, rx) = crossbeam_channel::bounded(1);
402
403 handle.save_blocks(blocks, tx).unwrap();
404
405 let BlockNumHash { hash: actual_hash, number: _ } = rx
406 .recv_timeout(std::time::Duration::from_secs(10))
407 .expect("test timed out")
408 .expect("no hash returned");
409
410 assert_eq!(block_hash, actual_hash);
411 }
412
413 #[test]
414 fn test_save_blocks_multiple_blocks() {
415 reth_tracing::init_test_tracing();
416 let handle = default_persistence_handle();
417
418 let mut test_block_builder = TestBlockBuilder::eth();
419 let blocks = test_block_builder.get_executed_blocks(0..5).collect::<Vec<_>>();
420 let last_hash = blocks.last().unwrap().recovered_block().hash();
421 let (tx, rx) = crossbeam_channel::bounded(1);
422
423 handle.save_blocks(blocks, tx).unwrap();
424 let BlockNumHash { hash: actual_hash, number: _ } = rx.recv().unwrap().unwrap();
425 assert_eq!(last_hash, actual_hash);
426 }
427
428 #[test]
429 fn test_save_blocks_multiple_calls() {
430 reth_tracing::init_test_tracing();
431 let handle = default_persistence_handle();
432
433 let ranges = [0..1, 1..2, 2..4, 4..5];
434 let mut test_block_builder = TestBlockBuilder::eth();
435 for range in ranges {
436 let blocks = test_block_builder.get_executed_blocks(range).collect::<Vec<_>>();
437 let last_hash = blocks.last().unwrap().recovered_block().hash();
438 let (tx, rx) = crossbeam_channel::bounded(1);
439
440 handle.save_blocks(blocks, tx).unwrap();
441
442 let BlockNumHash { hash: actual_hash, number: _ } = rx.recv().unwrap().unwrap();
443 assert_eq!(last_hash, actual_hash);
444 }
445 }
446}