reth_nippy_jar/compression/
zstd.rs
1use crate::{compression::Compression, NippyJarError};
2use derive_more::Deref;
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::{
5 fs::File,
6 io::{Read, Write},
7 sync::Arc,
8};
9use tracing::*;
10use zstd::bulk::Compressor;
11pub use zstd::{bulk::Decompressor, dict::DecoderDictionary};
12
13type RawDictionary = Vec<u8>;
14
15#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
17pub enum ZstdState {
18 #[default]
20 PendingDictionary,
21 Ready,
23}
24
25#[cfg_attr(test, derive(PartialEq))]
26#[derive(Debug, Serialize, Deserialize)]
27pub struct Zstd {
29 pub(crate) state: ZstdState,
31 pub(crate) level: i32,
33 pub use_dict: bool,
35 pub(crate) max_dict_size: usize,
37 #[serde(with = "dictionaries_serde")]
39 pub(crate) dictionaries: Option<Arc<ZstdDictionaries<'static>>>,
40 columns: usize,
42}
43
44impl Zstd {
45 pub const fn new(use_dict: bool, max_dict_size: usize, columns: usize) -> Self {
47 Self {
48 state: if use_dict { ZstdState::PendingDictionary } else { ZstdState::Ready },
49 level: 0,
50 use_dict,
51 max_dict_size,
52 dictionaries: None,
53 columns,
54 }
55 }
56
57 pub const fn with_level(mut self, level: i32) -> Self {
59 self.level = level;
60 self
61 }
62
63 pub fn decompressors(&self) -> Result<Vec<Decompressor<'_>>, NippyJarError> {
65 if let Some(dictionaries) = &self.dictionaries {
66 debug_assert!(dictionaries.len() == self.columns);
67 return dictionaries.decompressors()
68 }
69
70 Ok(vec![])
71 }
72
73 pub fn compressors(&self) -> Result<Option<Vec<Compressor<'_>>>, NippyJarError> {
75 match self.state {
76 ZstdState::PendingDictionary => Err(NippyJarError::CompressorNotReady),
77 ZstdState::Ready => {
78 if !self.use_dict {
79 return Ok(None)
80 }
81
82 if let Some(dictionaries) = &self.dictionaries {
83 debug!(target: "nippy-jar", count=?dictionaries.len(), "Generating ZSTD compressor dictionaries.");
84 return Ok(Some(dictionaries.compressors()?))
85 }
86 Ok(None)
87 }
88 }
89 }
90
91 pub fn compress_with_dictionary(
94 column_value: &[u8],
95 buffer: &mut Vec<u8>,
96 handle: &mut File,
97 compressor: Option<&mut Compressor<'_>>,
98 ) -> Result<(), NippyJarError> {
99 if let Some(compressor) = compressor {
100 let mut multiplier = 1;
105 while let Err(err) = compressor.compress_to_buffer(column_value, buffer) {
106 buffer.reserve(column_value.len() * multiplier);
107 multiplier += 1;
108 if multiplier == 5 {
109 return Err(NippyJarError::Disconnect(err))
110 }
111 }
112
113 handle.write_all(buffer)?;
114 buffer.clear();
115 } else {
116 handle.write_all(column_value)?;
117 }
118
119 Ok(())
120 }
121
122 pub fn decompress_with_dictionary(
124 column_value: &[u8],
125 output: &mut Vec<u8>,
126 decompressor: &mut Decompressor<'_>,
127 ) -> Result<(), NippyJarError> {
128 let previous_length = output.len();
129
130 unsafe {
132 output.set_len(output.capacity());
133 }
134
135 match decompressor.decompress_to_buffer(column_value, &mut output[previous_length..]) {
136 Ok(written) => {
137 unsafe {
140 output.set_len(previous_length + written);
141 }
142 Ok(())
143 }
144 Err(_) => {
145 unsafe {
147 output.set_len(previous_length);
148 }
149 Err(NippyJarError::OutputTooSmall)
150 }
151 }
152 }
153}
154
155impl Compression for Zstd {
156 fn decompress_to(&self, value: &[u8], dest: &mut Vec<u8>) -> Result<(), NippyJarError> {
157 let mut decoder = zstd::Decoder::with_dictionary(value, &[])?;
158 decoder.read_to_end(dest)?;
159 Ok(())
160 }
161
162 fn decompress(&self, value: &[u8]) -> Result<Vec<u8>, NippyJarError> {
163 let mut decompressed = Vec::with_capacity(value.len() * 2);
164 let mut decoder = zstd::Decoder::new(value)?;
165 decoder.read_to_end(&mut decompressed)?;
166 Ok(decompressed)
167 }
168
169 fn compress_to(&self, src: &[u8], dest: &mut Vec<u8>) -> Result<usize, NippyJarError> {
170 let before = dest.len();
171
172 let mut encoder = zstd::Encoder::new(dest, self.level)?;
173 encoder.write_all(src)?;
174
175 let dest = encoder.finish()?;
176
177 Ok(dest.len() - before)
178 }
179
180 fn compress(&self, src: &[u8]) -> Result<Vec<u8>, NippyJarError> {
181 let mut compressed = Vec::with_capacity(src.len());
182
183 self.compress_to(src, &mut compressed)?;
184
185 Ok(compressed)
186 }
187
188 fn is_ready(&self) -> bool {
189 matches!(self.state, ZstdState::Ready)
190 }
191
192 #[cfg(test)]
193 fn prepare_compression(
195 &mut self,
196 columns: Vec<impl IntoIterator<Item = Vec<u8>>>,
197 ) -> Result<(), NippyJarError> {
198 if !self.use_dict {
199 return Ok(())
200 }
201
202 if columns.len() != self.columns {
213 return Err(NippyJarError::ColumnLenMismatch(self.columns, columns.len()))
214 }
215
216 let mut dictionaries = Vec::with_capacity(columns.len());
217 for column in columns {
218 let mut sizes = vec![];
221 let data: Vec<_> = column
222 .into_iter()
223 .flat_map(|data| {
224 sizes.push(data.len());
225 data
226 })
227 .collect();
228
229 dictionaries.push(zstd::dict::from_continuous(&data, &sizes, self.max_dict_size)?);
230 }
231
232 debug_assert_eq!(dictionaries.len(), self.columns);
233
234 self.dictionaries = Some(Arc::new(ZstdDictionaries::new(dictionaries)));
235 self.state = ZstdState::Ready;
236
237 Ok(())
238 }
239}
240
241mod dictionaries_serde {
242 use super::*;
243
244 pub(crate) fn serialize<S>(
245 dictionaries: &Option<Arc<ZstdDictionaries<'static>>>,
246 serializer: S,
247 ) -> Result<S::Ok, S::Error>
248 where
249 S: Serializer,
250 {
251 match dictionaries {
252 Some(dicts) => serializer.serialize_some(dicts.as_ref()),
253 None => serializer.serialize_none(),
254 }
255 }
256
257 pub(crate) fn deserialize<'de, D>(
258 deserializer: D,
259 ) -> Result<Option<Arc<ZstdDictionaries<'static>>>, D::Error>
260 where
261 D: Deserializer<'de>,
262 {
263 let dictionaries: Option<Vec<RawDictionary>> = Option::deserialize(deserializer)?;
264 Ok(dictionaries.map(|dicts| Arc::new(ZstdDictionaries::load(dicts))))
265 }
266}
267
268#[cfg_attr(test, derive(PartialEq))]
270#[derive(Serialize, Deserialize, Deref)]
271pub(crate) struct ZstdDictionaries<'a>(Vec<ZstdDictionary<'a>>);
272
273impl std::fmt::Debug for ZstdDictionaries<'_> {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("ZstdDictionaries").field("num", &self.len()).finish_non_exhaustive()
276 }
277}
278
279impl ZstdDictionaries<'_> {
280 #[cfg(test)]
281 pub(crate) fn new(raw: Vec<RawDictionary>) -> Self {
283 Self(raw.into_iter().map(ZstdDictionary::Raw).collect())
284 }
285
286 pub(crate) fn load(raw: Vec<RawDictionary>) -> Self {
288 Self(
289 raw.into_iter()
290 .map(|dict| ZstdDictionary::Loaded(DecoderDictionary::copy(&dict)))
291 .collect(),
292 )
293 }
294
295 pub(crate) fn decompressors(&self) -> Result<Vec<Decompressor<'_>>, NippyJarError> {
297 Ok(self
298 .iter()
299 .flat_map(|dict| {
300 dict.loaded()
301 .ok_or(NippyJarError::DictionaryNotLoaded)
302 .map(Decompressor::with_prepared_dictionary)
303 })
304 .collect::<Result<Vec<_>, _>>()?)
305 }
306
307 pub(crate) fn compressors(&self) -> Result<Vec<Compressor<'_>>, NippyJarError> {
309 Ok(self
310 .iter()
311 .flat_map(|dict| {
312 dict.raw()
313 .ok_or(NippyJarError::CompressorNotAllowed)
314 .map(|dict| Compressor::with_dictionary(0, dict))
315 })
316 .collect::<Result<Vec<_>, _>>()?)
317 }
318}
319
320pub(crate) enum ZstdDictionary<'a> {
323 #[allow(dead_code)]
324 Raw(RawDictionary),
325 Loaded(DecoderDictionary<'a>),
326}
327
328impl ZstdDictionary<'_> {
329 pub(crate) const fn raw(&self) -> Option<&RawDictionary> {
331 match self {
332 ZstdDictionary::Raw(dict) => Some(dict),
333 ZstdDictionary::Loaded(_) => None,
334 }
335 }
336
337 pub(crate) const fn loaded(&self) -> Option<&DecoderDictionary<'_>> {
339 match self {
340 ZstdDictionary::Raw(_) => None,
341 ZstdDictionary::Loaded(dict) => Some(dict),
342 }
343 }
344}
345
346impl<'de> Deserialize<'de> for ZstdDictionary<'_> {
347 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
348 where
349 D: Deserializer<'de>,
350 {
351 let dict = RawDictionary::deserialize(deserializer)?;
352 Ok(Self::Loaded(DecoderDictionary::copy(&dict)))
353 }
354}
355
356impl Serialize for ZstdDictionary<'_> {
357 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
358 where
359 S: Serializer,
360 {
361 match self {
362 ZstdDictionary::Raw(r) => r.serialize(serializer),
363 ZstdDictionary::Loaded(_) => unreachable!(),
364 }
365 }
366}
367
368#[cfg(test)]
369impl PartialEq for ZstdDictionary<'_> {
370 fn eq(&self, other: &Self) -> bool {
371 if let (Self::Raw(a), Self::Raw(b)) = (self, &other) {
372 return a == b
373 }
374 unimplemented!("`DecoderDictionary` can't be compared. So comparison should be done after decompressing a value.");
375 }
376}