1use crate::metrics::PersistenceMetrics;
2use alloy_eips::BlockNumHash;
3use crossbeam_channel::Sender as CrossbeamSender;
4use reth_chain_state::ExecutedBlock;
5use reth_errors::ProviderError;
6use reth_ethereum_primitives::EthPrimitives;
7use reth_primitives_traits::{FastInstant as Instant, NodePrimitives};
8use reth_provider::{
9 providers::ProviderNodeTypes, BlockExecutionWriter, BlockHashReader, ChainStateBlockWriter,
10 DBProvider, DatabaseProviderFactory, ProviderFactory, SaveBlocksMode,
11};
12use reth_prune::{PrunerError, PrunerWithFactory};
13use reth_stages_api::{MetricEvent, MetricEventsSender};
14use reth_tasks::spawn_os_thread;
15use std::{
16 sync::{
17 mpsc::{Receiver, SendError, Sender},
18 Arc,
19 },
20 thread::JoinHandle,
21 time::Duration,
22};
23use thiserror::Error;
24use tracing::{debug, error, instrument};
25
26#[derive(Debug)]
28pub struct PersistenceResult {
29 pub last_block: Option<BlockNumHash>,
31 pub commit_duration: Option<Duration>,
33}
34
35#[derive(Debug)]
43pub struct PersistenceService<N>
44where
45 N: ProviderNodeTypes,
46{
47 provider: ProviderFactory<N>,
49 incoming: Receiver<PersistenceAction<N::Primitives>>,
51 pruner: PrunerWithFactory<ProviderFactory<N>>,
53 metrics: PersistenceMetrics,
55 sync_metrics_tx: MetricEventsSender,
57 pending_finalized_block: Option<u64>,
60 pending_safe_block: Option<u64>,
63}
64
65impl<N> PersistenceService<N>
66where
67 N: ProviderNodeTypes,
68{
69 pub fn new(
71 provider: ProviderFactory<N>,
72 incoming: Receiver<PersistenceAction<N::Primitives>>,
73 pruner: PrunerWithFactory<ProviderFactory<N>>,
74 sync_metrics_tx: MetricEventsSender,
75 ) -> Self {
76 Self {
77 provider,
78 incoming,
79 pruner,
80 metrics: PersistenceMetrics::default(),
81 sync_metrics_tx,
82 pending_finalized_block: None,
83 pending_safe_block: None,
84 }
85 }
86}
87
88impl<N> PersistenceService<N>
89where
90 N: ProviderNodeTypes,
91{
92 pub fn run(mut self) -> Result<(), PersistenceError> {
95 while let Ok(action) = self.incoming.recv() {
97 match action {
98 PersistenceAction::RemoveBlocksAbove(new_tip_num, sender) => {
99 let last_block = self.on_remove_blocks_above(new_tip_num)?;
100 let _ =
102 self.sync_metrics_tx.send(MetricEvent::SyncHeight { height: new_tip_num });
103 let _ = sender.send(PersistenceResult { last_block, commit_duration: None });
104 }
105 PersistenceAction::SaveBlocks(blocks, sender) => {
106 let result = self.on_save_blocks(blocks)?;
107 let result_number = result.last_block.map(|b| b.number);
108
109 let _ = sender.send(result);
110
111 if let Some(block_number) = result_number {
112 let _ = self
114 .sync_metrics_tx
115 .send(MetricEvent::SyncHeight { height: block_number });
116 }
117 }
118 PersistenceAction::SaveFinalizedBlock(finalized_block) => {
119 self.pending_finalized_block = Some(finalized_block);
120 }
121 PersistenceAction::SaveSafeBlock(safe_block) => {
122 self.pending_safe_block = Some(safe_block);
123 }
124 }
125 }
126 Ok(())
127 }
128
129 #[instrument(level = "debug", target = "engine::persistence", skip_all, fields(%new_tip_num))]
130 fn on_remove_blocks_above(
131 &self,
132 new_tip_num: u64,
133 ) -> Result<Option<BlockNumHash>, PersistenceError> {
134 debug!(target: "engine::persistence", ?new_tip_num, "Removing blocks");
135 let start_time = Instant::now();
136 let provider_rw = self.provider.database_provider_rw()?;
137
138 let new_tip_hash = provider_rw.block_hash(new_tip_num)?;
139 provider_rw.remove_block_and_execution_above(new_tip_num)?;
140 provider_rw.commit()?;
141
142 debug!(target: "engine::persistence", ?new_tip_num, ?new_tip_hash, "Removed blocks from disk");
143 self.metrics.remove_blocks_above_duration_seconds.record(start_time.elapsed());
144 Ok(new_tip_hash.map(|hash| BlockNumHash { hash, number: new_tip_num }))
145 }
146
147 #[instrument(level = "debug", target = "engine::persistence", skip_all, fields(block_count = blocks.len()))]
148 fn on_save_blocks(
149 &mut self,
150 blocks: Vec<ExecutedBlock<N::Primitives>>,
151 ) -> Result<PersistenceResult, PersistenceError> {
152 let first_block = blocks.first().map(|b| b.recovered_block.num_hash());
153 let last_block = blocks.last().map(|b| b.recovered_block.num_hash());
154 let block_count = blocks.len();
155
156 let pending_finalized = self.pending_finalized_block.take();
157 let pending_safe = self.pending_safe_block.take();
158
159 debug!(target: "engine::persistence", ?block_count, first=?first_block, last=?last_block, "Saving range of blocks");
160
161 let start_time = Instant::now();
162
163 if let Some(last) = last_block {
164 let provider_rw = self.provider.database_provider_rw()?;
165 provider_rw.save_blocks(blocks, SaveBlocksMode::Full)?;
166
167 if let Some(finalized) = pending_finalized {
168 provider_rw.save_finalized_block_number(finalized.min(last.number))?;
169 if finalized > last.number {
170 self.pending_finalized_block = Some(finalized);
171 }
172 }
173 if let Some(safe) = pending_safe {
174 provider_rw.save_safe_block_number(safe.min(last.number))?;
175 if safe > last.number {
176 self.pending_safe_block = Some(safe);
177 }
178 }
179
180 provider_rw.commit()?;
181 debug!(target: "engine::persistence", first=?first_block, last=?last_block, "Saved range of blocks");
182
183 if self.pruner.is_pruning_needed(last.number) {
189 debug!(target: "engine::persistence", block_num=?last.number, "Running pruner");
190 let prune_start = Instant::now();
191 let provider_rw = self.provider.database_provider_rw()?;
192 let _ = self.pruner.run_with_provider(&provider_rw, last.number)?;
193 provider_rw.commit()?;
194 debug!(target: "engine::persistence", tip=?last.number, "Finished pruning after saving blocks");
195 self.metrics.prune_before_duration_seconds.record(prune_start.elapsed());
196 }
197 }
198
199 let elapsed = start_time.elapsed();
200 self.metrics.save_blocks_batch_size.record(block_count as f64);
201 self.metrics.save_blocks_duration_seconds.record(elapsed);
202
203 Ok(PersistenceResult { last_block, commit_duration: Some(elapsed) })
204 }
205}
206
207#[derive(Debug, Error)]
209pub enum PersistenceError {
210 #[error(transparent)]
212 PrunerError(#[from] PrunerError),
213
214 #[error(transparent)]
216 ProviderError(#[from] ProviderError),
217}
218
219#[derive(Debug)]
221pub enum PersistenceAction<N: NodePrimitives = EthPrimitives> {
222 SaveBlocks(Vec<ExecutedBlock<N>>, CrossbeamSender<PersistenceResult>),
228
229 RemoveBlocksAbove(u64, CrossbeamSender<PersistenceResult>),
234
235 SaveFinalizedBlock(u64),
237
238 SaveSafeBlock(u64),
240}
241
242#[derive(Debug, Clone)]
244pub struct PersistenceHandle<N: NodePrimitives = EthPrimitives> {
245 sender: Sender<PersistenceAction<N>>,
247 _service_guard: Arc<ServiceGuard>,
250}
251
252impl<T: NodePrimitives> PersistenceHandle<T> {
253 pub fn new(sender: Sender<PersistenceAction<T>>) -> Self {
258 Self { sender, _service_guard: Arc::new(ServiceGuard(None)) }
259 }
260
261 pub fn spawn_service<N>(
267 provider_factory: ProviderFactory<N>,
268 pruner: PrunerWithFactory<ProviderFactory<N>>,
269 sync_metrics_tx: MetricEventsSender,
270 ) -> PersistenceHandle<N::Primitives>
271 where
272 N: ProviderNodeTypes,
273 {
274 let (db_service_tx, db_service_rx) = std::sync::mpsc::channel();
276
277 let db_service =
279 PersistenceService::new(provider_factory, db_service_rx, pruner, sync_metrics_tx);
280 let join_handle = spawn_os_thread("persistence", || {
281 if let Err(err) = db_service.run() {
282 error!(target: "engine::persistence", ?err, "Persistence service failed");
283 }
284 });
285
286 PersistenceHandle {
287 sender: db_service_tx,
288 _service_guard: Arc::new(ServiceGuard(Some(join_handle))),
289 }
290 }
291
292 pub fn send_action(
295 &self,
296 action: PersistenceAction<T>,
297 ) -> Result<(), SendError<PersistenceAction<T>>> {
298 self.sender.send(action)
299 }
300
301 pub fn save_blocks(
310 &self,
311 blocks: Vec<ExecutedBlock<T>>,
312 tx: CrossbeamSender<PersistenceResult>,
313 ) -> Result<(), SendError<PersistenceAction<T>>> {
314 self.send_action(PersistenceAction::SaveBlocks(blocks, tx))
315 }
316
317 pub fn save_finalized_block_number(
322 &self,
323 finalized_block: u64,
324 ) -> Result<(), SendError<PersistenceAction<T>>> {
325 self.send_action(PersistenceAction::SaveFinalizedBlock(finalized_block))
326 }
327
328 pub fn save_safe_block_number(
333 &self,
334 safe_block: u64,
335 ) -> Result<(), SendError<PersistenceAction<T>>> {
336 self.send_action(PersistenceAction::SaveSafeBlock(safe_block))
337 }
338
339 pub fn remove_blocks_above(
345 &self,
346 block_num: u64,
347 tx: CrossbeamSender<PersistenceResult>,
348 ) -> Result<(), SendError<PersistenceAction<T>>> {
349 self.send_action(PersistenceAction::RemoveBlocksAbove(block_num, tx))
350 }
351}
352
353struct ServiceGuard(Option<JoinHandle<()>>);
359
360impl std::fmt::Debug for ServiceGuard {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 f.debug_tuple("ServiceGuard").field(&self.0.as_ref().map(|_| "...")).finish()
363 }
364}
365
366impl Drop for ServiceGuard {
367 fn drop(&mut self) {
368 if let Some(join_handle) = self.0.take() {
369 let _ = join_handle.join();
370 }
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use alloy_primitives::B256;
378 use reth_chain_state::test_utils::TestBlockBuilder;
379 use reth_exex_types::FinishedExExHeight;
380 use reth_provider::test_utils::create_test_provider_factory;
381 use reth_prune::Pruner;
382 use tokio::sync::mpsc::unbounded_channel;
383
384 fn default_persistence_handle() -> PersistenceHandle<EthPrimitives> {
385 let provider = create_test_provider_factory();
386
387 let (_finished_exex_height_tx, finished_exex_height_rx) =
388 tokio::sync::watch::channel(FinishedExExHeight::NoExExs);
389
390 let pruner =
391 Pruner::new_with_factory(provider.clone(), vec![], 5, 0, None, finished_exex_height_rx);
392
393 let (sync_metrics_tx, _sync_metrics_rx) = unbounded_channel();
394 PersistenceHandle::<EthPrimitives>::spawn_service(provider, pruner, sync_metrics_tx)
395 }
396
397 #[test]
398 fn test_save_blocks_empty() {
399 reth_tracing::init_test_tracing();
400 let handle = default_persistence_handle();
401
402 let blocks = vec![];
403 let (tx, rx) = crossbeam_channel::bounded(1);
404
405 handle.save_blocks(blocks, tx).unwrap();
406
407 let result = rx.recv().unwrap();
408 assert!(result.last_block.is_none());
409 }
410
411 #[test]
412 fn test_save_blocks_single_block() {
413 reth_tracing::init_test_tracing();
414 let handle = default_persistence_handle();
415 let block_number = 0;
416 let mut test_block_builder = TestBlockBuilder::eth();
417 let executed =
418 test_block_builder.get_executed_block_with_number(block_number, B256::random());
419 let block_hash = executed.recovered_block().hash();
420
421 let blocks = vec![executed];
422 let (tx, rx) = crossbeam_channel::bounded(1);
423
424 handle.save_blocks(blocks, tx).unwrap();
425
426 let result = rx.recv_timeout(std::time::Duration::from_secs(10)).expect("test timed out");
427
428 assert_eq!(block_hash, result.last_block.unwrap().hash);
429 }
430
431 #[test]
432 fn test_save_blocks_multiple_blocks() {
433 reth_tracing::init_test_tracing();
434 let handle = default_persistence_handle();
435
436 let mut test_block_builder = TestBlockBuilder::eth();
437 let blocks = test_block_builder.get_executed_blocks(0..5).collect::<Vec<_>>();
438 let last_hash = blocks.last().unwrap().recovered_block().hash();
439 let (tx, rx) = crossbeam_channel::bounded(1);
440
441 handle.save_blocks(blocks, tx).unwrap();
442 let result = rx.recv().unwrap();
443 assert_eq!(last_hash, result.last_block.unwrap().hash);
444 }
445
446 #[test]
447 fn test_save_blocks_multiple_calls() {
448 reth_tracing::init_test_tracing();
449 let handle = default_persistence_handle();
450
451 let ranges = [0..1, 1..2, 2..4, 4..5];
452 let mut test_block_builder = TestBlockBuilder::eth();
453 for range in ranges {
454 let blocks = test_block_builder.get_executed_blocks(range).collect::<Vec<_>>();
455 let last_hash = blocks.last().unwrap().recovered_block().hash();
456 let (tx, rx) = crossbeam_channel::bounded(1);
457
458 handle.save_blocks(blocks, tx).unwrap();
459
460 let result = rx.recv().unwrap();
461 assert_eq!(last_hash, result.last_block.unwrap().hash);
462 }
463 }
464
465 #[test]
473 fn test_save_blocks_then_prune_preserves_new_history() {
474 use reth_db::{models::ShardedKey, tables, BlockNumberList};
475 use reth_provider::RocksDBProviderFactory;
476
477 reth_tracing::init_test_tracing();
478
479 let provider_factory = create_test_provider_factory();
480 let tracked_addr = alloy_primitives::Address::from([0xBE; 20]);
481
482 let rocksdb = provider_factory.rocksdb_provider();
484 {
485 let mut batch = rocksdb.batch();
486 let initial_blocks: Vec<u64> = (0..20).collect();
487 let shard = BlockNumberList::new_pre_sorted(initial_blocks.iter().copied());
488 batch
489 .put::<tables::AccountsHistory>(ShardedKey::new(tracked_addr, u64::MAX), &shard)
490 .unwrap();
491 batch.commit().unwrap();
492 }
493
494 let mut batch1 = rocksdb.batch();
497 batch1.append_account_history_shard(tracked_addr, 20..25u64).unwrap();
498 batch1.commit().unwrap();
499
500 let mut batch2 = rocksdb.batch();
503 batch2.prune_account_history_to(tracked_addr, 14).unwrap();
504 batch2.commit().unwrap();
505
506 let shards = rocksdb.account_history_shards(tracked_addr).unwrap();
508 let entries: Vec<u64> = shards.iter().flat_map(|(_, list)| list.iter()).collect();
509 let expected: Vec<u64> = (15..25).collect();
510 assert_eq!(entries, expected, "new entries 20..25 must survive pruning");
511 }
512}