use std::{fmt, str::FromStr};
use alloy_rlp::{Decodable, Encodable, Error as RlpError};
use bytes::BufMut;
use derive_more::Display;
use reth_codecs_derive::add_arbitrary_tests;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[error("Unknown eth protocol version: {0}")]
pub struct ParseVersionError(String);
#[repr(u8)]
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Display)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
pub enum EthVersion {
Eth66 = 66,
Eth67 = 67,
Eth68 = 68,
Eth69 = 69,
}
impl EthVersion {
pub const LATEST: Self = Self::Eth68;
pub const fn total_messages(&self) -> u8 {
match self {
Self::Eth66 => 15,
Self::Eth67 | Self::Eth68 => {
13
}
Self::Eth69 => 11,
}
}
pub const fn is_eth66(&self) -> bool {
matches!(self, Self::Eth66)
}
pub const fn is_eth67(&self) -> bool {
matches!(self, Self::Eth67)
}
pub const fn is_eth68(&self) -> bool {
matches!(self, Self::Eth68)
}
pub const fn is_eth69(&self) -> bool {
matches!(self, Self::Eth69)
}
}
impl Encodable for EthVersion {
fn encode(&self, out: &mut dyn BufMut) {
(*self as u8).encode(out)
}
fn length(&self) -> usize {
(*self as u8).length()
}
}
impl Decodable for EthVersion {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let version = u8::decode(buf)?;
Self::try_from(version).map_err(|_| RlpError::Custom("invalid eth version"))
}
}
impl TryFrom<&str> for EthVersion {
type Error = ParseVersionError;
#[inline]
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"66" => Ok(Self::Eth66),
"67" => Ok(Self::Eth67),
"68" => Ok(Self::Eth68),
"69" => Ok(Self::Eth69),
_ => Err(ParseVersionError(s.to_string())),
}
}
}
impl TryFrom<u8> for EthVersion {
type Error = ParseVersionError;
#[inline]
fn try_from(u: u8) -> Result<Self, Self::Error> {
match u {
66 => Ok(Self::Eth66),
67 => Ok(Self::Eth67),
68 => Ok(Self::Eth68),
69 => Ok(Self::Eth69),
_ => Err(ParseVersionError(u.to_string())),
}
}
}
impl FromStr for EthVersion {
type Err = ParseVersionError;
#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(s)
}
}
impl From<EthVersion> for u8 {
#[inline]
fn from(v: EthVersion) -> Self {
v as Self
}
}
impl From<EthVersion> for &'static str {
#[inline]
fn from(v: EthVersion) -> &'static str {
match v {
EthVersion::Eth66 => "66",
EthVersion::Eth67 => "67",
EthVersion::Eth68 => "68",
EthVersion::Eth69 => "69",
}
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
#[add_arbitrary_tests(rlp)]
pub enum ProtocolVersion {
V4 = 4,
#[default]
V5 = 5,
}
impl fmt::Display for ProtocolVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "v{}", *self as u8)
}
}
impl Encodable for ProtocolVersion {
fn encode(&self, out: &mut dyn BufMut) {
(*self as u8).encode(out)
}
fn length(&self) -> usize {
(*self as u8).length()
}
}
impl Decodable for ProtocolVersion {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let version = u8::decode(buf)?;
match version {
4 => Ok(Self::V4),
5 => Ok(Self::V5),
_ => Err(RlpError::Custom("unknown p2p protocol version")),
}
}
}
#[cfg(test)]
mod tests {
use super::{EthVersion, ParseVersionError};
use alloy_rlp::{Decodable, Encodable, Error as RlpError};
use bytes::BytesMut;
#[test]
fn test_eth_version_try_from_str() {
assert_eq!(EthVersion::Eth66, EthVersion::try_from("66").unwrap());
assert_eq!(EthVersion::Eth67, EthVersion::try_from("67").unwrap());
assert_eq!(EthVersion::Eth68, EthVersion::try_from("68").unwrap());
assert_eq!(EthVersion::Eth69, EthVersion::try_from("69").unwrap());
assert_eq!(Err(ParseVersionError("70".to_string())), EthVersion::try_from("70"));
}
#[test]
fn test_eth_version_from_str() {
assert_eq!(EthVersion::Eth66, "66".parse().unwrap());
assert_eq!(EthVersion::Eth67, "67".parse().unwrap());
assert_eq!(EthVersion::Eth68, "68".parse().unwrap());
assert_eq!(EthVersion::Eth69, "69".parse().unwrap());
assert_eq!(Err(ParseVersionError("70".to_string())), "70".parse::<EthVersion>());
}
#[test]
fn test_eth_version_rlp_encode() {
let versions = [EthVersion::Eth66, EthVersion::Eth67, EthVersion::Eth68, EthVersion::Eth69];
for version in versions {
let mut encoded = BytesMut::new();
version.encode(&mut encoded);
assert_eq!(encoded.len(), 1);
assert_eq!(encoded[0], version as u8);
}
}
#[test]
fn test_eth_version_rlp_decode() {
let test_cases = [
(66_u8, Ok(EthVersion::Eth66)),
(67_u8, Ok(EthVersion::Eth67)),
(68_u8, Ok(EthVersion::Eth68)),
(69_u8, Ok(EthVersion::Eth69)),
(70_u8, Err(RlpError::Custom("invalid eth version"))),
(65_u8, Err(RlpError::Custom("invalid eth version"))),
];
for (input, expected) in test_cases {
let mut encoded = BytesMut::new();
input.encode(&mut encoded);
let mut slice = encoded.as_ref();
let result = EthVersion::decode(&mut slice);
assert_eq!(result, expected);
}
}
#[test]
fn test_eth_version_total_messages() {
assert_eq!(EthVersion::Eth66.total_messages(), 15);
assert_eq!(EthVersion::Eth67.total_messages(), 13);
assert_eq!(EthVersion::Eth68.total_messages(), 13);
assert_eq!(EthVersion::Eth69.total_messages(), 11);
}
}