reth_nippy_jar/compression/
zstd.rsuse crate::{compression::Compression, NippyJarError};
use derive_more::Deref;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{
fs::File,
io::{Read, Write},
sync::Arc,
};
use tracing::*;
use zstd::bulk::Compressor;
pub use zstd::{bulk::Decompressor, dict::DecoderDictionary};
type RawDictionary = Vec<u8>;
#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum ZstdState {
#[default]
PendingDictionary,
Ready,
}
#[cfg_attr(test, derive(PartialEq))]
#[derive(Debug, Serialize, Deserialize)]
pub struct Zstd {
pub(crate) state: ZstdState,
pub(crate) level: i32,
pub use_dict: bool,
pub(crate) max_dict_size: usize,
#[serde(with = "dictionaries_serde")]
pub(crate) dictionaries: Option<Arc<ZstdDictionaries<'static>>>,
columns: usize,
}
impl Zstd {
pub const fn new(use_dict: bool, max_dict_size: usize, columns: usize) -> Self {
Self {
state: if use_dict { ZstdState::PendingDictionary } else { ZstdState::Ready },
level: 0,
use_dict,
max_dict_size,
dictionaries: None,
columns,
}
}
pub const fn with_level(mut self, level: i32) -> Self {
self.level = level;
self
}
pub fn decompressors(&self) -> Result<Vec<Decompressor<'_>>, NippyJarError> {
if let Some(dictionaries) = &self.dictionaries {
debug_assert!(dictionaries.len() == self.columns);
return dictionaries.decompressors()
}
Ok(vec![])
}
pub fn compressors(&self) -> Result<Option<Vec<Compressor<'_>>>, NippyJarError> {
match self.state {
ZstdState::PendingDictionary => Err(NippyJarError::CompressorNotReady),
ZstdState::Ready => {
if !self.use_dict {
return Ok(None)
}
if let Some(dictionaries) = &self.dictionaries {
debug!(target: "nippy-jar", count=?dictionaries.len(), "Generating ZSTD compressor dictionaries.");
return Ok(Some(dictionaries.compressors()?))
}
Ok(None)
}
}
}
pub fn compress_with_dictionary(
column_value: &[u8],
buffer: &mut Vec<u8>,
handle: &mut File,
compressor: Option<&mut Compressor<'_>>,
) -> Result<(), NippyJarError> {
if let Some(compressor) = compressor {
let mut multiplier = 1;
while let Err(err) = compressor.compress_to_buffer(column_value, buffer) {
buffer.reserve(column_value.len() * multiplier);
multiplier += 1;
if multiplier == 5 {
return Err(NippyJarError::Disconnect(err))
}
}
handle.write_all(buffer)?;
buffer.clear();
} else {
handle.write_all(column_value)?;
}
Ok(())
}
pub fn decompress_with_dictionary(
column_value: &[u8],
output: &mut Vec<u8>,
decompressor: &mut Decompressor<'_>,
) -> Result<(), NippyJarError> {
let previous_length = output.len();
unsafe {
output.set_len(output.capacity());
}
match decompressor.decompress_to_buffer(column_value, &mut output[previous_length..]) {
Ok(written) => {
unsafe {
output.set_len(previous_length + written);
}
Ok(())
}
Err(_) => {
unsafe {
output.set_len(previous_length);
}
Err(NippyJarError::OutputTooSmall)
}
}
}
}
impl Compression for Zstd {
fn decompress_to(&self, value: &[u8], dest: &mut Vec<u8>) -> Result<(), NippyJarError> {
let mut decoder = zstd::Decoder::with_dictionary(value, &[])?;
decoder.read_to_end(dest)?;
Ok(())
}
fn decompress(&self, value: &[u8]) -> Result<Vec<u8>, NippyJarError> {
let mut decompressed = Vec::with_capacity(value.len() * 2);
let mut decoder = zstd::Decoder::new(value)?;
decoder.read_to_end(&mut decompressed)?;
Ok(decompressed)
}
fn compress_to(&self, src: &[u8], dest: &mut Vec<u8>) -> Result<usize, NippyJarError> {
let before = dest.len();
let mut encoder = zstd::Encoder::new(dest, self.level)?;
encoder.write_all(src)?;
let dest = encoder.finish()?;
Ok(dest.len() - before)
}
fn compress(&self, src: &[u8]) -> Result<Vec<u8>, NippyJarError> {
let mut compressed = Vec::with_capacity(src.len());
self.compress_to(src, &mut compressed)?;
Ok(compressed)
}
fn is_ready(&self) -> bool {
matches!(self.state, ZstdState::Ready)
}
#[cfg(test)]
fn prepare_compression(
&mut self,
columns: Vec<impl IntoIterator<Item = Vec<u8>>>,
) -> Result<(), NippyJarError> {
if !self.use_dict {
return Ok(())
}
if columns.len() != self.columns {
return Err(NippyJarError::ColumnLenMismatch(self.columns, columns.len()))
}
let mut dictionaries = Vec::with_capacity(columns.len());
for column in columns {
let mut sizes = vec![];
let data: Vec<_> = column
.into_iter()
.flat_map(|data| {
sizes.push(data.len());
data
})
.collect();
dictionaries.push(zstd::dict::from_continuous(&data, &sizes, self.max_dict_size)?);
}
debug_assert_eq!(dictionaries.len(), self.columns);
self.dictionaries = Some(Arc::new(ZstdDictionaries::new(dictionaries)));
self.state = ZstdState::Ready;
Ok(())
}
}
mod dictionaries_serde {
use super::*;
pub(crate) fn serialize<S>(
dictionaries: &Option<Arc<ZstdDictionaries<'static>>>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match dictionaries {
Some(dicts) => serializer.serialize_some(dicts.as_ref()),
None => serializer.serialize_none(),
}
}
pub(crate) fn deserialize<'de, D>(
deserializer: D,
) -> Result<Option<Arc<ZstdDictionaries<'static>>>, D::Error>
where
D: Deserializer<'de>,
{
let dictionaries: Option<Vec<RawDictionary>> = Option::deserialize(deserializer)?;
Ok(dictionaries.map(|dicts| Arc::new(ZstdDictionaries::load(dicts))))
}
}
#[cfg_attr(test, derive(PartialEq))]
#[derive(Serialize, Deserialize, Deref)]
pub(crate) struct ZstdDictionaries<'a>(Vec<ZstdDictionary<'a>>);
impl std::fmt::Debug for ZstdDictionaries<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdDictionaries").field("num", &self.len()).finish_non_exhaustive()
}
}
impl ZstdDictionaries<'_> {
#[cfg(test)]
pub(crate) fn new(raw: Vec<RawDictionary>) -> Self {
Self(raw.into_iter().map(ZstdDictionary::Raw).collect())
}
pub(crate) fn load(raw: Vec<RawDictionary>) -> Self {
Self(
raw.into_iter()
.map(|dict| ZstdDictionary::Loaded(DecoderDictionary::copy(&dict)))
.collect(),
)
}
pub(crate) fn decompressors(&self) -> Result<Vec<Decompressor<'_>>, NippyJarError> {
Ok(self
.iter()
.flat_map(|dict| {
dict.loaded()
.ok_or(NippyJarError::DictionaryNotLoaded)
.map(Decompressor::with_prepared_dictionary)
})
.collect::<Result<Vec<_>, _>>()?)
}
pub(crate) fn compressors(&self) -> Result<Vec<Compressor<'_>>, NippyJarError> {
Ok(self
.iter()
.flat_map(|dict| {
dict.raw()
.ok_or(NippyJarError::CompressorNotAllowed)
.map(|dict| Compressor::with_dictionary(0, dict))
})
.collect::<Result<Vec<_>, _>>()?)
}
}
pub(crate) enum ZstdDictionary<'a> {
#[allow(dead_code)]
Raw(RawDictionary),
Loaded(DecoderDictionary<'a>),
}
impl ZstdDictionary<'_> {
pub(crate) const fn raw(&self) -> Option<&RawDictionary> {
match self {
ZstdDictionary::Raw(dict) => Some(dict),
ZstdDictionary::Loaded(_) => None,
}
}
pub(crate) const fn loaded(&self) -> Option<&DecoderDictionary<'_>> {
match self {
ZstdDictionary::Raw(_) => None,
ZstdDictionary::Loaded(dict) => Some(dict),
}
}
}
impl<'de> Deserialize<'de> for ZstdDictionary<'_> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let dict = RawDictionary::deserialize(deserializer)?;
Ok(Self::Loaded(DecoderDictionary::copy(&dict)))
}
}
impl Serialize for ZstdDictionary<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
ZstdDictionary::Raw(r) => r.serialize(serializer),
ZstdDictionary::Loaded(_) => unreachable!(),
}
}
}
#[cfg(test)]
impl PartialEq for ZstdDictionary<'_> {
fn eq(&self, other: &Self) -> bool {
if let (Self::Raw(a), Self::Raw(b)) = (self, &other) {
return a == b
}
unimplemented!("`DecoderDictionary` can't be compared. So comparison should be done after decompressing a value.");
}
}