use std::{collections::HashSet, fmt, str::FromStr};
use serde::{Deserialize, Serialize, Serializer};
use strum::{AsRefStr, EnumIter, IntoStaticStr, ParseError, VariantArray, VariantNames};
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub enum RpcModuleSelection {
All,
#[default]
Standard,
Selection(HashSet<RethRpcModule>),
}
impl RpcModuleSelection {
pub const STANDARD_MODULES: [RethRpcModule; 3] =
[RethRpcModule::Eth, RethRpcModule::Net, RethRpcModule::Web3];
pub fn all_modules() -> HashSet<RethRpcModule> {
RethRpcModule::modules().into_iter().collect()
}
pub fn standard_modules() -> HashSet<RethRpcModule> {
HashSet::from(Self::STANDARD_MODULES)
}
pub fn default_ipc_modules() -> HashSet<RethRpcModule> {
Self::all_modules()
}
pub fn try_from_selection<I, T>(selection: I) -> Result<Self, T::Error>
where
I: IntoIterator<Item = T>,
T: TryInto<RethRpcModule>,
{
selection.into_iter().map(TryInto::try_into).collect()
}
pub fn len(&self) -> usize {
match self {
Self::All => RethRpcModule::variant_count(),
Self::Standard => Self::STANDARD_MODULES.len(),
Self::Selection(s) => s.len(),
}
}
pub fn is_empty(&self) -> bool {
match self {
Self::Selection(sel) => sel.is_empty(),
_ => false,
}
}
pub fn iter_selection(&self) -> Box<dyn Iterator<Item = RethRpcModule> + '_> {
match self {
Self::All => Box::new(RethRpcModule::modules().into_iter()),
Self::Standard => Box::new(Self::STANDARD_MODULES.iter().copied()),
Self::Selection(s) => Box::new(s.iter().copied()),
}
}
pub fn to_selection(&self) -> HashSet<RethRpcModule> {
match self {
Self::All => Self::all_modules(),
Self::Standard => Self::standard_modules(),
Self::Selection(s) => s.clone(),
}
}
pub fn into_selection(self) -> HashSet<RethRpcModule> {
match self {
Self::All => Self::all_modules(),
Self::Standard => Self::standard_modules(),
Self::Selection(s) => s,
}
}
pub fn are_identical(http: Option<&Self>, ws: Option<&Self>) -> bool {
match (http, ws) {
(Some(Self::All), Some(other)) | (Some(other), Some(Self::All)) => {
other.len() == RethRpcModule::variant_count()
}
(Some(some), None) | (None, Some(some)) => some.is_empty(),
(Some(http), Some(ws)) => http.to_selection() == ws.to_selection(),
(None, None) => true,
}
}
pub fn contains(&self, module: &RethRpcModule) -> bool {
match self {
Self::All => true,
Self::Standard => Self::STANDARD_MODULES.contains(module),
Self::Selection(s) => s.contains(module),
}
}
}
impl From<&HashSet<RethRpcModule>> for RpcModuleSelection {
fn from(s: &HashSet<RethRpcModule>) -> Self {
Self::from(s.clone())
}
}
impl From<HashSet<RethRpcModule>> for RpcModuleSelection {
fn from(s: HashSet<RethRpcModule>) -> Self {
Self::Selection(s)
}
}
impl From<&[RethRpcModule]> for RpcModuleSelection {
fn from(s: &[RethRpcModule]) -> Self {
Self::Selection(s.iter().copied().collect())
}
}
impl From<Vec<RethRpcModule>> for RpcModuleSelection {
fn from(s: Vec<RethRpcModule>) -> Self {
Self::Selection(s.into_iter().collect())
}
}
impl<const N: usize> From<[RethRpcModule; N]> for RpcModuleSelection {
fn from(s: [RethRpcModule; N]) -> Self {
Self::Selection(s.iter().copied().collect())
}
}
impl<'a> FromIterator<&'a RethRpcModule> for RpcModuleSelection {
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = &'a RethRpcModule>,
{
iter.into_iter().copied().collect()
}
}
impl FromIterator<RethRpcModule> for RpcModuleSelection {
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = RethRpcModule>,
{
Self::Selection(iter.into_iter().collect())
}
}
impl FromStr for RpcModuleSelection {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Ok(Self::Selection(Default::default()))
}
let mut modules = s.split(',').map(str::trim).peekable();
let first = modules.peek().copied().ok_or(ParseError::VariantNotFound)?;
match first.to_lowercase().as_str() {
"all" => Ok(Self::All),
"none" => Ok(Self::Selection(Default::default())),
_ => Self::try_from_selection(modules),
}
}
}
impl fmt::Display for RpcModuleSelection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{}]",
self.iter_selection().map(|s| s.to_string()).collect::<Vec<_>>().join(", ")
)
}
}
#[derive(
Debug,
Clone,
Copy,
Eq,
PartialEq,
Hash,
AsRefStr,
IntoStaticStr,
VariantNames,
VariantArray,
EnumIter,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "kebab-case")]
pub enum RethRpcModule {
Admin,
Debug,
Eth,
Net,
Trace,
Txpool,
Web3,
Rpc,
Reth,
Ots,
Flashbots,
}
impl RethRpcModule {
pub const fn variant_count() -> usize {
<Self as VariantArray>::VARIANTS.len()
}
pub const fn all_variant_names() -> &'static [&'static str] {
<Self as VariantNames>::VARIANTS
}
pub const fn all_variants() -> &'static [Self] {
<Self as VariantArray>::VARIANTS
}
pub fn modules() -> impl IntoIterator<Item = Self> {
use strum::IntoEnumIterator;
Self::iter()
}
#[inline]
pub fn as_str(&self) -> &'static str {
self.into()
}
}
impl FromStr for RethRpcModule {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"admin" => Self::Admin,
"debug" => Self::Debug,
"eth" => Self::Eth,
"net" => Self::Net,
"trace" => Self::Trace,
"txpool" => Self::Txpool,
"web3" => Self::Web3,
"rpc" => Self::Rpc,
"reth" => Self::Reth,
"ots" => Self::Ots,
"flashbots" => Self::Flashbots,
_ => return Err(ParseError::VariantNotFound),
})
}
}
impl TryFrom<&str> for RethRpcModule {
type Error = ParseError;
fn try_from(s: &str) -> Result<Self, <Self as TryFrom<&str>>::Error> {
FromStr::from_str(s)
}
}
impl fmt::Display for RethRpcModule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad(self.as_ref())
}
}
impl Serialize for RethRpcModule {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
s.serialize_str(self.as_ref())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_all_modules() {
let all_modules = RpcModuleSelection::all_modules();
assert_eq!(all_modules.len(), RethRpcModule::variant_count());
}
#[test]
fn test_standard_modules() {
let standard_modules = RpcModuleSelection::standard_modules();
let expected_modules: HashSet<RethRpcModule> =
HashSet::from([RethRpcModule::Eth, RethRpcModule::Net, RethRpcModule::Web3]);
assert_eq!(standard_modules, expected_modules);
}
#[test]
fn test_default_ipc_modules() {
let default_ipc_modules = RpcModuleSelection::default_ipc_modules();
assert_eq!(default_ipc_modules, RpcModuleSelection::all_modules());
}
#[test]
fn test_try_from_selection_success() {
let selection = vec!["eth", "admin"];
let config = RpcModuleSelection::try_from_selection(selection).unwrap();
assert_eq!(config, RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Admin]));
}
#[test]
fn test_rpc_module_selection_len() {
let all_modules = RpcModuleSelection::All;
let standard = RpcModuleSelection::Standard;
let selection = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Admin]);
assert_eq!(all_modules.len(), RethRpcModule::variant_count());
assert_eq!(standard.len(), 3);
assert_eq!(selection.len(), 2);
}
#[test]
fn test_rpc_module_selection_is_empty() {
let empty_selection = RpcModuleSelection::from(HashSet::new());
assert!(empty_selection.is_empty());
let non_empty_selection = RpcModuleSelection::from([RethRpcModule::Eth]);
assert!(!non_empty_selection.is_empty());
}
#[test]
fn test_rpc_module_selection_iter_selection() {
let all_modules = RpcModuleSelection::All;
let standard = RpcModuleSelection::Standard;
let selection = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Admin]);
assert_eq!(all_modules.iter_selection().count(), RethRpcModule::variant_count());
assert_eq!(standard.iter_selection().count(), 3);
assert_eq!(selection.iter_selection().count(), 2);
}
#[test]
fn test_rpc_module_selection_to_selection() {
let all_modules = RpcModuleSelection::All;
let standard = RpcModuleSelection::Standard;
let selection = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Admin]);
assert_eq!(all_modules.to_selection(), RpcModuleSelection::all_modules());
assert_eq!(standard.to_selection(), RpcModuleSelection::standard_modules());
assert_eq!(
selection.to_selection(),
HashSet::from([RethRpcModule::Eth, RethRpcModule::Admin])
);
}
#[test]
fn test_rpc_module_selection_are_identical() {
let all_modules = RpcModuleSelection::All;
assert!(RpcModuleSelection::are_identical(Some(&all_modules), Some(&all_modules)));
assert!(RpcModuleSelection::are_identical(None, None));
let selection1 = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Admin]);
let selection2 = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Admin]);
assert!(RpcModuleSelection::are_identical(Some(&selection1), Some(&selection2)));
let standard = RpcModuleSelection::Standard;
assert!(!RpcModuleSelection::are_identical(Some(&all_modules), Some(&standard)));
let empty_selection = RpcModuleSelection::Selection(HashSet::new());
assert!(RpcModuleSelection::are_identical(None, Some(&empty_selection)));
assert!(RpcModuleSelection::are_identical(Some(&empty_selection), None));
let non_empty_selection = RpcModuleSelection::from([RethRpcModule::Eth]);
assert!(!RpcModuleSelection::are_identical(None, Some(&non_empty_selection)));
assert!(!RpcModuleSelection::are_identical(Some(&non_empty_selection), None));
let partial_selection = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Net]);
assert!(!RpcModuleSelection::are_identical(Some(&all_modules), Some(&partial_selection)));
let full_selection =
RpcModuleSelection::from(RethRpcModule::modules().into_iter().collect::<HashSet<_>>());
assert!(RpcModuleSelection::are_identical(Some(&all_modules), Some(&full_selection)));
let selection3 = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Net]);
let selection4 = RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Web3]);
assert!(!RpcModuleSelection::are_identical(Some(&selection3), Some(&selection4)));
let matching_standard =
RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Net, RethRpcModule::Web3]);
assert!(RpcModuleSelection::are_identical(Some(&standard), Some(&matching_standard)));
let non_matching_standard =
RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Net]);
assert!(!RpcModuleSelection::are_identical(Some(&standard), Some(&non_matching_standard)));
}
#[test]
fn test_rpc_module_selection_from_str() {
let result = RpcModuleSelection::from_str("");
assert!(result.is_ok());
assert_eq!(result.unwrap(), RpcModuleSelection::Selection(Default::default()));
let result = RpcModuleSelection::from_str("all");
assert!(result.is_ok());
assert_eq!(result.unwrap(), RpcModuleSelection::All);
let result = RpcModuleSelection::from_str("All");
assert!(result.is_ok());
assert_eq!(result.unwrap(), RpcModuleSelection::All);
let result = RpcModuleSelection::from_str("ALL");
assert!(result.is_ok());
assert_eq!(result.unwrap(), RpcModuleSelection::All);
let result = RpcModuleSelection::from_str("none");
assert!(result.is_ok());
assert_eq!(result.unwrap(), RpcModuleSelection::Selection(Default::default()));
let result = RpcModuleSelection::from_str("None");
assert!(result.is_ok());
assert_eq!(result.unwrap(), RpcModuleSelection::Selection(Default::default()));
let result = RpcModuleSelection::from_str("NONE");
assert!(result.is_ok());
assert_eq!(result.unwrap(), RpcModuleSelection::Selection(Default::default()));
let result = RpcModuleSelection::from_str("eth,admin");
assert!(result.is_ok());
let expected_selection =
RpcModuleSelection::from([RethRpcModule::Eth, RethRpcModule::Admin]);
assert_eq!(result.unwrap(), expected_selection);
let result = RpcModuleSelection::from_str(" eth , admin ");
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected_selection);
let result = RpcModuleSelection::from_str("invalid,unknown");
assert!(result.is_err());
assert_eq!(result.unwrap_err(), ParseError::VariantNotFound);
let result = RpcModuleSelection::from_str("eth");
assert!(result.is_ok());
let expected_selection = RpcModuleSelection::from([RethRpcModule::Eth]);
assert_eq!(result.unwrap(), expected_selection);
let result = RpcModuleSelection::from_str("unknown");
assert!(result.is_err());
assert_eq!(result.unwrap_err(), ParseError::VariantNotFound);
}
}