1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Generics};

mod generator;
use generator::*;

mod enums;
use enums::*;

mod flags;
use flags::*;

mod structs;
use structs::*;

// Helper Alias type
type IsCompact = bool;
// Helper Alias type
type FieldName = String;
// Helper Alias type
type FieldType = String;
/// `Compact` has alternative functions that can be used as a workaround for type
/// specialization of fixed sized types.
///
/// Example: `Vec<B256>` vs `Vec<U256>`. The first does not
/// require the len of the element, while the latter one does.
type UseAlternative = bool;
// Helper Alias type
type StructFieldDescriptor = (FieldName, FieldType, IsCompact, UseAlternative);
// Helper Alias type
type FieldList = Vec<FieldTypes>;

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum FieldTypes {
    StructField(StructFieldDescriptor),
    EnumVariant(String),
    EnumUnnamedField((FieldType, UseAlternative)),
}

/// Derives the `Compact` trait and its from/to implementations.
pub fn derive(input: TokenStream, is_zstd: bool) -> TokenStream {
    let mut output = quote! {};

    let DeriveInput { ident, data, generics, .. } = parse_macro_input!(input);

    let has_lifetime = has_lifetime(&generics);

    let fields = get_fields(&data);
    output.extend(generate_flag_struct(&ident, has_lifetime, &fields, is_zstd));
    output.extend(generate_from_to(&ident, has_lifetime, &fields, is_zstd));
    output.into()
}

pub fn has_lifetime(generics: &Generics) -> bool {
    generics.lifetimes().next().is_some()
}

/// Given a list of fields on a struct, extract their fields and types.
pub fn get_fields(data: &Data) -> FieldList {
    let mut fields = vec![];

    match data {
        Data::Struct(data) => match data.fields {
            syn::Fields::Named(ref data_fields) => {
                for field in &data_fields.named {
                    load_field(field, &mut fields, false);
                }
                assert_eq!(fields.len(), data_fields.named.len(), "get_fields");
            }
            syn::Fields::Unnamed(ref data_fields) => {
                assert_eq!(
                    data_fields.unnamed.len(),
                    1,
                    "Compact only allows one unnamed field. Consider making it a struct."
                );
                load_field(&data_fields.unnamed[0], &mut fields, false);
            }
            syn::Fields::Unit => todo!(),
        },
        Data::Enum(data) => {
            for variant in &data.variants {
                fields.push(FieldTypes::EnumVariant(variant.ident.to_string()));

                match &variant.fields {
                    syn::Fields::Named(_) => {
                        panic!("Not allowed to have Enum Variants with multiple named fields. Make it a struct instead.")
                    }
                    syn::Fields::Unnamed(data_fields) => {
                        assert_eq!(
                            data_fields.unnamed.len(),
                            1,
                            "Compact only allows one unnamed field. Consider making it a struct."
                        );
                        load_field(&data_fields.unnamed[0], &mut fields, true);
                    }
                    syn::Fields::Unit => (),
                }
            }
        }
        Data::Union(_) => todo!(),
    }

    fields
}

fn load_field(field: &syn::Field, fields: &mut FieldList, is_enum: bool) {
    match field.ty {
        syn::Type::Reference(ref reference) => match &*reference.elem {
            syn::Type::Path(path) => {
                load_field_from_segments(&path.path.segments, is_enum, fields, field)
            }
            _ => unimplemented!("{:?}", &field.ident),
        },
        syn::Type::Path(ref path) => {
            load_field_from_segments(&path.path.segments, is_enum, fields, field)
        }
        _ => unimplemented!("{:?}", &field.ident),
    }
}

fn load_field_from_segments(
    segments: &syn::punctuated::Punctuated<syn::PathSegment, syn::token::PathSep>,
    is_enum: bool,
    fields: &mut Vec<FieldTypes>,
    field: &syn::Field,
) {
    if !segments.is_empty() {
        let mut ftype = String::new();

        let mut use_alt_impl: UseAlternative = false;

        for (index, segment) in segments.iter().enumerate() {
            ftype.push_str(&segment.ident.to_string());
            if index < segments.len() - 1 {
                ftype.push_str("::");
            }

            use_alt_impl = should_use_alt_impl(&ftype, segment);
        }

        if is_enum {
            fields.push(FieldTypes::EnumUnnamedField((ftype.to_string(), use_alt_impl)));
        } else {
            let should_compact = is_flag_type(&ftype) ||
                field.attrs.iter().any(|attr| {
                    attr.path().segments.iter().any(|path| path.ident == "maybe_zero")
                });

            fields.push(FieldTypes::StructField((
                field.ident.as_ref().map(|i| i.to_string()).unwrap_or_default(),
                ftype,
                should_compact,
                use_alt_impl,
            )));
        }
    }
}

