use alloy_primitives::{Address, BlockNumber, B256};
use bytes::Buf;
use reth_codecs::{add_arbitrary_tests, Compact};
use reth_trie_common::{hash_builder::HashBuilderState, StoredSubNode};
use serde::{Deserialize, Serialize};
use std::ops::RangeInclusive;
use super::StageId;
#[derive(Default, Debug, Clone, PartialEq)]
pub struct MerkleCheckpoint {
pub target_block: BlockNumber,
pub last_account_key: B256,
pub walker_stack: Vec<StoredSubNode>,
pub state: HashBuilderState,
}
impl MerkleCheckpoint {
pub const fn new(
target_block: BlockNumber,
last_account_key: B256,
walker_stack: Vec<StoredSubNode>,
state: HashBuilderState,
) -> Self {
Self { target_block, last_account_key, walker_stack, state }
}
}
impl Compact for MerkleCheckpoint {
fn to_compact<B>(&self, buf: &mut B) -> usize
where
B: bytes::BufMut + AsMut<[u8]>,
{
let mut len = 0;
buf.put_u64(self.target_block);
len += 8;
buf.put_slice(self.last_account_key.as_slice());
len += self.last_account_key.len();
buf.put_u16(self.walker_stack.len() as u16);
len += 2;
for item in &self.walker_stack {
len += item.to_compact(buf);
}
len += self.state.to_compact(buf);
len
}
fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) {
let target_block = buf.get_u64();
let last_account_key = B256::from_slice(&buf[..32]);
buf.advance(32);
let walker_stack_len = buf.get_u16() as usize;
let mut walker_stack = Vec::with_capacity(walker_stack_len);
for _ in 0..walker_stack_len {
let (item, rest) = StoredSubNode::from_compact(buf, 0);
walker_stack.push(item);
buf = rest;
}
let (state, buf) = HashBuilderState::from_compact(buf, 0);
(Self { target_block, last_account_key, walker_stack, state }, buf)
}
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct AccountHashingCheckpoint {
pub address: Option<Address>,
pub block_range: CheckpointBlockRange,
pub progress: EntitiesCheckpoint,
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct StorageHashingCheckpoint {
pub address: Option<Address>,
pub storage: Option<B256>,
pub block_range: CheckpointBlockRange,
pub progress: EntitiesCheckpoint,
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct ExecutionCheckpoint {
pub block_range: CheckpointBlockRange,
pub progress: EntitiesCheckpoint,
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct HeadersCheckpoint {
pub block_range: CheckpointBlockRange,
pub progress: EntitiesCheckpoint,
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct IndexHistoryCheckpoint {
pub block_range: CheckpointBlockRange,
pub progress: EntitiesCheckpoint,
}
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct EntitiesCheckpoint {
pub processed: u64,
pub total: u64,
}
impl EntitiesCheckpoint {
pub fn fmt_percentage(&self) -> Option<String> {
if self.total == 0 {
return None
}
let percentage = 100.0 * self.processed as f64 / self.total as f64;
Some(format!("{:.2}%", (percentage * 100.0).floor() / 100.0))
}
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct CheckpointBlockRange {
pub from: BlockNumber,
pub to: BlockNumber,
}
impl From<RangeInclusive<BlockNumber>> for CheckpointBlockRange {
fn from(range: RangeInclusive<BlockNumber>) -> Self {
Self { from: *range.start(), to: *range.end() }
}
}
impl From<&RangeInclusive<BlockNumber>> for CheckpointBlockRange {
fn from(range: &RangeInclusive<BlockNumber>) -> Self {
Self { from: *range.start(), to: *range.end() }
}
}
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub struct StageCheckpoint {
pub block_number: BlockNumber,
pub stage_checkpoint: Option<StageUnitCheckpoint>,
}
impl StageCheckpoint {
pub fn new(block_number: BlockNumber) -> Self {
Self { block_number, ..Default::default() }
}
pub const fn with_block_number(mut self, block_number: BlockNumber) -> Self {
self.block_number = block_number;
self
}
pub fn with_block_range(mut self, stage_id: &StageId, from: u64, to: u64) -> Self {
self.stage_checkpoint = Some(match stage_id {
StageId::Execution => StageUnitCheckpoint::Execution(ExecutionCheckpoint::default()),
StageId::AccountHashing => {
StageUnitCheckpoint::Account(AccountHashingCheckpoint::default())
}
StageId::StorageHashing => {
StageUnitCheckpoint::Storage(StorageHashingCheckpoint::default())
}
StageId::IndexStorageHistory | StageId::IndexAccountHistory => {
StageUnitCheckpoint::IndexHistory(IndexHistoryCheckpoint::default())
}
_ => return self,
});
_ = self.stage_checkpoint.map(|mut checkpoint| checkpoint.set_block_range(from, to));
self
}
pub fn entities(&self) -> Option<EntitiesCheckpoint> {
let stage_checkpoint = self.stage_checkpoint?;
match stage_checkpoint {
StageUnitCheckpoint::Account(AccountHashingCheckpoint {
progress: entities, ..
}) |
StageUnitCheckpoint::Storage(StorageHashingCheckpoint {
progress: entities, ..
}) |
StageUnitCheckpoint::Entities(entities) |
StageUnitCheckpoint::Execution(ExecutionCheckpoint { progress: entities, .. }) |
StageUnitCheckpoint::Headers(HeadersCheckpoint { progress: entities, .. }) |
StageUnitCheckpoint::IndexHistory(IndexHistoryCheckpoint {
progress: entities,
..
}) => Some(entities),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, Compact)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(compact)]
pub enum StageUnitCheckpoint {
Account(AccountHashingCheckpoint),
Storage(StorageHashingCheckpoint),
Entities(EntitiesCheckpoint),
Execution(ExecutionCheckpoint),
Headers(HeadersCheckpoint),
IndexHistory(IndexHistoryCheckpoint),
}
impl StageUnitCheckpoint {
pub fn set_block_range(&mut self, from: u64, to: u64) -> Option<CheckpointBlockRange> {
match self {
Self::Account(AccountHashingCheckpoint { ref mut block_range, .. }) |
Self::Storage(StorageHashingCheckpoint { ref mut block_range, .. }) |
Self::Execution(ExecutionCheckpoint { ref mut block_range, .. }) |
Self::IndexHistory(IndexHistoryCheckpoint { ref mut block_range, .. }) => {
let old_range = *block_range;
*block_range = CheckpointBlockRange { from, to };
Some(old_range)
}
_ => None,
}
}
}
#[cfg(test)]
impl Default for StageUnitCheckpoint {
fn default() -> Self {
Self::Account(AccountHashingCheckpoint::default())
}
}
macro_rules! stage_unit_checkpoints {
($(($index:expr,$enum_variant:tt,$checkpoint_ty:ty,#[doc = $fn_get_doc:expr]$fn_get_name:ident,#[doc = $fn_build_doc:expr]$fn_build_name:ident)),+) => {
impl StageCheckpoint {
$(
#[doc = $fn_get_doc]
pub const fn $fn_get_name(&self) -> Option<$checkpoint_ty> {
match self.stage_checkpoint {
Some(StageUnitCheckpoint::$enum_variant(checkpoint)) => Some(checkpoint),
_ => None,
}
}
#[doc = $fn_build_doc]
pub const fn $fn_build_name(
mut self,
checkpoint: $checkpoint_ty,
) -> Self {
self.stage_checkpoint = Some(StageUnitCheckpoint::$enum_variant(checkpoint));
self
}
)+
}
};
}
stage_unit_checkpoints!(
(
0,
Account,
AccountHashingCheckpoint,
account_hashing_stage_checkpoint,
with_account_hashing_stage_checkpoint
),
(
1,
Storage,
StorageHashingCheckpoint,
storage_hashing_stage_checkpoint,
with_storage_hashing_stage_checkpoint
),
(
2,
Entities,
EntitiesCheckpoint,
entities_stage_checkpoint,
with_entities_stage_checkpoint
),
(
3,
Execution,
ExecutionCheckpoint,
execution_stage_checkpoint,
with_execution_stage_checkpoint
),
(
4,
Headers,
HeadersCheckpoint,
headers_stage_checkpoint,
with_headers_stage_checkpoint
),
(
5,
IndexHistory,
IndexHistoryCheckpoint,
index_history_stage_checkpoint,
with_index_history_stage_checkpoint
)
);
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
#[test]
fn merkle_checkpoint_roundtrip() {
let mut rng = rand::thread_rng();
let checkpoint = MerkleCheckpoint {
target_block: rng.gen(),
last_account_key: rng.gen(),
walker_stack: vec![StoredSubNode {
key: B256::random_with(&mut rng).to_vec(),
nibble: Some(rng.gen()),
node: None,
}],
state: HashBuilderState::default(),
};
let mut buf = Vec::new();
let encoded = checkpoint.to_compact(&mut buf);
let (decoded, _) = MerkleCheckpoint::from_compact(&buf, encoded);
assert_eq!(decoded, checkpoint);
}
}