reth_discv5/
filter.rs
1use std::collections::HashSet;
4
5use derive_more::Constructor;
6use itertools::Itertools;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum FilterOutcome {
11 Ok,
13 Ignore {
15 reason: String,
17 },
18}
19
20impl FilterOutcome {
21 pub const fn is_ok(&self) -> bool {
23 matches!(self, Self::Ok)
24 }
25}
26
27#[derive(Debug, Constructor, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct MustIncludeKey {
30 key: &'static [u8],
32}
33
34impl MustIncludeKey {
35 pub fn filter(&self, enr: &discv5::Enr) -> FilterOutcome {
37 if enr.get_raw_rlp(self.key).is_none() {
38 return FilterOutcome::Ignore {
39 reason: format!("{} fork required", String::from_utf8_lossy(self.key)),
40 }
41 }
42 FilterOutcome::Ok
43 }
44}
45
46#[derive(Debug, Clone, Default)]
48pub struct MustNotIncludeKeys {
49 keys: HashSet<MustIncludeKey>,
50}
51
52impl MustNotIncludeKeys {
53 pub fn new(disallow_keys: &[&'static [u8]]) -> Self {
56 let mut keys = HashSet::with_capacity(disallow_keys.len());
57 for key in disallow_keys {
58 _ = keys.insert(MustIncludeKey::new(key));
59 }
60
61 Self { keys }
62 }
63}
64
65impl MustNotIncludeKeys {
66 pub fn filter(&self, enr: &discv5::Enr) -> FilterOutcome {
68 for key in &self.keys {
69 if matches!(key.filter(enr), FilterOutcome::Ok) {
70 return FilterOutcome::Ignore {
71 reason: format!(
72 "{} forks not allowed",
73 self.keys.iter().map(|key| String::from_utf8_lossy(key.key)).format(",")
74 ),
75 }
76 }
77 }
78
79 FilterOutcome::Ok
80 }
81
82 pub fn add_disallowed_keys(&mut self, keys: &[&'static [u8]]) {
84 for key in keys {
85 self.keys.insert(MustIncludeKey::new(key));
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::NetworkStackId;
94 use alloy_rlp::Bytes;
95 use discv5::enr::{CombinedKey, Enr};
96
97 #[test]
98 fn must_not_include_key_filter() {
99 let filter = MustNotIncludeKeys::new(&[NetworkStackId::ETH, NetworkStackId::ETH2]);
102
103 let sk = CombinedKey::generate_secp256k1();
105 let enr_1 = Enr::builder()
106 .add_value_rlp(NetworkStackId::ETH as &[u8], Bytes::from("cancun"))
107 .build(&sk)
108 .unwrap();
109
110 let sk = CombinedKey::generate_secp256k1();
112 let enr_2 = Enr::builder()
113 .add_value_rlp(NetworkStackId::ETH2, Bytes::from("deneb"))
114 .build(&sk)
115 .unwrap();
116
117 assert!(matches!(filter.filter(&enr_1), FilterOutcome::Ignore { .. }));
120 assert!(matches!(filter.filter(&enr_2), FilterOutcome::Ignore { .. }));
121 }
122}