reth_codecs_derive/
arbitrary.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{quote, ToTokens};
4
5/// If `compact` or `rlp` is passed to `derive_arbitrary`, this function will generate the
6/// corresponding proptest roundtrip tests.
7///
8/// It accepts an optional integer number for the number of proptest cases. Otherwise, it will set
9/// it at 1000.
10pub fn maybe_generate_tests(
11    args: TokenStream,
12    type_ident: &impl ToTokens,
13    mod_tests: &Ident,
14) -> TokenStream2 {
15    // Same as proptest
16    let mut default_cases = 256;
17
18    let mut traits = vec![];
19    let mut roundtrips = vec![];
20    let mut additional_tests = vec![];
21    let mut is_crate = false;
22
23    let mut iter = args.into_iter().peekable();
24
25    // we check if there's a crate argument which is used from inside the codecs crate directly
26    if let Some(arg) = iter.peek() {
27        if arg.to_string() == "crate" {
28            is_crate = true;
29            iter.next();
30        }
31    }
32
33    for arg in iter {
34        if arg.to_string() == "compact" {
35            let path = if is_crate {
36                quote! { use crate::Compact; }
37            } else {
38                quote! { use reth_codecs::Compact; }
39            };
40            traits.push(path);
41            roundtrips.push(quote! {
42                {
43                    let mut buf = vec![];
44                    let len = field.clone().to_compact(&mut buf);
45                    let (decoded, _): (super::#type_ident, _) = Compact::from_compact(&buf, len);
46                    assert_eq!(field, decoded, "maybe_generate_tests::compact");
47                }
48            });
49        } else if arg.to_string() == "rlp" {
50            traits.push(quote! { use alloy_rlp::{Encodable, Decodable}; });
51            roundtrips.push(quote! {
52                {
53                    let mut buf = vec![];
54                    let len = field.encode(&mut buf);
55                    let mut b = &mut buf.as_slice();
56                    let decoded: super::#type_ident = Decodable::decode(b).unwrap();
57                    assert_eq!(field, decoded, "maybe_generate_tests::rlp");
58                    // ensure buffer is fully consumed by decode
59                    assert!(b.is_empty(), "buffer was not consumed entirely");
60
61                }
62            });
63            additional_tests.push(quote! {
64
65                #[test]
66                fn malformed_rlp_header_check() {
67                    use rand::RngCore;
68
69                    // get random instance of type
70                    let mut raw = vec![0u8; 1024];
71                    rand::thread_rng().fill_bytes(&mut raw);
72                    let mut unstructured = arbitrary::Unstructured::new(&raw[..]);
73                    let val: Result<super::#type_ident, _> = arbitrary::Arbitrary::arbitrary(&mut unstructured);
74                    if val.is_err() {
75                        // this can be flaky sometimes due to not enough data for iterator based types like Vec
76                        return
77                    }
78                    let val = val.unwrap();
79                    let mut buf = vec![];
80                    let len = val.encode(&mut buf);
81
82                    // malformed rlp-header check
83                    let mut decode_buf = &mut buf.as_slice();
84                    let mut header = alloy_rlp::Header::decode(decode_buf).expect("failed to decode header");
85                    header.payload_length+=1;
86                    let mut b = Vec::with_capacity(decode_buf.len());
87                    header.encode(&mut b);
88                    b.extend_from_slice(decode_buf);
89                    let res: Result<super::#type_ident, _> = Decodable::decode(&mut b.as_ref());
90                    assert!(res.is_err(), "malformed header was decoded");
91                }
92            });
93        } else if let Ok(num) = arg.to_string().parse() {
94            default_cases = num;
95        }
96    }
97
98    let mut tests = TokenStream2::default();
99    if !roundtrips.is_empty() {
100        tests = quote! {
101            #[allow(non_snake_case)]
102            #[cfg(test)]
103            mod #mod_tests {
104                #(#traits)*
105                use proptest_arbitrary_interop::arb;
106
107                #[test]
108                fn proptest() {
109                    let mut config = proptest::prelude::ProptestConfig::with_cases(#default_cases as u32);
110
111                    proptest::proptest!(config, |(field in arb::<super::#type_ident>())| {
112                        #(#roundtrips)*
113                    });
114                }
115
116                #(#additional_tests)*
117            }
118        }
119    }
120
121    tests
122}