1use crate::{segment::PrunePurpose, PruneSegment, PruneSegmentError};
2use alloy_primitives::BlockNumber;
3
4#[derive(Debug, PartialEq, Eq, Clone, Copy)]
6#[cfg_attr(any(test, feature = "test-utils"), derive(arbitrary::Arbitrary))]
7#[cfg_attr(any(test, feature = "reth-codec"), derive(reth_codecs::Compact))]
8#[cfg_attr(any(test, feature = "reth-codec"), reth_codecs::add_arbitrary_tests(compact))]
9#[cfg_attr(any(test, feature = "serde"), derive(serde::Serialize, serde::Deserialize))]
10#[cfg_attr(any(test, feature = "serde"), serde(rename_all = "lowercase"))]
11pub enum PruneMode {
12 Full,
14 Distance(u64),
16 Before(BlockNumber),
18}
19
20#[cfg(any(test, feature = "test-utils"))]
21#[allow(clippy::derivable_impls)]
22impl Default for PruneMode {
23 fn default() -> Self {
24 Self::Full
25 }
26}
27
28impl PruneMode {
29 pub const fn before_inclusive(block_number: BlockNumber) -> Self {
33 Self::Before(block_number + 1)
34 }
35
36 pub fn prune_target_block(
39 &self,
40 tip: BlockNumber,
41 segment: PruneSegment,
42 purpose: PrunePurpose,
43 ) -> Result<Option<(BlockNumber, Self)>, PruneSegmentError> {
44 let result = match self {
45 Self::Full if segment.min_blocks(purpose) == 0 => Some((tip, *self)),
46 Self::Distance(distance) if *distance > tip => None, Self::Distance(distance) if *distance >= segment.min_blocks(purpose) => {
48 Some((tip - distance, *self))
49 }
50 Self::Before(n) if *n == tip + 1 && purpose.is_static_file() => Some((tip, *self)),
51 Self::Before(n) if *n > tip => None, Self::Before(n) => {
53 (tip - n >= segment.min_blocks(purpose)).then(|| ((*n).saturating_sub(1), *self))
54 }
55 _ => return Err(PruneSegmentError::Configuration(segment)),
56 };
57 Ok(result)
58 }
59
60 pub const fn should_prune(&self, block: BlockNumber, tip: BlockNumber) -> bool {
62 match self {
63 Self::Full => true,
64 Self::Distance(distance) => {
65 if *distance > tip {
66 return false
67 }
68 block < tip - *distance
69 }
70 Self::Before(n) => *n > block,
71 }
72 }
73
74 pub const fn is_full(&self) -> bool {
76 matches!(self, Self::Full)
77 }
78
79 pub const fn is_distance(&self) -> bool {
81 matches!(self, Self::Distance(_))
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use crate::{
88 PruneMode, PrunePurpose, PruneSegment, PruneSegmentError, MINIMUM_PRUNING_DISTANCE,
89 };
90 use assert_matches::assert_matches;
91 use serde::Deserialize;
92
93 #[test]
94 fn test_prune_target_block() {
95 let tip = 20000;
96 let segment = PruneSegment::Receipts;
97
98 let tests = vec![
99 (PruneMode::Full, Err(PruneSegmentError::Configuration(segment))),
101 (PruneMode::Distance(tip + 1), Ok(None)),
103 (
104 PruneMode::Distance(segment.min_blocks(PrunePurpose::User) + 1),
105 Ok(Some(tip - (segment.min_blocks(PrunePurpose::User) + 1))),
106 ),
107 (PruneMode::Before(tip + 1), Ok(None)),
109 (
110 PruneMode::Before(tip - MINIMUM_PRUNING_DISTANCE),
111 Ok(Some(tip - MINIMUM_PRUNING_DISTANCE - 1)),
112 ),
113 (
114 PruneMode::Before(tip - MINIMUM_PRUNING_DISTANCE - 1),
115 Ok(Some(tip - MINIMUM_PRUNING_DISTANCE - 2)),
116 ),
117 (PruneMode::Before(tip - 1), Ok(None)),
119 ];
120
121 for (index, (mode, expected_result)) in tests.into_iter().enumerate() {
122 assert_eq!(
123 mode.prune_target_block(tip, segment, PrunePurpose::User),
124 expected_result.map(|r| r.map(|b| (b, mode))),
125 "Test {} failed",
126 index + 1,
127 );
128 }
129
130 assert_eq!(
132 PruneMode::Full.prune_target_block(
133 tip,
134 PruneSegment::TransactionLookup,
135 PrunePurpose::User
136 ),
137 Ok(Some((tip, PruneMode::Full))),
138 );
139 }
140
141 #[test]
142 fn test_should_prune() {
143 let tip = 20000;
144 let should_prune = true;
145
146 let tests = vec![
147 (PruneMode::Distance(tip + 1), 1, !should_prune),
148 (
149 PruneMode::Distance(MINIMUM_PRUNING_DISTANCE + 1),
150 tip - MINIMUM_PRUNING_DISTANCE - 1,
151 !should_prune,
152 ),
153 (
154 PruneMode::Distance(MINIMUM_PRUNING_DISTANCE + 1),
155 tip - MINIMUM_PRUNING_DISTANCE - 2,
156 should_prune,
157 ),
158 (PruneMode::Before(tip + 1), 1, should_prune),
159 (PruneMode::Before(tip + 1), tip + 1, !should_prune),
160 ];
161
162 for (index, (mode, block, expected_result)) in tests.into_iter().enumerate() {
163 assert_eq!(mode.should_prune(block, tip), expected_result, "Test {} failed", index + 1,);
164 }
165 }
166
167 #[test]
168 fn prune_mode_deserialize() {
169 #[derive(Debug, Deserialize)]
170 struct Config {
171 a: Option<PruneMode>,
172 b: Option<PruneMode>,
173 c: Option<PruneMode>,
174 d: Option<PruneMode>,
175 }
176
177 let toml_str = r#"
178 a = "full"
179 b = { distance = 10 }
180 c = { before = 20 }
181 "#;
182
183 assert_matches!(
184 toml::from_str(toml_str),
185 Ok(Config {
186 a: Some(PruneMode::Full),
187 b: Some(PruneMode::Distance(10)),
188 c: Some(PruneMode::Before(20)),
189 d: None
190 })
191 );
192 }
193}