/// Since there's no impl specialization in rust stable atm, once we find we have a
/// Vec/Option we try to find out if it's a Vec/Option of a fixed size data type, e.g. `Vec<B256>`.
///
/// If so, we use another impl to code/decode its data.
fn should_use_alt_impl(ftype: &String, segment: &syn::PathSegment) -> bool {
    if *ftype == "Vec" || *ftype == "Option" {
        if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments {
            if let Some(syn::GenericArgument::Type(syn::Type::Path(arg_path))) = args.args.last() {
                if let (Some(path), 1) =
                    (arg_path.path.segments.first(), arg_path.path.segments.len())
                {
                    if [
                        "B256",
                        "Address",
                        "Address",
                        "Bloom",
                        "TxHash",
                        "BlockHash",
                        "CompactPlaceholder",
                    ]
                    .contains(&path.ident.to_string().as_str())
                    {
                        return true
                    }
                }
            }
        }
    }
    false
}

/// Given the field type in a string format, return the amount of bits necessary to save its maximum
/// length.
pub fn get_bit_size(ftype: &str) -> u8 {
    match ftype {
        "TransactionKind" | "TxKind" | "bool" | "Option" | "Signature" => 1,
        "TxType" => 2,
        "u64" | "BlockNumber" | "TxNumber" | "ChainId" | "NumTransactions" => 4,
        "u128" => 5,
        "U256" => 6,
        _ => 0,
    }
}

/// Given the field type in a string format, checks if its type should be added to the
/// `StructFlags`.
pub fn is_flag_type(ftype: &str) -> bool {
    get_bit_size(ftype) > 0
}

#[cfg(test)]
mod tests {
    use super::*;
    use similar_asserts::assert_eq;
    use syn::parse2;

