use std::collections::HashSet;
use derive_more::Constructor;
use itertools::Itertools;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FilterOutcome {
Ok,
Ignore {
reason: String,
},
}
impl FilterOutcome {
pub const fn is_ok(&self) -> bool {
matches!(self, Self::Ok)
}
}
#[derive(Debug, Constructor, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MustIncludeKey {
key: &'static [u8],
}
impl MustIncludeKey {
pub fn filter(&self, enr: &discv5::Enr) -> FilterOutcome {
if enr.get_raw_rlp(self.key).is_none() {
return FilterOutcome::Ignore {
reason: format!("{} fork required", String::from_utf8_lossy(self.key)),
}
}
FilterOutcome::Ok
}
}
#[derive(Debug, Clone, Default)]
pub struct MustNotIncludeKeys {
keys: HashSet<MustIncludeKey>,
}
impl MustNotIncludeKeys {
pub fn new(disallow_keys: &[&'static [u8]]) -> Self {
let mut keys = HashSet::with_capacity(disallow_keys.len());
for key in disallow_keys {
_ = keys.insert(MustIncludeKey::new(key));
}
Self { keys }
}
}
impl MustNotIncludeKeys {
pub fn filter(&self, enr: &discv5::Enr) -> FilterOutcome {
for key in &self.keys {
if matches!(key.filter(enr), FilterOutcome::Ok) {
return FilterOutcome::Ignore {
reason: format!(
"{} forks not allowed",
self.keys.iter().map(|key| String::from_utf8_lossy(key.key)).format(",")
),
}
}
}
FilterOutcome::Ok
}
pub fn add_disallowed_keys(&mut self, keys: &[&'static [u8]]) {
for key in keys {
self.keys.insert(MustIncludeKey::new(key));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::NetworkStackId;
use alloy_rlp::Bytes;
use discv5::enr::{CombinedKey, Enr};
#[test]
fn must_not_include_key_filter() {
let filter = MustNotIncludeKeys::new(&[NetworkStackId::ETH, NetworkStackId::ETH2]);
let sk = CombinedKey::generate_secp256k1();
let enr_1 = Enr::builder()
.add_value_rlp(NetworkStackId::ETH as &[u8], Bytes::from("cancun"))
.build(&sk)
.unwrap();
let sk = CombinedKey::generate_secp256k1();
let enr_2 = Enr::builder()
.add_value_rlp(NetworkStackId::ETH2, Bytes::from("deneb"))
.build(&sk)
.unwrap();
assert!(matches!(filter.filter(&enr_1), FilterOutcome::Ignore { .. }));
assert!(matches!(filter.filter(&enr_2), FilterOutcome::Ignore { .. }));
}
}