use crate::{
errors::{EthHandshakeError, EthStreamError},
message::{EthBroadcastMessage, ProtocolBroadcastMessage},
p2pstream::HANDSHAKE_TIMEOUT,
CanDisconnect, DisconnectReason, EthMessage, EthNetworkPrimitives, EthVersion, ProtocolMessage,
Status,
};
use alloy_primitives::bytes::{Bytes, BytesMut};
use futures::{ready, Sink, SinkExt, StreamExt};
use pin_project::pin_project;
use reth_eth_wire_types::NetworkPrimitives;
use reth_ethereum_forks::ForkFilter;
use reth_primitives_traits::GotExpected;
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::timeout;
use tokio_stream::Stream;
use tracing::{debug, trace};
pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
pub(crate) const MAX_STATUS_SIZE: usize = 500 * 1024;
#[pin_project]
#[derive(Debug)]
pub struct UnauthedEthStream<S> {
#[pin]
inner: S,
}
impl<S> UnauthedEthStream<S> {
pub const fn new(inner: S) -> Self {
Self { inner }
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, E> UnauthedEthStream<S>
where
S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Unpin,
EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
{
pub async fn handshake<N: NetworkPrimitives>(
self,
status: Status,
fork_filter: ForkFilter,
) -> Result<(EthStream<S, N>, Status), EthStreamError> {
self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
}
pub async fn handshake_with_timeout<N: NetworkPrimitives>(
self,
status: Status,
fork_filter: ForkFilter,
timeout_limit: Duration,
) -> Result<(EthStream<S, N>, Status), EthStreamError> {
timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
.await
.map_err(|_| EthStreamError::StreamTimeout)?
}
pub async fn handshake_without_timeout<N: NetworkPrimitives>(
mut self,
status: Status,
fork_filter: ForkFilter,
) -> Result<(EthStream<S, N>, Status), EthStreamError> {
trace!(
%status,
"sending eth status to peer"
);
self.inner
.send(
alloy_rlp::encode(ProtocolMessage::<N>::from(EthMessage::<N>::Status(status)))
.into(),
)
.await?;
let their_msg_res = self.inner.next().await;
let their_msg = match their_msg_res {
Some(msg) => msg,
None => {
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))
}
}?;
if their_msg.len() > MAX_STATUS_SIZE {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthStreamError::MessageTooBig(their_msg.len()))
}
let version = status.version;
let msg = match ProtocolMessage::<N>::decode_message(version, &mut their_msg.as_ref()) {
Ok(m) => m,
Err(err) => {
debug!("decode error in eth handshake: msg={their_msg:x}");
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(EthStreamError::InvalidMessage(err))
}
};
match msg.message {
EthMessage::Status(resp) => {
trace!(
status=%resp,
"validating incoming eth status from peer"
);
if status.genesis != resp.genesis {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedGenesis(
GotExpected { expected: status.genesis, got: resp.genesis }.into(),
)
.into())
}
if status.version != resp.version {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
got: resp.version,
expected: status.version,
})
.into())
}
if status.chain != resp.chain {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedChain(GotExpected {
got: resp.chain,
expected: status.chain,
})
.into())
}
if status.total_difficulty.bit_len() > 100 {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
got: status.total_difficulty.bit_len(),
maximum: 100,
}
.into())
}
if let Err(err) =
fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)
{
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(err.into())
}
let stream = EthStream::new(version, self.inner);
Ok((stream, resp))
}
_ => {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
))
}
}
}
}
#[pin_project]
#[derive(Debug)]
pub struct EthStream<S, N = EthNetworkPrimitives> {
version: EthVersion,
#[pin]
inner: S,
_pd: std::marker::PhantomData<N>,
}
impl<S, N> EthStream<S, N> {
#[inline]
pub const fn new(version: EthVersion, inner: S) -> Self {
Self { version, inner, _pd: std::marker::PhantomData }
}
#[inline]
pub const fn version(&self) -> EthVersion {
self.version
}
#[inline]
pub const fn inner(&self) -> &S {
&self.inner
}
#[inline]
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
#[inline]
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, E, N> EthStream<S, N>
where
S: Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
N: NetworkPrimitives,
{
pub fn start_send_broadcast(
&mut self,
item: EthBroadcastMessage<N>,
) -> Result<(), EthStreamError> {
self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
ProtocolBroadcastMessage::from(item),
)))?;
Ok(())
}
}
impl<S, E, N> Stream for EthStream<S, N>
where
S: Stream<Item = Result<BytesMut, E>> + Unpin,
EthStreamError: From<E>,
N: NetworkPrimitives,
{
type Item = Result<EthMessage<N>, EthStreamError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let res = ready!(this.inner.poll_next(cx));
let bytes = match res {
Some(Ok(bytes)) => bytes,
Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
None => return Poll::Ready(None),
};
if bytes.len() > MAX_MESSAGE_SIZE {
return Poll::Ready(Some(Err(EthStreamError::MessageTooBig(bytes.len()))))
}
let msg = match ProtocolMessage::decode_message(*this.version, &mut bytes.as_ref()) {
Ok(m) => m,
Err(err) => {
let msg = if bytes.len() > 50 {
format!("{:02x?}...{:x?}", &bytes[..10], &bytes[bytes.len() - 10..])
} else {
format!("{bytes:02x?}")
};
debug!(
version=?this.version,
%msg,
"failed to decode protocol message"
);
return Poll::Ready(Some(Err(EthStreamError::InvalidMessage(err))))
}
};
if matches!(msg.message, EthMessage::Status(_)) {
return Poll::Ready(Some(Err(EthStreamError::EthHandshakeError(
EthHandshakeError::StatusNotInHandshake,
))))
}
Poll::Ready(Some(Ok(msg.message)))
}
}
impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
where
S: CanDisconnect<Bytes> + Unpin,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
N: NetworkPrimitives,
{
type Error = EthStreamError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx).map_err(Into::into)
}
fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
if matches!(item, EthMessage::Status(_)) {
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
}
self.project()
.inner
.start_send(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))?;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx).map_err(Into::into)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx).map_err(Into::into)
}
}
impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
where
S: CanDisconnect<Bytes> + Send,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
N: NetworkPrimitives,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
self.inner.disconnect(reason).await.map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::UnauthedEthStream;
use crate::{
broadcast::BlockHashNumber,
errors::{EthHandshakeError, EthStreamError},
hello::DEFAULT_TCP_PORT,
p2pstream::UnauthedP2PStream,
EthMessage, EthStream, EthVersion, HelloMessageWithProtocols, PassthroughCodec,
ProtocolVersion, Status,
};
use alloy_chains::NamedChain;
use alloy_primitives::{B256, U256};
use futures::{SinkExt, StreamExt};
use reth_ecies::stream::ECIESStream;
use reth_eth_wire_types::EthNetworkPrimitives;
use reth_ethereum_forks::{ForkFilter, Head};
use reth_network_peers::pk2id;
use secp256k1::{SecretKey, SECP256K1};
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::codec::Decoder;
#[tokio::test]
async fn can_handshake() {
let genesis = B256::random();
let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
let status = Status {
version: EthVersion::Eth67,
chain: NamedChain::Mainnet.into(),
total_difficulty: U256::ZERO,
blockhash: B256::random(),
genesis,
forkid: fork_filter.current(),
};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let status_clone = status;
let fork_filter_clone = fork_filter.clone();
let handle = tokio::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await
.unwrap();
assert_eq!(their_status, status_clone);
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = PassthroughCodec::default().framed(outgoing);
let (_, their_status) = UnauthedEthStream::new(sink)
.handshake::<EthNetworkPrimitives>(status, fork_filter)
.await
.unwrap();
assert_eq!(their_status, status);
handle.await.unwrap();
}
#[tokio::test]
async fn pass_handshake_on_low_td_bitlen() {
let genesis = B256::random();
let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
let status = Status {
version: EthVersion::Eth67,
chain: NamedChain::Mainnet.into(),
total_difficulty: U256::from(2).pow(U256::from(100)) - U256::from(1),
blockhash: B256::random(),
genesis,
forkid: fork_filter.current(),
};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let status_clone = status;
let fork_filter_clone = fork_filter.clone();
let handle = tokio::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await
.unwrap();
assert_eq!(their_status, status_clone);
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = PassthroughCodec::default().framed(outgoing);
let (_, their_status) = UnauthedEthStream::new(sink)
.handshake::<EthNetworkPrimitives>(status, fork_filter)
.await
.unwrap();
assert_eq!(their_status, status);
handle.await.unwrap();
}
#[tokio::test]
async fn fail_handshake_on_high_td_bitlen() {
let genesis = B256::random();
let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
let status = Status {
version: EthVersion::Eth67,
chain: NamedChain::Mainnet.into(),
total_difficulty: U256::from(2).pow(U256::from(100)),
blockhash: B256::random(),
genesis,
forkid: fork_filter.current(),
};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let status_clone = status;
let fork_filter_clone = fork_filter.clone();
let handle = tokio::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let handshake_res = UnauthedEthStream::new(stream)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await;
assert!(matches!(
handshake_res,
Err(EthStreamError::EthHandshakeError(
EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 101, maximum: 100 }
))
));
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = PassthroughCodec::default().framed(outgoing);
let handshake_res = UnauthedEthStream::new(sink)
.handshake::<EthNetworkPrimitives>(status, fork_filter)
.await;
assert!(matches!(
handshake_res,
Err(EthStreamError::EthHandshakeError(
EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 101, maximum: 100 }
))
));
handle.await.unwrap();
}
#[tokio::test]
async fn can_write_and_read_cleartext() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let test_msg: EthMessage = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: B256::random(), number: 5 },
BlockHashNumber { hash: B256::random(), number: 6 },
]
.into(),
);
let test_msg_clone = test_msg.clone();
let handle = tokio::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let mut stream = EthStream::new(EthVersion::Eth67, stream);
let message = stream.next().await.unwrap().unwrap();
assert_eq!(message, test_msg_clone);
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = PassthroughCodec::default().framed(outgoing);
let mut client_stream = EthStream::new(EthVersion::Eth67, sink);
client_stream.send(test_msg).await.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn can_write_and_read_ecies() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let server_key = SecretKey::new(&mut rand::thread_rng());
let test_msg: EthMessage = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: B256::random(), number: 5 },
BlockHashNumber { hash: B256::random(), number: 6 },
]
.into(),
);
let test_msg_clone = test_msg.clone();
let handle = tokio::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
let mut stream = EthStream::new(EthVersion::Eth67, stream);
let message = stream.next().await.unwrap().unwrap();
assert_eq!(message, test_msg_clone);
});
let server_id = pk2id(&server_key.public_key(SECP256K1));
let client_key = SecretKey::new(&mut rand::thread_rng());
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let outgoing = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
let mut client_stream = EthStream::new(EthVersion::Eth67, outgoing);
client_stream.send(test_msg).await.unwrap();
handle.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn ethstream_over_p2p() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let server_key = SecretKey::new(&mut rand::thread_rng());
let test_msg: EthMessage = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: B256::random(), number: 5 },
BlockHashNumber { hash: B256::random(), number: 6 },
]
.into(),
);
let genesis = B256::random();
let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
let status = Status {
version: EthVersion::Eth67,
chain: NamedChain::Mainnet.into(),
total_difficulty: U256::ZERO,
blockhash: B256::random(),
genesis,
forkid: fork_filter.current(),
};
let status_copy = status;
let fork_filter_clone = fork_filter.clone();
let test_msg_clone = test_msg.clone();
let handle = tokio::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
let server_hello = HelloMessageWithProtocols {
protocol_version: ProtocolVersion::V5,
client_version: "bitcoind/1.0.0".to_string(),
protocols: vec![EthVersion::Eth67.into()],
port: DEFAULT_TCP_PORT,
id: pk2id(&server_key.public_key(SECP256K1)),
};
let unauthed_stream = UnauthedP2PStream::new(stream);
let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream)
.handshake(status_copy, fork_filter_clone)
.await
.unwrap();
let message = eth_stream.next().await.unwrap().unwrap();
assert_eq!(message, test_msg_clone);
});
let server_id = pk2id(&server_key.public_key(SECP256K1));
let client_key = SecretKey::new(&mut rand::thread_rng());
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
let client_hello = HelloMessageWithProtocols {
protocol_version: ProtocolVersion::V5,
client_version: "bitcoind/1.0.0".to_string(),
protocols: vec![EthVersion::Eth67.into()],
port: DEFAULT_TCP_PORT,
id: pk2id(&client_key.public_key(SECP256K1)),
};
let unauthed_stream = UnauthedP2PStream::new(sink);
let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
let (mut client_stream, _) =
UnauthedEthStream::new(p2p_stream).handshake(status, fork_filter).await.unwrap();
client_stream.send(test_msg).await.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn handshake_should_timeout() {
let genesis = B256::random();
let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
let status = Status {
version: EthVersion::Eth67,
chain: NamedChain::Mainnet.into(),
total_difficulty: U256::ZERO,
blockhash: B256::random(),
genesis,
forkid: fork_filter.current(),
};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let status_clone = status;
let fork_filter_clone = fork_filter.clone();
let _handle = tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(11)).await;
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await
.unwrap();
assert_eq!(their_status, status_clone);
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = PassthroughCodec::default().framed(outgoing);
let handshake_result = UnauthedEthStream::new(sink)
.handshake_with_timeout::<EthNetworkPrimitives>(
status,
fork_filter,
Duration::from_secs(1),
)
.await;
assert!(
matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
);
}
}