reth_codecs_derive/compact/
generator.rs

1//! Code generator for the `Compact` trait.
2
3use super::*;
4use crate::ZstdConfig;
5use convert_case::{Case, Casing};
6use syn::{Attribute, LitStr};
7
8/// Generates code to implement the `Compact` trait for a data type.
9pub fn generate_from_to(
10    ident: &Ident,
11    attrs: &[Attribute],
12    has_lifetime: bool,
13    fields: &FieldList,
14    zstd: Option<ZstdConfig>,
15) -> TokenStream2 {
16    let flags = format_ident!("{ident}Flags");
17
18    let reth_codecs = parse_reth_codecs_path(attrs).unwrap();
19
20    let to_compact = generate_to_compact(fields, ident, zstd.clone(), &reth_codecs);
21    let from_compact = generate_from_compact(fields, ident, zstd);
22
23    let snake_case_ident = ident.to_string().to_case(Case::Snake);
24
25    let fuzz = format_ident!("fuzz_test_{snake_case_ident}");
26    let test = format_ident!("fuzz_{snake_case_ident}");
27
28    let lifetime = if has_lifetime {
29        quote! { 'a }
30    } else {
31        quote! {}
32    };
33
34    let impl_compact = if has_lifetime {
35        quote! {
36           impl<#lifetime> #reth_codecs::Compact for #ident<#lifetime>
37        }
38    } else {
39        quote! {
40           impl #reth_codecs::Compact for #ident
41        }
42    };
43
44    let has_ref_fields = fields.iter().any(|field| {
45        if let FieldTypes::StructField(field) = field {
46            field.is_reference
47        } else {
48            false
49        }
50    });
51
52    let fn_from_compact = if has_ref_fields {
53        quote! { unimplemented!("from_compact not supported with ref structs") }
54    } else {
55        quote! {
56            let (flags, mut buf) = #flags::from(buf);
57            #from_compact
58        }
59    };
60
61    let fuzz_tests = if has_lifetime {
62        quote! {}
63    } else {
64        quote! {
65            #[cfg(test)]
66            #[allow(dead_code)]
67            #[test_fuzz::test_fuzz]
68            fn #fuzz(obj: #ident)  {
69                use #reth_codecs::Compact;
70                let mut buf = vec![];
71                let len = obj.clone().to_compact(&mut buf);
72                let (same_obj, buf) = #ident::from_compact(buf.as_ref(), len);
73                assert_eq!(obj, same_obj);
74            }
75
76            #[test]
77            #[allow(missing_docs)]
78            pub fn #test() {
79                #fuzz(#ident::default())
80            }
81        }
82    };
83
84    // Build function
85    quote! {
86        #fuzz_tests
87
88        #impl_compact {
89            fn to_compact<B>(&self, buf: &mut B) -> usize where B: #reth_codecs::__private::bytes::BufMut + AsMut<[u8]> {
90                let mut flags = #flags::default();
91                let mut total_length = 0;
92                #(#to_compact)*
93                total_length
94            }
95
96            fn from_compact(mut buf: &[u8], len: usize) -> (Self, &[u8]) {
97                #fn_from_compact
98            }
99        }
100    }
101}
102
103/// Generates code to implement the `Compact` trait method `to_compact`.
104fn generate_from_compact(
105    fields: &FieldList,
106    ident: &Ident,
107    zstd: Option<ZstdConfig>,
108) -> TokenStream2 {
109    let mut lines = vec![];
110    let mut known_types =
111        vec!["B256", "Address", "Bloom", "Vec", "TxHash", "BlockHash", "FixedBytes", "Cow"];
112
113    // Only types without `Bytes` should be added here. It's currently manually added, since
114    // it's hard to figure out with derive_macro which types have Bytes fields.
115    //
116    // This removes the requirement of the field to be placed last in the struct.
117    known_types.extend_from_slice(&["TxKind", "AccessList", "Signature", "CheckpointBlockRange"]);
118
119    // let mut handle = FieldListHandler::new(fields);
120    let is_enum = fields.iter().any(|field| matches!(field, FieldTypes::EnumVariant(_)));
121
122    if is_enum {
123        let enum_lines = EnumHandler::new(fields).generate_from(ident);
124
125        // Builds the object instantiation.
126        lines.push(quote! {
127            let obj = match flags.variant() {
128                #(#enum_lines)*
129                _ => unreachable!()
130            };
131        });
132    } else {
133        let mut struct_handler = StructHandler::new(fields);
134        lines.append(&mut struct_handler.generate_from(known_types.as_slice()));
135
136        // Builds the object instantiation.
137        if struct_handler.is_wrapper {
138            lines.push(quote! {
139                let obj = #ident(placeholder);
140            });
141        } else {
142            let fields = fields.iter().filter_map(|field| {
143                if let FieldTypes::StructField(field) = field {
144                    let ident = format_ident!("{}", field.name);
145                    return Some(quote! {
146                        #ident: #ident,
147                    })
148                }
149                None
150            });
151
152            lines.push(quote! {
153                let obj = #ident {
154                    #(#fields)*
155                };
156            });
157        }
158    }
159
160    // If the type has compression support, then check the `__zstd` flag. Otherwise, use the default
161    // code branch. However, even if it's a type with compression support, not all values are
162    // to be compressed (thus the zstd flag). Ideally only the bigger ones.
163    if let Some(zstd) = zstd {
164        let decompressor = zstd.decompressor;
165        quote! {
166            if flags.__zstd() != 0 {
167                #decompressor.with(|decompressor| {
168                    let decompressor = &mut decompressor.borrow_mut();
169                    let decompressed = decompressor.decompress(buf);
170                    let mut original_buf = buf;
171
172                    let mut buf: &[u8] = decompressed;
173                    #(#lines)*
174                    (obj, original_buf)
175                })
176            } else {
177                #(#lines)*
178                (obj, buf)
179            }
180        }
181    } else {
182        quote! {
183            #(#lines)*
184            (obj, buf)
185        }
186    }
187}
188
189/// Generates code to implement the `Compact` trait method `from_compact`.
190fn generate_to_compact(
191    fields: &FieldList,
192    ident: &Ident,
193    zstd: Option<ZstdConfig>,
194    reth_codecs: &syn::Path,
195) -> Vec<TokenStream2> {
196    let mut lines = vec![quote! {
197        let mut buffer = #reth_codecs::__private::bytes::BytesMut::new();
198    }];
199
200    let is_enum = fields.iter().any(|field| matches!(field, FieldTypes::EnumVariant(_)));
201
202    if is_enum {
203        let enum_lines = EnumHandler::new(fields).generate_to(ident);
204
205        lines.push(quote! {
206            flags.set_variant(match self {
207                #(#enum_lines)*
208            });
209        })
210    } else {
211        lines.append(&mut StructHandler::new(fields).generate_to());
212    }
213
214    // Just because a type supports compression, doesn't mean all its values are to be compressed.
215    // We skip the smaller ones, and thus require a flag` __zstd` to specify if this value is
216    // compressed or not.
217    if zstd.is_some() {
218        lines.push(quote! {
219            let mut zstd = buffer.len() > 7;
220            if zstd {
221                flags.set___zstd(1);
222            }
223        });
224    }
225
226    // Places the flag bits.
227    lines.push(quote! {
228        let flags = flags.into_bytes();
229        total_length += flags.len() + buffer.len();
230        buf.put_slice(&flags);
231    });
232
233    if let Some(zstd) = zstd {
234        let compressor = zstd.compressor;
235        lines.push(quote! {
236            if zstd {
237                #compressor.with(|compressor| {
238                    let mut compressor = compressor.borrow_mut();
239
240                    let compressed = compressor.compress(&buffer).expect("Failed to compress.");
241                    buf.put(compressed.as_slice());
242                });
243            } else {
244                buf.put(buffer);
245            }
246        });
247    } else {
248        lines.push(quote! {
249            buf.put(buffer);
250        })
251    }
252
253    lines
254}
255
256/// Function to extract the crate path from `reth_codecs(crate = "...")` attribute.
257pub(crate) fn parse_reth_codecs_path(attrs: &[Attribute]) -> syn::Result<syn::Path> {
258    // let default_crate_path: syn::Path = syn::parse_str("reth-codecs").unwrap();
259    let mut reth_codecs_path: syn::Path = syn::parse_quote!(reth_codecs);
260    for attr in attrs {
261        if attr.path().is_ident("reth_codecs") {
262            attr.parse_nested_meta(|meta| {
263                if meta.path.is_ident("crate") {
264                    let value = meta.value()?;
265                    let lit: LitStr = value.parse()?;
266                    reth_codecs_path = syn::parse_str(&lit.value())?;
267                    Ok(())
268                } else {
269                    Err(meta.error("unsupported attribute"))
270                }
271            })?;
272        }
273    }
274
275    Ok(reth_codecs_path)
276}