reth_db_api/models/
integer_list.rs

1//! Implements [`Compress`] and [`Decompress`] for [`IntegerList`]
2
3use crate::{
4    table::{Compress, Decompress},
5    DatabaseError,
6};
7use bytes::BufMut;
8use core::fmt;
9use derive_more::Deref;
10use roaring::RoaringTreemap;
11
12/// A data structure that uses Roaring Bitmaps to efficiently store a list of integers.
13///
14/// This structure provides excellent compression while allowing direct access to individual
15/// elements without the need for full decompression.
16///
17/// Key features:
18/// - Efficient compression: the underlying Roaring Bitmaps significantly reduce memory usage.
19/// - Direct access: elements can be accessed or queried without needing to decode the entire list.
20/// - [`RoaringTreemap`] backing: internally backed by [`RoaringTreemap`], which supports 64-bit
21///   integers.
22#[derive(Clone, PartialEq, Default, Deref)]
23pub struct IntegerList(pub RoaringTreemap);
24
25impl fmt::Debug for IntegerList {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        f.write_str("IntegerList")?;
28        f.debug_list().entries(self.0.iter()).finish()
29    }
30}
31
32impl IntegerList {
33    /// Creates a new empty [`IntegerList`].
34    pub fn empty() -> Self {
35        Self(RoaringTreemap::new())
36    }
37
38    /// Creates an [`IntegerList`] from a list of integers.
39    ///
40    /// Returns an error if the list is not pre-sorted.
41    pub fn new(list: impl IntoIterator<Item = u64>) -> Result<Self, IntegerListError> {
42        RoaringTreemap::from_sorted_iter(list)
43            .map(Self)
44            .map_err(|_| IntegerListError::UnsortedInput)
45    }
46
47    /// Creates an [`IntegerList`] from a pre-sorted list of integers.
48    ///
49    /// # Panics
50    ///
51    /// Panics if the list is not pre-sorted.
52    #[inline]
53    #[track_caller]
54    pub fn new_pre_sorted(list: impl IntoIterator<Item = u64>) -> Self {
55        Self::new(list).expect("IntegerList must be pre-sorted and non-empty")
56    }
57
58    /// Appends a list of integers to the current list.
59    pub fn append(&mut self, list: impl IntoIterator<Item = u64>) -> Result<u64, IntegerListError> {
60        self.0.append(list).map_err(|_| IntegerListError::UnsortedInput)
61    }
62
63    /// Pushes a new integer to the list.
64    pub fn push(&mut self, value: u64) -> Result<(), IntegerListError> {
65        self.0.push(value).then_some(()).ok_or(IntegerListError::UnsortedInput)
66    }
67
68    /// Clears the list.
69    pub fn clear(&mut self) {
70        self.0.clear();
71    }
72
73    /// Serializes a [`IntegerList`] into a sequence of bytes.
74    pub fn to_bytes(&self) -> Vec<u8> {
75        let mut vec = Vec::with_capacity(self.0.serialized_size());
76        self.0.serialize_into(&mut vec).expect("not able to encode IntegerList");
77        vec
78    }
79
80    /// Serializes a [`IntegerList`] into a sequence of bytes.
81    pub fn to_mut_bytes<B: bytes::BufMut>(&self, buf: &mut B) {
82        self.0.serialize_into(buf.writer()).unwrap();
83    }
84
85    /// Deserializes a sequence of bytes into a proper [`IntegerList`].
86    pub fn from_bytes(data: &[u8]) -> Result<Self, IntegerListError> {
87        RoaringTreemap::deserialize_from(data)
88            .map(Self)
89            .map_err(|_| IntegerListError::FailedToDeserialize)
90    }
91}
92
93impl serde::Serialize for IntegerList {
94    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
95    where
96        S: serde::Serializer,
97    {
98        use serde::ser::SerializeSeq;
99
100        let mut seq = serializer.serialize_seq(Some(self.len() as usize))?;
101        for e in &self.0 {
102            seq.serialize_element(&e)?;
103        }
104        seq.end()
105    }
106}
107
108struct IntegerListVisitor;
109
110impl<'de> serde::de::Visitor<'de> for IntegerListVisitor {
111    type Value = IntegerList;
112
113    fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        f.write_str("a usize array")
115    }
116
117    fn visit_seq<E>(self, mut seq: E) -> Result<Self::Value, E::Error>
118    where
119        E: serde::de::SeqAccess<'de>,
120    {
121        let mut list = IntegerList::empty();
122        while let Some(item) = seq.next_element()? {
123            list.push(item).map_err(serde::de::Error::custom)?;
124        }
125        Ok(list)
126    }
127}
128
129impl<'de> serde::Deserialize<'de> for IntegerList {
130    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131    where
132        D: serde::Deserializer<'de>,
133    {
134        deserializer.deserialize_byte_buf(IntegerListVisitor)
135    }
136}
137
138#[cfg(any(test, feature = "arbitrary"))]
139use arbitrary::{Arbitrary, Unstructured};
140
141#[cfg(any(test, feature = "arbitrary"))]
142impl<'a> Arbitrary<'a> for IntegerList {
143    fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self, arbitrary::Error> {
144        let mut nums: Vec<u64> = Vec::arbitrary(u)?;
145        nums.sort_unstable();
146        Self::new(nums).map_err(|_| arbitrary::Error::IncorrectFormat)
147    }
148}
149
150/// Primitives error type.
151#[derive(Debug, derive_more::Display, derive_more::Error)]
152pub enum IntegerListError {
153    /// The provided input is unsorted.
154    #[display("the provided input is unsorted")]
155    UnsortedInput,
156    /// Failed to deserialize data into type.
157    #[display("failed to deserialize data into type")]
158    FailedToDeserialize,
159}
160
161impl Compress for IntegerList {
162    type Compressed = Vec<u8>;
163
164    fn compress(self) -> Self::Compressed {
165        self.to_bytes()
166    }
167
168    fn compress_to_buf<B: bytes::BufMut + AsMut<[u8]>>(&self, buf: &mut B) {
169        self.to_mut_bytes(buf)
170    }
171}
172
173impl Decompress for IntegerList {
174    fn decompress(value: &[u8]) -> Result<Self, DatabaseError> {
175        Self::from_bytes(value).map_err(|_| DatabaseError::Decode)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn empty_list() {
185        assert_eq!(IntegerList::empty().len(), 0);
186        assert_eq!(IntegerList::new_pre_sorted(std::iter::empty()).len(), 0);
187    }
188
189    #[test]
190    fn test_integer_list() {
191        let original_list = [1, 2, 3];
192        let ef_list = IntegerList::new(original_list).unwrap();
193        assert_eq!(ef_list.iter().collect::<Vec<_>>(), original_list);
194    }
195
196    #[test]
197    fn test_integer_list_serialization() {
198        let original_list = [1, 2, 3];
199        let ef_list = IntegerList::new(original_list).unwrap();
200
201        let blist = ef_list.to_bytes();
202        assert_eq!(IntegerList::from_bytes(&blist).unwrap(), ef_list)
203    }
204}