reth_codecs_derive/
arbitrary.rs
1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{quote, ToTokens};
4
5pub fn maybe_generate_tests(
11 args: TokenStream,
12 type_ident: &impl ToTokens,
13 mod_tests: &Ident,
14) -> TokenStream2 {
15 let mut default_cases = 256;
17
18 let mut traits = vec![];
19 let mut roundtrips = vec![];
20 let mut additional_tests = vec![];
21 let mut is_crate = false;
22
23 let mut iter = args.into_iter().peekable();
24
25 if let Some(arg) = iter.peek() {
27 if arg.to_string() == "crate" {
28 is_crate = true;
29 iter.next();
30 }
31 }
32
33 for arg in iter {
34 if arg.to_string() == "compact" {
35 let path = if is_crate {
36 quote! { use crate::Compact; }
37 } else {
38 quote! { use reth_codecs::Compact; }
39 };
40 traits.push(path);
41 roundtrips.push(quote! {
42 {
43 let mut buf = vec![];
44 let len = field.clone().to_compact(&mut buf);
45 let (decoded, _): (super::#type_ident, _) = Compact::from_compact(&buf, len);
46 assert_eq!(field, decoded, "maybe_generate_tests::compact");
47 }
48 });
49 } else if arg.to_string() == "rlp" {
50 traits.push(quote! { use alloy_rlp::{Encodable, Decodable}; });
51 roundtrips.push(quote! {
52 {
53 let mut buf = vec![];
54 let len = field.encode(&mut buf);
55 let mut b = &mut buf.as_slice();
56 let decoded: super::#type_ident = Decodable::decode(b).unwrap();
57 assert_eq!(field, decoded, "maybe_generate_tests::rlp");
58 assert!(b.is_empty(), "buffer was not consumed entirely");
60
61 }
62 });
63 additional_tests.push(quote! {
64
65 #[test]
66 fn malformed_rlp_header_check() {
67 use rand::RngCore;
68
69 let mut raw = vec![0u8; 1024];
71 rand::thread_rng().fill_bytes(&mut raw);
72 let mut unstructured = arbitrary::Unstructured::new(&raw[..]);
73 let val: Result<super::#type_ident, _> = arbitrary::Arbitrary::arbitrary(&mut unstructured);
74 if val.is_err() {
75 return
77 }
78 let val = val.unwrap();
79 let mut buf = vec![];
80 let len = val.encode(&mut buf);
81
82 let mut decode_buf = &mut buf.as_slice();
84 let mut header = alloy_rlp::Header::decode(decode_buf).expect("failed to decode header");
85 header.payload_length+=1;
86 let mut b = Vec::with_capacity(decode_buf.len());
87 header.encode(&mut b);
88 b.extend_from_slice(decode_buf);
89 let res: Result<super::#type_ident, _> = Decodable::decode(&mut b.as_ref());
90 assert!(res.is_err(), "malformed header was decoded");
91 }
92 });
93 } else if let Ok(num) = arg.to_string().parse() {
94 default_cases = num;
95 }
96 }
97
98 let mut tests = TokenStream2::default();
99 if !roundtrips.is_empty() {
100 tests = quote! {
101 #[allow(non_snake_case)]
102 #[cfg(test)]
103 mod #mod_tests {
104 #(#traits)*
105 use proptest_arbitrary_interop::arb;
106
107 #[test]
108 fn proptest() {
109 let mut config = proptest::prelude::ProptestConfig::with_cases(#default_cases as u32);
110
111 proptest::proptest!(config, |(field in arb::<super::#type_ident>())| {
112 #(#roundtrips)*
113 });
114 }
115
116 #(#additional_tests)*
117 }
118 }
119 }
120
121 tests
122}