reth_nippy_jar/compression/
zstd.rs

1use crate::{compression::Compression, NippyJarError};
2use derive_more::Deref;
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::{
5    fs::File,
6    io::{Read, Write},
7    sync::Arc,
8};
9use tracing::*;
10use zstd::bulk::Compressor;
11pub use zstd::{bulk::Decompressor, dict::DecoderDictionary};
12
13type RawDictionary = Vec<u8>;
14
15/// Represents the state of a Zstandard compression operation.
16#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
17pub enum ZstdState {
18    /// The compressor is pending a dictionary.
19    #[default]
20    PendingDictionary,
21    /// The compressor is ready to perform compression.
22    Ready,
23}
24
25#[cfg_attr(test, derive(PartialEq))]
26#[derive(Debug, Serialize, Deserialize)]
27/// Zstd compression structure. Supports a compression dictionary per column.
28pub struct Zstd {
29    /// State. Should be ready before compressing.
30    pub(crate) state: ZstdState,
31    /// Compression level. A level of `0` uses zstd's default (currently `3`).
32    pub(crate) level: i32,
33    /// Uses custom dictionaries to compress data.
34    pub use_dict: bool,
35    /// Max size of a dictionary
36    pub(crate) max_dict_size: usize,
37    /// List of column dictionaries.
38    #[serde(with = "dictionaries_serde")]
39    pub(crate) dictionaries: Option<Arc<ZstdDictionaries<'static>>>,
40    /// Number of columns to compress.
41    columns: usize,
42}
43
44impl Zstd {
45    /// Creates new [`Zstd`].
46    pub const fn new(use_dict: bool, max_dict_size: usize, columns: usize) -> Self {
47        Self {
48            state: if use_dict { ZstdState::PendingDictionary } else { ZstdState::Ready },
49            level: 0,
50            use_dict,
51            max_dict_size,
52            dictionaries: None,
53            columns,
54        }
55    }
56
57    /// Sets the compression level for the Zstd compression instance.
58    pub const fn with_level(mut self, level: i32) -> Self {
59        self.level = level;
60        self
61    }
62
63    /// Creates a list of [`Decompressor`] if using dictionaries.
64    pub fn decompressors(&self) -> Result<Vec<Decompressor<'_>>, NippyJarError> {
65        if let Some(dictionaries) = &self.dictionaries {
66            debug_assert!(dictionaries.len() == self.columns);
67            return dictionaries.decompressors()
68        }
69
70        Ok(vec![])
71    }
72
73    /// If using dictionaries, creates a list of [`Compressor`].
74    pub fn compressors(&self) -> Result<Option<Vec<Compressor<'_>>>, NippyJarError> {
75        match self.state {
76            ZstdState::PendingDictionary => Err(NippyJarError::CompressorNotReady),
77            ZstdState::Ready => {
78                if !self.use_dict {
79                    return Ok(None)
80                }
81
82                if let Some(dictionaries) = &self.dictionaries {
83                    debug!(target: "nippy-jar", count=?dictionaries.len(), "Generating ZSTD compressor dictionaries.");
84                    return Ok(Some(dictionaries.compressors()?))
85                }
86                Ok(None)
87            }
88        }
89    }
90
91    /// Compresses a value using a dictionary. Reserves additional capacity for `buffer` if
92    /// necessary.
93    pub fn compress_with_dictionary(
94        column_value: &[u8],
95        buffer: &mut Vec<u8>,
96        handle: &mut File,
97        compressor: Option<&mut Compressor<'_>>,
98    ) -> Result<(), NippyJarError> {
99        if let Some(compressor) = compressor {
100            // Compressor requires the destination buffer to be big enough to write, otherwise it
101            // fails. However, we don't know how big it will be. If data is small
102            // enough, the compressed buffer will actually be larger. We keep retrying.
103            // If we eventually fail, it probably means it's another kind of error.
104            let mut multiplier = 1;
105            while let Err(err) = compressor.compress_to_buffer(column_value, buffer) {
106                buffer.reserve(column_value.len() * multiplier);
107                multiplier += 1;
108                if multiplier == 5 {
109                    return Err(NippyJarError::Disconnect(err))
110                }
111            }
112
113            handle.write_all(buffer)?;
114            buffer.clear();
115        } else {
116            handle.write_all(column_value)?;
117        }
118
119        Ok(())
120    }
121
122    /// Appends a decompressed value using a dictionary to a user provided buffer.
123    pub fn decompress_with_dictionary(
124        column_value: &[u8],
125        output: &mut Vec<u8>,
126        decompressor: &mut Decompressor<'_>,
127    ) -> Result<(), NippyJarError> {
128        let previous_length = output.len();
129
130        // SAFETY: We're setting len to the existing capacity.
131        unsafe {
132            output.set_len(output.capacity());
133        }
134
135        match decompressor.decompress_to_buffer(column_value, &mut output[previous_length..]) {
136            Ok(written) => {
137                // SAFETY: `decompress_to_buffer` can only write if there's enough capacity.
138                // Therefore, it shouldn't write more than our capacity.
139                unsafe {
140                    output.set_len(previous_length + written);
141                }
142                Ok(())
143            }
144            Err(_) => {
145                // SAFETY: we are resetting it to the previous value.
146                unsafe {
147                    output.set_len(previous_length);
148                }
149                Err(NippyJarError::OutputTooSmall)
150            }
151        }
152    }
153}
154
155impl Compression for Zstd {
156    fn decompress_to(&self, value: &[u8], dest: &mut Vec<u8>) -> Result<(), NippyJarError> {
157        let mut decoder = zstd::Decoder::with_dictionary(value, &[])?;
158        decoder.read_to_end(dest)?;
159        Ok(())
160    }
161
162    fn decompress(&self, value: &[u8]) -> Result<Vec<u8>, NippyJarError> {
163        let mut decompressed = Vec::with_capacity(value.len() * 2);
164        let mut decoder = zstd::Decoder::new(value)?;
165        decoder.read_to_end(&mut decompressed)?;
166        Ok(decompressed)
167    }
168
169    fn compress_to(&self, src: &[u8], dest: &mut Vec<u8>) -> Result<usize, NippyJarError> {
170        let before = dest.len();
171
172        let mut encoder = zstd::Encoder::new(dest, self.level)?;
173        encoder.write_all(src)?;
174
175        let dest = encoder.finish()?;
176
177        Ok(dest.len() - before)
178    }
179
180    fn compress(&self, src: &[u8]) -> Result<Vec<u8>, NippyJarError> {
181        let mut compressed = Vec::with_capacity(src.len());
182
183        self.compress_to(src, &mut compressed)?;
184
185        Ok(compressed)
186    }
187
188    fn is_ready(&self) -> bool {
189        matches!(self.state, ZstdState::Ready)
190    }
191
192    #[cfg(test)]
193    /// If using it with dictionaries, prepares a dictionary for each column.
194    fn prepare_compression(
195        &mut self,
196        columns: Vec<impl IntoIterator<Item = Vec<u8>>>,
197    ) -> Result<(), NippyJarError> {
198        if !self.use_dict {
199            return Ok(())
200        }
201
202        // There's a per 2GB hard limit on each column data set for training
203        // REFERENCE: https://github.com/facebook/zstd/blob/dev/programs/zstd.1.md#dictionary-builder
204        // ```
205        // -M#, --memory=#: Limit the amount of sample data loaded for training (default: 2 GB).
206        // Note that the default (2 GB) is also the maximum. This parameter can be useful in
207        // situations where the training set size is not well controlled and could be potentially
208        // very large. Since speed of the training process is directly correlated to the size of the
209        // training sample set, a smaller sample set leads to faster training.`
210        // ```
211
212        if columns.len() != self.columns {
213            return Err(NippyJarError::ColumnLenMismatch(self.columns, columns.len()))
214        }
215
216        let mut dictionaries = Vec::with_capacity(columns.len());
217        for column in columns {
218            // ZSTD requires all training data to be continuous in memory, alongside the size of
219            // each entry
220            let mut sizes = vec![];
221            let data: Vec<_> = column
222                .into_iter()
223                .flat_map(|data| {
224                    sizes.push(data.len());
225                    data
226                })
227                .collect();
228
229            dictionaries.push(zstd::dict::from_continuous(&data, &sizes, self.max_dict_size)?);
230        }
231
232        debug_assert_eq!(dictionaries.len(), self.columns);
233
234        self.dictionaries = Some(Arc::new(ZstdDictionaries::new(dictionaries)));
235        self.state = ZstdState::Ready;
236
237        Ok(())
238    }
239}
240
241mod dictionaries_serde {
242    use super::*;
243
244    pub(crate) fn serialize<S>(
245        dictionaries: &Option<Arc<ZstdDictionaries<'static>>>,
246        serializer: S,
247    ) -> Result<S::Ok, S::Error>
248    where
249        S: Serializer,
250    {
251        match dictionaries {
252            Some(dicts) => serializer.serialize_some(dicts.as_ref()),
253            None => serializer.serialize_none(),
254        }
255    }
256
257    pub(crate) fn deserialize<'de, D>(
258        deserializer: D,
259    ) -> Result<Option<Arc<ZstdDictionaries<'static>>>, D::Error>
260    where
261        D: Deserializer<'de>,
262    {
263        let dictionaries: Option<Vec<RawDictionary>> = Option::deserialize(deserializer)?;
264        Ok(dictionaries.map(|dicts| Arc::new(ZstdDictionaries::load(dicts))))
265    }
266}
267
268/// List of [`ZstdDictionary`]
269#[cfg_attr(test, derive(PartialEq))]
270#[derive(Serialize, Deserialize, Deref)]
271pub(crate) struct ZstdDictionaries<'a>(Vec<ZstdDictionary<'a>>);
272
273impl std::fmt::Debug for ZstdDictionaries<'_> {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_struct("ZstdDictionaries").field("num", &self.len()).finish_non_exhaustive()
276    }
277}
278
279impl ZstdDictionaries<'_> {
280    #[cfg(test)]
281    /// Creates [`ZstdDictionaries`].
282    pub(crate) fn new(raw: Vec<RawDictionary>) -> Self {
283        Self(raw.into_iter().map(ZstdDictionary::Raw).collect())
284    }
285
286    /// Loads a list [`RawDictionary`] into a list of [`ZstdDictionary::Loaded`].
287    pub(crate) fn load(raw: Vec<RawDictionary>) -> Self {
288        Self(
289            raw.into_iter()
290                .map(|dict| ZstdDictionary::Loaded(DecoderDictionary::copy(&dict)))
291                .collect(),
292        )
293    }
294
295    /// Creates a list of decompressors from a list of [`ZstdDictionary::Loaded`].
296    pub(crate) fn decompressors(&self) -> Result<Vec<Decompressor<'_>>, NippyJarError> {
297        Ok(self
298            .iter()
299            .flat_map(|dict| {
300                dict.loaded()
301                    .ok_or(NippyJarError::DictionaryNotLoaded)
302                    .map(Decompressor::with_prepared_dictionary)
303            })
304            .collect::<Result<Vec<_>, _>>()?)
305    }
306
307    /// Creates a list of compressors from a list of [`ZstdDictionary::Raw`].
308    pub(crate) fn compressors(&self) -> Result<Vec<Compressor<'_>>, NippyJarError> {
309        Ok(self
310            .iter()
311            .flat_map(|dict| {
312                dict.raw()
313                    .ok_or(NippyJarError::CompressorNotAllowed)
314                    .map(|dict| Compressor::with_dictionary(0, dict))
315            })
316            .collect::<Result<Vec<_>, _>>()?)
317    }
318}
319
320/// A Zstd dictionary. It's created and serialized with [`ZstdDictionary::Raw`], and deserialized as
321/// [`ZstdDictionary::Loaded`].
322pub(crate) enum ZstdDictionary<'a> {
323    #[allow(dead_code)]
324    Raw(RawDictionary),
325    Loaded(DecoderDictionary<'a>),
326}
327
328impl ZstdDictionary<'_> {
329    /// Returns a reference to the expected `RawDictionary`
330    pub(crate) const fn raw(&self) -> Option<&RawDictionary> {
331        match self {
332            ZstdDictionary::Raw(dict) => Some(dict),
333            ZstdDictionary::Loaded(_) => None,
334        }
335    }
336
337    /// Returns a reference to the expected `DecoderDictionary`
338    pub(crate) const fn loaded(&self) -> Option<&DecoderDictionary<'_>> {
339        match self {
340            ZstdDictionary::Raw(_) => None,
341            ZstdDictionary::Loaded(dict) => Some(dict),
342        }
343    }
344}
345
346impl<'de> Deserialize<'de> for ZstdDictionary<'_> {
347    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
348    where
349        D: Deserializer<'de>,
350    {
351        let dict = RawDictionary::deserialize(deserializer)?;
352        Ok(Self::Loaded(DecoderDictionary::copy(&dict)))
353    }
354}
355
356impl Serialize for ZstdDictionary<'_> {
357    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
358    where
359        S: Serializer,
360    {
361        match self {
362            ZstdDictionary::Raw(r) => r.serialize(serializer),
363            ZstdDictionary::Loaded(_) => unreachable!(),
364        }
365    }
366}
367
368#[cfg(test)]
369impl PartialEq for ZstdDictionary<'_> {
370    fn eq(&self, other: &Self) -> bool {
371        if let (Self::Raw(a), Self::Raw(b)) = (self, &other) {
372            return a == b
373        }
374        unimplemented!("`DecoderDictionary` can't be compared. So comparison should be done after decompressing a value.");
375    }
376}