use crate::metrics::PersistenceMetrics;
use alloy_eips::BlockNumHash;
use reth_chain_state::ExecutedBlock;
use reth_errors::ProviderError;
use reth_provider::{
providers::ProviderNodeTypes, writer::UnifiedStorageWriter, BlockHashReader,
ChainStateBlockWriter, DatabaseProviderFactory, ProviderFactory, StaticFileProviderFactory,
};
use reth_prune::{PrunerError, PrunerOutput, PrunerWithFactory};
use reth_stages_api::{MetricEvent, MetricEventsSender};
use std::{
sync::mpsc::{Receiver, SendError, Sender},
time::Instant,
};
use thiserror::Error;
use tokio::sync::oneshot;
use tracing::{debug, error};
#[derive(Debug)]
pub struct PersistenceService<N: ProviderNodeTypes> {
provider: ProviderFactory<N>,
incoming: Receiver<PersistenceAction>,
pruner: PrunerWithFactory<ProviderFactory<N>>,
metrics: PersistenceMetrics,
sync_metrics_tx: MetricEventsSender,
}
impl<N: ProviderNodeTypes> PersistenceService<N> {
pub fn new(
provider: ProviderFactory<N>,
incoming: Receiver<PersistenceAction>,
pruner: PrunerWithFactory<ProviderFactory<N>>,
sync_metrics_tx: MetricEventsSender,
) -> Self {
Self { provider, incoming, pruner, metrics: PersistenceMetrics::default(), sync_metrics_tx }
}
fn prune_before(&mut self, block_num: u64) -> Result<PrunerOutput, PrunerError> {
debug!(target: "engine::persistence", ?block_num, "Running pruner");
let start_time = Instant::now();
let result = self.pruner.run(block_num);
self.metrics.prune_before_duration_seconds.record(start_time.elapsed());
result
}
}
impl<N: ProviderNodeTypes> PersistenceService<N> {
pub fn run(mut self) -> Result<(), PersistenceError> {
while let Ok(action) = self.incoming.recv() {
match action {
PersistenceAction::RemoveBlocksAbove(new_tip_num, sender) => {
let result = self.on_remove_blocks_above(new_tip_num)?;
let _ =
self.sync_metrics_tx.send(MetricEvent::SyncHeight { height: new_tip_num });
let _ = sender.send(result);
}
PersistenceAction::SaveBlocks(blocks, sender) => {
let result = self.on_save_blocks(blocks)?;
let result_number = result.map(|r| r.number);
let _ = sender.send(result);
if let Some(block_number) = result_number {
let _ = self
.sync_metrics_tx
.send(MetricEvent::SyncHeight { height: block_number });
if self.pruner.is_pruning_needed(block_number) {
let _ = self.prune_before(block_number)?;
}
}
}
PersistenceAction::SaveFinalizedBlock(finalized_block) => {
let provider = self.provider.database_provider_rw()?;
provider.save_finalized_block_number(finalized_block)?;
provider.commit()?;
}
PersistenceAction::SaveSafeBlock(safe_block) => {
let provider = self.provider.database_provider_rw()?;
provider.save_safe_block_number(safe_block)?;
provider.commit()?;
}
}
}
Ok(())
}
fn on_remove_blocks_above(
&self,
new_tip_num: u64,
) -> Result<Option<BlockNumHash>, PersistenceError> {
debug!(target: "engine::persistence", ?new_tip_num, "Removing blocks");
let start_time = Instant::now();
let provider_rw = self.provider.database_provider_rw()?;
let sf_provider = self.provider.static_file_provider();
let new_tip_hash = provider_rw.block_hash(new_tip_num)?;
UnifiedStorageWriter::from(&provider_rw, &sf_provider).remove_blocks_above(new_tip_num)?;
UnifiedStorageWriter::commit_unwind(provider_rw, sf_provider)?;
debug!(target: "engine::persistence", ?new_tip_num, ?new_tip_hash, "Removed blocks from disk");
self.metrics.remove_blocks_above_duration_seconds.record(start_time.elapsed());
Ok(new_tip_hash.map(|hash| BlockNumHash { hash, number: new_tip_num }))
}
fn on_save_blocks(
&self,
blocks: Vec<ExecutedBlock>,
) -> Result<Option<BlockNumHash>, PersistenceError> {
debug!(target: "engine::persistence", first=?blocks.first().map(|b| b.block.num_hash()), last=?blocks.last().map(|b| b.block.num_hash()), "Saving range of blocks");
let start_time = Instant::now();
let last_block_hash_num = blocks
.last()
.map(|block| BlockNumHash { hash: block.block().hash(), number: block.block().number });
if last_block_hash_num.is_some() {
let provider_rw = self.provider.database_provider_rw()?;
let static_file_provider = self.provider.static_file_provider();
UnifiedStorageWriter::from(&provider_rw, &static_file_provider).save_blocks(&blocks)?;
UnifiedStorageWriter::commit(provider_rw, static_file_provider)?;
}
self.metrics.save_blocks_duration_seconds.record(start_time.elapsed());
Ok(last_block_hash_num)
}
}
#[derive(Debug, Error)]
pub enum PersistenceError {
#[error(transparent)]
PrunerError(#[from] PrunerError),
#[error(transparent)]
ProviderError(#[from] ProviderError),
}
#[derive(Debug)]
pub enum PersistenceAction {
SaveBlocks(Vec<ExecutedBlock>, oneshot::Sender<Option<BlockNumHash>>),
RemoveBlocksAbove(u64, oneshot::Sender<Option<BlockNumHash>>),
SaveFinalizedBlock(u64),
SaveSafeBlock(u64),
}
#[derive(Debug, Clone)]
pub struct PersistenceHandle {
sender: Sender<PersistenceAction>,
}
impl PersistenceHandle {
pub const fn new(sender: Sender<PersistenceAction>) -> Self {
Self { sender }
}
pub fn spawn_service<N: ProviderNodeTypes>(
provider_factory: ProviderFactory<N>,
pruner: PrunerWithFactory<ProviderFactory<N>>,
sync_metrics_tx: MetricEventsSender,
) -> Self {
let (db_service_tx, db_service_rx) = std::sync::mpsc::channel();
let persistence_handle = Self::new(db_service_tx);
let db_service =
PersistenceService::new(provider_factory, db_service_rx, pruner, sync_metrics_tx);
std::thread::Builder::new()
.name("Persistence Service".to_string())
.spawn(|| {
if let Err(err) = db_service.run() {
error!(target: "engine::persistence", ?err, "Persistence service failed");
}
})
.unwrap();
persistence_handle
}
pub fn send_action(
&self,
action: PersistenceAction,
) -> Result<(), SendError<PersistenceAction>> {
self.sender.send(action)
}
pub fn save_blocks(
&self,
blocks: Vec<ExecutedBlock>,
tx: oneshot::Sender<Option<BlockNumHash>>,
) -> Result<(), SendError<PersistenceAction>> {
self.send_action(PersistenceAction::SaveBlocks(blocks, tx))
}
pub fn save_finalized_block_number(
&self,
finalized_block: u64,
) -> Result<(), SendError<PersistenceAction>> {
self.send_action(PersistenceAction::SaveFinalizedBlock(finalized_block))
}
pub fn save_safe_block_number(
&self,
safe_block: u64,
) -> Result<(), SendError<PersistenceAction>> {
self.send_action(PersistenceAction::SaveSafeBlock(safe_block))
}
pub fn remove_blocks_above(
&self,
block_num: u64,
tx: oneshot::Sender<Option<BlockNumHash>>,
) -> Result<(), SendError<PersistenceAction>> {
self.send_action(PersistenceAction::RemoveBlocksAbove(block_num, tx))
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::B256;
use reth_chain_state::test_utils::TestBlockBuilder;
use reth_exex_types::FinishedExExHeight;
use reth_provider::test_utils::create_test_provider_factory;
use reth_prune::Pruner;
use tokio::sync::mpsc::unbounded_channel;
fn default_persistence_handle() -> PersistenceHandle {
let provider = create_test_provider_factory();
let (_finished_exex_height_tx, finished_exex_height_rx) =
tokio::sync::watch::channel(FinishedExExHeight::NoExExs);
let pruner =
Pruner::new_with_factory(provider.clone(), vec![], 5, 0, None, finished_exex_height_rx);
let (sync_metrics_tx, _sync_metrics_rx) = unbounded_channel();
PersistenceHandle::spawn_service(provider, pruner, sync_metrics_tx)
}
#[tokio::test]
async fn test_save_blocks_empty() {
reth_tracing::init_test_tracing();
let persistence_handle = default_persistence_handle();
let blocks = vec![];
let (tx, rx) = oneshot::channel();
persistence_handle.save_blocks(blocks, tx).unwrap();
let hash = rx.await.unwrap();
assert_eq!(hash, None);
}
#[tokio::test]
async fn test_save_blocks_single_block() {
reth_tracing::init_test_tracing();
let persistence_handle = default_persistence_handle();
let block_number = 0;
let mut test_block_builder = TestBlockBuilder::default();
let executed =
test_block_builder.get_executed_block_with_number(block_number, B256::random());
let block_hash = executed.block().hash();
let blocks = vec![executed];
let (tx, rx) = oneshot::channel();
persistence_handle.save_blocks(blocks, tx).unwrap();
let BlockNumHash { hash: actual_hash, number: _ } =
tokio::time::timeout(std::time::Duration::from_secs(10), rx)
.await
.expect("test timed out")
.expect("channel closed unexpectedly")
.expect("no hash returned");
assert_eq!(block_hash, actual_hash);
}
#[tokio::test]
async fn test_save_blocks_multiple_blocks() {
reth_tracing::init_test_tracing();
let persistence_handle = default_persistence_handle();
let mut test_block_builder = TestBlockBuilder::default();
let blocks = test_block_builder.get_executed_blocks(0..5).collect::<Vec<_>>();
let last_hash = blocks.last().unwrap().block().hash();
let (tx, rx) = oneshot::channel();
persistence_handle.save_blocks(blocks, tx).unwrap();
let BlockNumHash { hash: actual_hash, number: _ } = rx.await.unwrap().unwrap();
assert_eq!(last_hash, actual_hash);
}
#[tokio::test]
async fn test_save_blocks_multiple_calls() {
reth_tracing::init_test_tracing();
let persistence_handle = default_persistence_handle();
let ranges = [0..1, 1..2, 2..4, 4..5];
let mut test_block_builder = TestBlockBuilder::default();
for range in ranges {
let blocks = test_block_builder.get_executed_blocks(range).collect::<Vec<_>>();
let last_hash = blocks.last().unwrap().block().hash();
let (tx, rx) = oneshot::channel();
persistence_handle.save_blocks(blocks, tx).unwrap();
let BlockNumHash { hash: actual_hash, number: _ } = rx.await.unwrap().unwrap();
assert_eq!(last_hash, actual_hash);
}
}
}