1use crate::{segment::PrunePurpose, PruneSegment, PruneSegmentError};
2use alloy_primitives::BlockNumber;
3
4#[derive(Debug, PartialEq, Eq, Clone, Copy)]
6#[cfg_attr(any(test, feature = "test-utils"), derive(arbitrary::Arbitrary))]
7#[cfg_attr(any(test, feature = "reth-codec"), derive(reth_codecs::Compact))]
8#[cfg_attr(any(test, feature = "reth-codec"), reth_codecs::add_arbitrary_tests(compact))]
9#[cfg_attr(any(test, feature = "serde"), derive(serde::Serialize, serde::Deserialize))]
10#[cfg_attr(any(test, feature = "serde"), serde(rename_all = "lowercase"))]
11pub enum PruneMode {
12 Full,
14 Distance(u64),
16 Before(BlockNumber),
18}
19
20#[cfg(any(test, feature = "test-utils"))]
21impl Default for PruneMode {
22 fn default() -> Self {
23 Self::Full
24 }
25}
26
27impl PruneMode {
28 pub const fn before_inclusive(block_number: BlockNumber) -> Self {
32 Self::Before(block_number + 1)
33 }
34
35 pub fn prune_target_block(
38 &self,
39 tip: BlockNumber,
40 segment: PruneSegment,
41 purpose: PrunePurpose,
42 ) -> Result<Option<(BlockNumber, Self)>, PruneSegmentError> {
43 let result = match self {
44 Self::Full if segment.min_blocks(purpose) == 0 => Some((tip, *self)),
45 Self::Distance(distance) if *distance > tip => None, Self::Distance(distance) if *distance >= segment.min_blocks(purpose) => {
47 Some((tip - distance, *self))
48 }
49 Self::Before(n) if *n == tip + 1 && purpose.is_static_file() => Some((tip, *self)),
50 Self::Before(n) if *n > tip => None, Self::Before(n) => {
52 (tip - n >= segment.min_blocks(purpose)).then(|| ((*n).saturating_sub(1), *self))
53 }
54 _ => return Err(PruneSegmentError::Configuration(segment)),
55 };
56 Ok(result)
57 }
58
59 pub const fn should_prune(&self, block: BlockNumber, tip: BlockNumber) -> bool {
61 match self {
62 Self::Full => true,
63 Self::Distance(distance) => {
64 if *distance > tip {
65 return false
66 }
67 block < tip - *distance
68 }
69 Self::Before(n) => *n > block,
70 }
71 }
72
73 pub const fn is_full(&self) -> bool {
75 matches!(self, Self::Full)
76 }
77
78 pub const fn is_distance(&self) -> bool {
80 matches!(self, Self::Distance(_))
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use crate::{
87 PruneMode, PrunePurpose, PruneSegment, PruneSegmentError, MINIMUM_PRUNING_DISTANCE,
88 };
89 use assert_matches::assert_matches;
90 use serde::Deserialize;
91
92 #[test]
93 fn test_prune_target_block() {
94 let tip = 20000;
95 let segment = PruneSegment::Receipts;
96
97 let tests = vec![
98 (PruneMode::Full, Err(PruneSegmentError::Configuration(segment))),
100 (PruneMode::Distance(tip + 1), Ok(None)),
102 (
103 PruneMode::Distance(segment.min_blocks(PrunePurpose::User) + 1),
104 Ok(Some(tip - (segment.min_blocks(PrunePurpose::User) + 1))),
105 ),
106 (PruneMode::Before(tip + 1), Ok(None)),
108 (
109 PruneMode::Before(tip - MINIMUM_PRUNING_DISTANCE),
110 Ok(Some(tip - MINIMUM_PRUNING_DISTANCE - 1)),
111 ),
112 (
113 PruneMode::Before(tip - MINIMUM_PRUNING_DISTANCE - 1),
114 Ok(Some(tip - MINIMUM_PRUNING_DISTANCE - 2)),
115 ),
116 (PruneMode::Before(tip - 1), Ok(None)),
118 ];
119
120 for (index, (mode, expected_result)) in tests.into_iter().enumerate() {
121 assert_eq!(
122 mode.prune_target_block(tip, segment, PrunePurpose::User),
123 expected_result.map(|r| r.map(|b| (b, mode))),
124 "Test {} failed",
125 index + 1,
126 );
127 }
128
129 assert_eq!(
131 PruneMode::Full.prune_target_block(tip, PruneSegment::Transactions, PrunePurpose::User),
132 Ok(Some((tip, PruneMode::Full))),
133 );
134 }
135
136 #[test]
137 fn test_should_prune() {
138 let tip = 20000;
139 let should_prune = true;
140
141 let tests = vec![
142 (PruneMode::Distance(tip + 1), 1, !should_prune),
143 (
144 PruneMode::Distance(MINIMUM_PRUNING_DISTANCE + 1),
145 tip - MINIMUM_PRUNING_DISTANCE - 1,
146 !should_prune,
147 ),
148 (
149 PruneMode::Distance(MINIMUM_PRUNING_DISTANCE + 1),
150 tip - MINIMUM_PRUNING_DISTANCE - 2,
151 should_prune,
152 ),
153 (PruneMode::Before(tip + 1), 1, should_prune),
154 (PruneMode::Before(tip + 1), tip + 1, !should_prune),
155 ];
156
157 for (index, (mode, block, expected_result)) in tests.into_iter().enumerate() {
158 assert_eq!(mode.should_prune(block, tip), expected_result, "Test {} failed", index + 1,);
159 }
160 }
161
162 #[test]
163 fn prune_mode_deserialize() {
164 #[derive(Debug, Deserialize)]
165 struct Config {
166 a: Option<PruneMode>,
167 b: Option<PruneMode>,
168 c: Option<PruneMode>,
169 d: Option<PruneMode>,
170 }
171
172 let toml_str = r#"
173 a = "full"
174 b = { distance = 10 }
175 c = { before = 20 }
176 "#;
177
178 assert_matches!(
179 toml::from_str(toml_str),
180 Ok(Config {
181 a: Some(PruneMode::Full),
182 b: Some(PruneMode::Distance(10)),
183 c: Some(PruneMode::Before(20)),
184 d: None
185 })
186 );
187 }
188}