reth_db_api/models/
integer_list.rs
1use crate::{
4 table::{Compress, Decompress},
5 DatabaseError,
6};
7use bytes::BufMut;
8use core::fmt;
9use derive_more::Deref;
10use roaring::RoaringTreemap;
11
12#[derive(Clone, PartialEq, Default, Deref)]
23pub struct IntegerList(pub RoaringTreemap);
24
25impl fmt::Debug for IntegerList {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 f.write_str("IntegerList")?;
28 f.debug_list().entries(self.0.iter()).finish()
29 }
30}
31
32impl IntegerList {
33 pub fn empty() -> Self {
35 Self(RoaringTreemap::new())
36 }
37
38 pub fn new(list: impl IntoIterator<Item = u64>) -> Result<Self, IntegerListError> {
42 RoaringTreemap::from_sorted_iter(list)
43 .map(Self)
44 .map_err(|_| IntegerListError::UnsortedInput)
45 }
46
47 #[inline]
53 #[track_caller]
54 pub fn new_pre_sorted(list: impl IntoIterator<Item = u64>) -> Self {
55 Self::new(list).expect("IntegerList must be pre-sorted and non-empty")
56 }
57
58 pub fn append(&mut self, list: impl IntoIterator<Item = u64>) -> Result<u64, IntegerListError> {
60 self.0.append(list).map_err(|_| IntegerListError::UnsortedInput)
61 }
62
63 pub fn push(&mut self, value: u64) -> Result<(), IntegerListError> {
65 self.0.push(value).then_some(()).ok_or(IntegerListError::UnsortedInput)
66 }
67
68 pub fn clear(&mut self) {
70 self.0.clear();
71 }
72
73 pub fn to_bytes(&self) -> Vec<u8> {
75 let mut vec = Vec::with_capacity(self.0.serialized_size());
76 self.0.serialize_into(&mut vec).expect("not able to encode IntegerList");
77 vec
78 }
79
80 pub fn to_mut_bytes<B: bytes::BufMut>(&self, buf: &mut B) {
82 self.0.serialize_into(buf.writer()).unwrap();
83 }
84
85 pub fn from_bytes(data: &[u8]) -> Result<Self, IntegerListError> {
87 RoaringTreemap::deserialize_from(data)
88 .map(Self)
89 .map_err(|_| IntegerListError::FailedToDeserialize)
90 }
91}
92
93impl serde::Serialize for IntegerList {
94 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
95 where
96 S: serde::Serializer,
97 {
98 use serde::ser::SerializeSeq;
99
100 let mut seq = serializer.serialize_seq(Some(self.len() as usize))?;
101 for e in &self.0 {
102 seq.serialize_element(&e)?;
103 }
104 seq.end()
105 }
106}
107
108struct IntegerListVisitor;
109
110impl<'de> serde::de::Visitor<'de> for IntegerListVisitor {
111 type Value = IntegerList;
112
113 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 f.write_str("a usize array")
115 }
116
117 fn visit_seq<E>(self, mut seq: E) -> Result<Self::Value, E::Error>
118 where
119 E: serde::de::SeqAccess<'de>,
120 {
121 let mut list = IntegerList::empty();
122 while let Some(item) = seq.next_element()? {
123 list.push(item).map_err(serde::de::Error::custom)?;
124 }
125 Ok(list)
126 }
127}
128
129impl<'de> serde::Deserialize<'de> for IntegerList {
130 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131 where
132 D: serde::Deserializer<'de>,
133 {
134 deserializer.deserialize_byte_buf(IntegerListVisitor)
135 }
136}
137
138#[cfg(any(test, feature = "arbitrary"))]
139use arbitrary::{Arbitrary, Unstructured};
140
141#[cfg(any(test, feature = "arbitrary"))]
142impl<'a> Arbitrary<'a> for IntegerList {
143 fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self, arbitrary::Error> {
144 let mut nums: Vec<u64> = Vec::arbitrary(u)?;
145 nums.sort_unstable();
146 Self::new(nums).map_err(|_| arbitrary::Error::IncorrectFormat)
147 }
148}
149
150#[derive(Debug, derive_more::Display, derive_more::Error)]
152pub enum IntegerListError {
153 #[display("the provided input is unsorted")]
155 UnsortedInput,
156 #[display("failed to deserialize data into type")]
158 FailedToDeserialize,
159}
160
161impl Compress for IntegerList {
162 type Compressed = Vec<u8>;
163
164 fn compress(self) -> Self::Compressed {
165 self.to_bytes()
166 }
167
168 fn compress_to_buf<B: bytes::BufMut + AsMut<[u8]>>(&self, buf: &mut B) {
169 self.to_mut_bytes(buf)
170 }
171}
172
173impl Decompress for IntegerList {
174 fn decompress(value: &[u8]) -> Result<Self, DatabaseError> {
175 Self::from_bytes(value).map_err(|_| DatabaseError::Decode)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn empty_list() {
185 assert_eq!(IntegerList::empty().len(), 0);
186 assert_eq!(IntegerList::new_pre_sorted(std::iter::empty()).len(), 0);
187 }
188
189 #[test]
190 fn test_integer_list() {
191 let original_list = [1, 2, 3];
192 let ef_list = IntegerList::new(original_list).unwrap();
193 assert_eq!(ef_list.iter().collect::<Vec<_>>(), original_list);
194 }
195
196 #[test]
197 fn test_integer_list_serialization() {
198 let original_list = [1, 2, 3];
199 let ef_list = IntegerList::new(original_list).unwrap();
200
201 let blist = ef_list.to_bytes();
202 assert_eq!(IntegerList::from_bytes(&blist).unwrap(), ef_list)
203 }
204}