    #[test]
    fn gen() {
        let f_struct = quote! {
             #[derive(Debug, PartialEq, Clone)]
             pub struct TestStruct {
                 f_u64: u64,
                 f_u256: U256,
                 f_bool_t: bool,
                 f_bool_f: bool,
                 f_option_none: Option<U256>,
                 f_option_some: Option<B256>,
                 f_option_some_u64: Option<u64>,
                 f_vec_empty: Vec<U256>,
                 f_vec_some: Vec<Address>,
             }
        };

        // Generate code that will impl the `Compact` trait.
        let mut output = quote! {};
        let DeriveInput { ident, data, .. } = parse2(f_struct).unwrap();
        let fields = get_fields(&data);
        output.extend(generate_flag_struct(&ident, false, &fields, false));
        output.extend(generate_from_to(&ident, false, &fields, false));

        // Expected output in a TokenStream format. Commas matter!
        let should_output = quote! {
            impl TestStruct {
                #[doc = "Used bytes by [`TestStructFlags`]"]
                pub const fn bitflag_encoded_bytes() -> usize {
                    2u8 as usize
                }
            }

            pub use TestStruct_flags::TestStructFlags;

            #[allow(non_snake_case)]
            mod TestStruct_flags {
                use bytes::Buf;
                use modular_bitfield::prelude::*;
                #[doc = "Fieldset that facilitates compacting the parent type. Used bytes: 2 | Unused bits: 1"]
                #[bitfield]
                #[derive(Clone, Copy, Debug, Default)]
                pub struct TestStructFlags {
                    pub f_u64_len: B4,
                    pub f_u256_len: B6,
                    pub f_bool_t_len: B1,
                    pub f_bool_f_len: B1,
                    pub f_option_none_len: B1,
                    pub f_option_some_len: B1,
                    pub f_option_some_u64_len: B1,
                    #[skip]
                    unused: B1,
                }
                impl TestStructFlags {
                    #[doc = r" Deserializes this fieldset and returns it, alongside the original slice in an advanced position."]
                    pub fn from(mut buf: &[u8]) -> (Self, &[u8]) {
                        (
                            TestStructFlags::from_bytes([buf.get_u8(), buf.get_u8(),]),
                            buf
                        )
                    }
                }
            }
            #[cfg(test)]
            #[allow(dead_code)]
            #[test_fuzz::test_fuzz]
            fn fuzz_test_test_struct(obj: TestStruct) {
                let mut buf = vec![];
                let len = obj.clone().to_compact(&mut buf);
                let (same_obj, buf) = TestStruct::from_compact(buf.as_ref(), len);
                assert_eq!(obj, same_obj);
            }
            #[test]
            #[allow(missing_docs)]
            pub fn fuzz_test_struct() {
                fuzz_test_test_struct(TestStruct::default())
            }
            impl Compact for TestStruct {
                fn to_compact<B>(&self, buf: &mut B) -> usize where B: bytes::BufMut + AsMut<[u8]> {
                    let mut flags = TestStructFlags::default();
                    let mut total_length = 0;
                    let mut buffer = bytes::BytesMut::new();
                    let f_u64_len = self.f_u64.to_compact(&mut buffer);
                    flags.set_f_u64_len(f_u64_len as u8);
                    let f_u256_len = self.f_u256.to_compact(&mut buffer);
                    flags.set_f_u256_len(f_u256_len as u8);
                    let f_bool_t_len = self.f_bool_t.to_compact(&mut buffer);
                    flags.set_f_bool_t_len(f_bool_t_len as u8);
                    let f_bool_f_len = self.f_bool_f.to_compact(&mut buffer);
                    flags.set_f_bool_f_len(f_bool_f_len as u8);
                    let f_option_none_len = self.f_option_none.to_compact(&mut buffer);
                    flags.set_f_option_none_len(f_option_none_len as u8);
                    let f_option_some_len = self.f_option_some.specialized_to_compact(&mut buffer);
                    flags.set_f_option_some_len(f_option_some_len as u8);
                    let f_option_some_u64_len = self.f_option_some_u64.to_compact(&mut buffer);
                    flags.set_f_option_some_u64_len(f_option_some_u64_len as u8);
                    let f_vec_empty_len = self.f_vec_empty.to_compact(&mut buffer);
                    let f_vec_some_len = self.f_vec_some.specialized_to_compact(&mut buffer);
                    let flags = flags.into_bytes();
                    total_length += flags.len() + buffer.len();
                    buf.put_slice(&flags);
                    buf.put(buffer);
                    total_length
                }
                fn from_compact(mut buf: &[u8], len: usize) -> (Self, &[u8]) {
                    let (flags, mut buf) = TestStructFlags::from(buf);
                    let (f_u64, new_buf) = u64::from_compact(buf, flags.f_u64_len() as usize);
                    buf = new_buf;
                    let (f_u256, new_buf) = U256::from_compact(buf, flags.f_u256_len() as usize);
                    buf = new_buf;
                    let (f_bool_t, new_buf) = bool::from_compact(buf, flags.f_bool_t_len() as usize);
                    buf = new_buf;
                    let (f_bool_f, new_buf) = bool::from_compact(buf, flags.f_bool_f_len() as usize);
                    buf = new_buf;
                    let (f_option_none, new_buf) = Option::from_compact(buf, flags.f_option_none_len() as usize);
                    buf = new_buf;
                    let (f_option_some, new_buf) = Option::specialized_from_compact(buf, flags.f_option_some_len() as usize);
                    buf = new_buf;
                    let (f_option_some_u64, new_buf) = Option::from_compact(buf, flags.f_option_some_u64_len() as usize);
                    buf = new_buf;
                    let (f_vec_empty, new_buf) = Vec::from_compact(buf, buf.len());
                    buf = new_buf;
                    let (f_vec_some, new_buf) = Vec::specialized_from_compact(buf, buf.len());
                    buf = new_buf;
                    let obj = TestStruct {
                        f_u64: f_u64,
                        f_u256: f_u256,
                        f_bool_t: f_bool_t,
                        f_bool_f: f_bool_f,
                        f_option_none: f_option_none,
                        f_option_some: f_option_some,
                        f_option_some_u64: f_option_some_u64,
                        f_vec_empty: f_vec_empty,
                        f_vec_some: f_vec_some,
                    };
                    (obj, buf)
                }
            }
        };

        assert_eq!(
            syn::parse2::<syn::File>(output).unwrap(),
            syn::parse2::<syn::File>(should_output).unwrap()
        );
    }
}