1use crate::{segment::PrunePurpose, PruneSegment, PruneSegmentError};
2use alloy_primitives::BlockNumber;
3
4#[derive(Debug, PartialEq, Eq, Clone, Copy)]
6#[cfg_attr(any(test, feature = "test-utils"), derive(arbitrary::Arbitrary))]
7#[cfg_attr(any(test, feature = "reth-codec"), derive(reth_codecs::Compact))]
8#[cfg_attr(any(test, feature = "reth-codec"), reth_codecs::add_arbitrary_tests(compact))]
9#[cfg_attr(any(test, feature = "serde"), derive(serde::Serialize, serde::Deserialize))]
10#[cfg_attr(any(test, feature = "serde"), serde(rename_all = "lowercase"))]
11pub enum PruneMode {
12 Full,
14 Distance(u64),
16 Before(BlockNumber),
18}
19
20#[cfg(any(test, feature = "test-utils"))]
21#[allow(clippy::derivable_impls)]
22impl Default for PruneMode {
23 fn default() -> Self {
24 Self::Full
25 }
26}
27
28impl PruneMode {
29 pub const fn before_inclusive(block_number: BlockNumber) -> Self {
33 Self::Before(block_number + 1)
34 }
35
36 pub fn prune_target_block(
39 &self,
40 tip: BlockNumber,
41 segment: PruneSegment,
42 purpose: PrunePurpose,
43 ) -> Result<Option<(BlockNumber, Self)>, PruneSegmentError> {
44 let min_blocks = segment.min_blocks();
45 let result = match self {
46 Self::Full if min_blocks == 0 => Some((tip, *self)),
47 Self::Full if min_blocks <= tip => Some((tip - min_blocks, *self)),
49 Self::Full => None, Self::Distance(distance) if *distance > tip => None, Self::Distance(distance) if *distance >= segment.min_blocks() => {
52 Some((tip - distance, *self))
53 }
54 Self::Before(n) if *n == tip + 1 && purpose.is_static_file() => Some((tip, *self)),
55 Self::Before(n) if *n > tip => None, Self::Before(n) => {
57 (tip - n >= segment.min_blocks()).then(|| ((*n).saturating_sub(1), *self))
58 }
59 _ => return Err(PruneSegmentError::Configuration(segment)),
60 };
61 Ok(result)
62 }
63
64 pub const fn should_prune(&self, block: BlockNumber, tip: BlockNumber) -> bool {
66 match self {
67 Self::Full => true,
68 Self::Distance(distance) => {
69 if *distance > tip {
70 return false
71 }
72 block < tip - *distance
73 }
74 Self::Before(n) => *n > block,
75 }
76 }
77
78 pub const fn is_full(&self) -> bool {
80 matches!(self, Self::Full)
81 }
82
83 pub const fn is_distance(&self) -> bool {
85 matches!(self, Self::Distance(_))
86 }
87
88 pub const fn next_pruned_block(&self, checkpoint: Option<BlockNumber>) -> Option<BlockNumber> {
103 let next = match checkpoint {
104 Some(c) => c + 1,
105 None => 0,
106 };
107
108 match self {
109 Self::Before(n) => {
110 if next < *n {
111 Some(next)
112 } else {
113 None
114 }
115 }
116 Self::Distance(_) | Self::Full => Some(next),
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use crate::{PruneMode, PrunePurpose, PruneSegment, MINIMUM_UNWIND_SAFE_DISTANCE};
124 use assert_matches::assert_matches;
125 use serde::Deserialize;
126
127 #[test]
128 fn test_prune_target_block() {
129 let tip = 20000;
130 let segment = PruneSegment::AccountHistory;
131
132 let tests = vec![
133 (PruneMode::Full, Ok(Some(tip - segment.min_blocks()))),
135 (PruneMode::Distance(tip + 1), Ok(None)),
137 (
138 PruneMode::Distance(segment.min_blocks() + 1),
139 Ok(Some(tip - (segment.min_blocks() + 1))),
140 ),
141 (PruneMode::Before(tip + 1), Ok(None)),
143 (
144 PruneMode::Before(tip - MINIMUM_UNWIND_SAFE_DISTANCE),
145 Ok(Some(tip - MINIMUM_UNWIND_SAFE_DISTANCE - 1)),
146 ),
147 (
148 PruneMode::Before(tip - MINIMUM_UNWIND_SAFE_DISTANCE - 1),
149 Ok(Some(tip - MINIMUM_UNWIND_SAFE_DISTANCE - 2)),
150 ),
151 (PruneMode::Before(tip - 1), Ok(None)),
153 ];
154
155 for (index, (mode, expected_result)) in tests.into_iter().enumerate() {
156 assert_eq!(
157 mode.prune_target_block(tip, segment, PrunePurpose::User),
158 expected_result.map(|r| r.map(|b| (b, mode))),
159 "Test {} failed",
160 index + 1,
161 );
162 }
163
164 assert_eq!(
166 PruneMode::Full.prune_target_block(
167 tip,
168 PruneSegment::TransactionLookup,
169 PrunePurpose::User
170 ),
171 Ok(Some((tip, PruneMode::Full))),
172 );
173 }
174
175 #[test]
176 fn test_should_prune() {
177 let tip = 20000;
178 let should_prune = true;
179
180 let tests = vec![
181 (PruneMode::Distance(tip + 1), 1, !should_prune),
182 (
183 PruneMode::Distance(MINIMUM_UNWIND_SAFE_DISTANCE + 1),
184 tip - MINIMUM_UNWIND_SAFE_DISTANCE - 1,
185 !should_prune,
186 ),
187 (
188 PruneMode::Distance(MINIMUM_UNWIND_SAFE_DISTANCE + 1),
189 tip - MINIMUM_UNWIND_SAFE_DISTANCE - 2,
190 should_prune,
191 ),
192 (PruneMode::Before(tip + 1), 1, should_prune),
193 (PruneMode::Before(tip + 1), tip + 1, !should_prune),
194 ];
195
196 for (index, (mode, block, expected_result)) in tests.into_iter().enumerate() {
197 assert_eq!(mode.should_prune(block, tip), expected_result, "Test {} failed", index + 1,);
198 }
199 }
200
201 #[test]
202 fn prune_mode_deserialize() {
203 #[derive(Debug, Deserialize)]
204 struct Config {
205 a: Option<PruneMode>,
206 b: Option<PruneMode>,
207 c: Option<PruneMode>,
208 d: Option<PruneMode>,
209 }
210
211 let toml_str = r#"
212 a = "full"
213 b = { distance = 10 }
214 c = { before = 20 }
215 "#;
216
217 assert_matches!(
218 toml::from_str(toml_str),
219 Ok(Config {
220 a: Some(PruneMode::Full),
221 b: Some(PruneMode::Distance(10)),
222 c: Some(PruneMode::Before(20)),
223 d: None
224 })
225 );
226 }
227}