use core::sync::atomic::Ordering;
use std::{
collections::VecDeque,
future::Future,
net::SocketAddr,
pin::Pin,
sync::{atomic::AtomicU64, Arc},
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use crate::{
message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
session::{
conn::EthRlpxConnection,
handle::{ActiveSessionMessage, SessionCommand},
SessionId,
},
};
use alloy_primitives::Sealable;
use futures::{stream::Fuse, SinkExt, StreamExt};
use metrics::Gauge;
use reth_eth_wire::{
capability::RawCapabilityMessage,
errors::{EthHandshakeError, EthStreamError, P2PStreamError},
message::{EthBroadcastMessage, RequestPair},
Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives,
};
use reth_metrics::common::mpsc::MeteredPollSender;
use reth_network_api::PeerRequest;
use reth_network_p2p::error::RequestError;
use reth_network_peers::PeerId;
use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
use reth_primitives_traits::Block;
use rustc_hash::FxHashMap;
use tokio::{
sync::{mpsc::error::TrySendError, oneshot},
time::Interval,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::PollSender;
use tracing::{debug, trace};
const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
const SAMPLE_IMPACT: f64 = 0.1;
const TIMEOUT_SCALING: u32 = 3;
#[allow(dead_code)]
pub(crate) struct ActiveSession<N: NetworkPrimitives> {
pub(crate) next_id: u64,
pub(crate) conn: EthRlpxConnection<N>,
pub(crate) remote_peer_id: PeerId,
pub(crate) remote_addr: SocketAddr,
pub(crate) remote_capabilities: Arc<Capabilities>,
pub(crate) session_id: SessionId,
pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
pub(crate) internal_request_tx: Fuse<ReceiverStream<PeerRequest<N>>>,
pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
pub(crate) internal_request_timeout: Arc<AtomicU64>,
pub(crate) internal_request_timeout_interval: Interval,
pub(crate) protocol_breach_request_timeout: Duration,
pub(crate) terminate_message:
Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
}
impl<N: NetworkPrimitives> ActiveSession<N> {
fn is_disconnecting(&self) -> bool {
self.conn.inner().is_disconnecting()
}
fn next_id(&mut self) -> u64 {
let id = self.next_id;
self.next_id += 1;
id
}
pub fn shrink_to_fit(&mut self) {
self.received_requests_from_remote.shrink_to_fit();
self.queued_outgoing.shrink_to_fit();
}
fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
macro_rules! on_request {
($req:ident, $resp_item:ident, $req_item:ident) => {{
let RequestPair { request_id, message: request } = $req;
let (tx, response) = oneshot::channel();
let received = ReceivedRequest {
request_id,
rx: PeerResponse::$resp_item { response },
received: Instant::now(),
};
self.received_requests_from_remote.push(received);
self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
request,
response: tx,
}))
.into()
}};
}
macro_rules! on_response {
($resp:ident, $item:ident) => {{
let RequestPair { request_id, message } = $resp;
#[allow(clippy::collapsible_match)]
if let Some(req) = self.inflight_requests.remove(&request_id) {
match req.request {
RequestState::Waiting(PeerRequest::$item { response, .. }) => {
let _ = response.send(Ok(message));
self.update_request_timeout(req.timestamp, Instant::now());
}
RequestState::Waiting(request) => {
request.send_bad_response();
}
RequestState::TimedOut => {
self.update_request_timeout(req.timestamp, Instant::now());
}
}
} else {
self.on_bad_message();
}
OnIncomingMessageOutcome::Ok
}};
}
match msg {
message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
message,
},
EthMessage::NewBlockHashes(msg) => {
self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
}
EthMessage::NewBlock(msg) => {
let block =
NewBlockMessage { hash: msg.block.header().hash_slow(), block: Arc::new(*msg) };
self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
}
EthMessage::Transactions(msg) => {
self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
}
EthMessage::NewPooledTransactionHashes66(msg) => {
self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
}
EthMessage::NewPooledTransactionHashes68(msg) => {
if msg.hashes.len() != msg.types.len() || msg.hashes.len() != msg.sizes.len() {
return OnIncomingMessageOutcome::BadMessage {
error: EthStreamError::TransactionHashesInvalidLenOfFields {
hashes_len: msg.hashes.len(),
types_len: msg.types.len(),
sizes_len: msg.sizes.len(),
},
message: EthMessage::NewPooledTransactionHashes68(msg),
}
}
self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
}
EthMessage::GetBlockHeaders(req) => {
on_request!(req, BlockHeaders, GetBlockHeaders)
}
EthMessage::BlockHeaders(resp) => {
on_response!(resp, GetBlockHeaders)
}
EthMessage::GetBlockBodies(req) => {
on_request!(req, BlockBodies, GetBlockBodies)
}
EthMessage::BlockBodies(resp) => {
on_response!(resp, GetBlockBodies)
}
EthMessage::GetPooledTransactions(req) => {
on_request!(req, PooledTransactions, GetPooledTransactions)
}
EthMessage::PooledTransactions(resp) => {
on_response!(resp, GetPooledTransactions)
}
EthMessage::GetNodeData(req) => {
on_request!(req, NodeData, GetNodeData)
}
EthMessage::NodeData(resp) => {
on_response!(resp, GetNodeData)
}
EthMessage::GetReceipts(req) => {
on_request!(req, Receipts, GetReceipts)
}
EthMessage::Receipts(resp) => {
on_response!(resp, GetReceipts)
}
}
}
fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
let request_id = self.next_id();
let msg = request.create_request_message(request_id);
self.queued_outgoing.push_back(msg.into());
let req = InflightRequest {
request: RequestState::Waiting(request),
timestamp: Instant::now(),
deadline,
};
self.inflight_requests.insert(request_id, req);
}
fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
match msg {
PeerMessage::NewBlockHashes(msg) => {
self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
}
PeerMessage::NewBlock(msg) => {
self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
}
PeerMessage::PooledTransactions(msg) => {
if msg.is_valid_for_version(self.conn.version()) {
self.queued_outgoing.push_back(EthMessage::from(msg).into());
}
}
PeerMessage::EthRequest(req) => {
let deadline = self.request_deadline();
self.on_internal_peer_request(req, deadline);
}
PeerMessage::SendTransactions(msg) => {
self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
}
PeerMessage::ReceivedTransaction(_) => {
unreachable!("Not emitted by network")
}
PeerMessage::Other(other) => {
debug!(target: "net::session", message_id=%other.id, "Ignoring unsupported message");
self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
}
}
}
fn request_deadline(&self) -> Instant {
Instant::now() +
Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
}
fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
match resp.try_into_message(id) {
Ok(msg) => {
self.queued_outgoing.push_back(msg.into());
}
Err(err) => {
debug!(target: "net", %err, "Failed to respond to received request");
}
}
}
#[allow(clippy::result_large_err)]
fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
match sender
.try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
{
Ok(_) => Ok(()),
Err(err) => {
trace!(
target: "net",
%err,
"no capacity for incoming broadcast",
);
match err {
TrySendError::Full(msg) => Err(msg),
TrySendError::Closed(_) => Ok(()),
}
}
}
}
#[allow(clippy::result_large_err)]
fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
match sender
.try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
{
Ok(_) => Ok(()),
Err(err) => {
trace!(
target: "net",
%err,
"no capacity for incoming request",
);
match err {
TrySendError::Full(msg) => Err(msg),
TrySendError::Closed(_) => {
Ok(())
}
}
}
}
}
fn on_bad_message(&self) {
let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
}
fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
let msg = ActiveSessionMessage::Disconnected {
peer_id: self.remote_peer_id,
remote_addr: self.remote_addr,
};
self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
self.poll_terminate_message(cx).expect("message is set")
}
fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
let msg = ActiveSessionMessage::ClosedOnConnectionError {
peer_id: self.remote_peer_id,
remote_addr: self.remote_addr,
error,
};
self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
self.poll_terminate_message(cx).expect("message is set")
}
fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
self.conn
.inner_mut()
.start_disconnect(reason)
.map_err(P2PStreamError::from)
.map_err(Into::into)
}
fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
debug_assert!(self.is_disconnecting(), "not disconnecting");
let _ = ready!(self.conn.poll_close_unpin(cx));
self.emit_disconnect(cx)
}
fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
match self.start_disconnect(reason) {
Ok(()) => {
self.poll_disconnect(cx)
}
Err(err) => {
debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
self.close_on_error(err, cx)
}
}
}
#[must_use]
fn check_timed_out_requests(&mut self, now: Instant) -> bool {
for (id, req) in &mut self.inflight_requests {
if req.is_timed_out(now) {
if req.is_waiting() {
debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
req.timeout();
} else if now - req.timestamp > self.protocol_breach_request_timeout {
return true
}
}
}
false
}
fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
let elapsed = received.saturating_duration_since(sent);
let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
let request_timeout = calculate_new_timeout(current, elapsed);
self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
}
fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
let (mut tx, msg) = self.terminate_message.take()?;
match tx.poll_reserve(cx) {
Poll::Pending => {
self.terminate_message = Some((tx, msg));
return Some(Poll::Pending)
}
Poll::Ready(Ok(())) => {
let _ = tx.send_item(msg);
}
Poll::Ready(Err(_)) => {
}
}
Some(Poll::Ready(()))
}
}
impl<N: NetworkPrimitives> Future for ActiveSession<N> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if let Some(terminate) = this.poll_terminate_message(cx) {
return terminate
}
if this.is_disconnecting() {
return this.poll_disconnect(cx)
}
let mut budget = 4;
'main: loop {
let mut progress = false;
loop {
match this.commands_rx.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => {
return Poll::Ready(())
}
Poll::Ready(Some(cmd)) => {
progress = true;
match cmd {
SessionCommand::Disconnect { reason } => {
debug!(
target: "net::session",
?reason,
remote_peer_id=?this.remote_peer_id,
"Received disconnect command for session"
);
let reason =
reason.unwrap_or(DisconnectReason::DisconnectRequested);
return this.try_disconnect(reason, cx)
}
SessionCommand::Message(msg) => {
this.on_internal_peer_message(msg);
}
}
}
}
}
let deadline = this.request_deadline();
while let Poll::Ready(Some(req)) = this.internal_request_tx.poll_next_unpin(cx) {
progress = true;
this.on_internal_peer_request(req, deadline);
}
for idx in (0..this.received_requests_from_remote.len()).rev() {
let mut req = this.received_requests_from_remote.swap_remove(idx);
match req.rx.poll(cx) {
Poll::Pending => {
this.received_requests_from_remote.push(req);
}
Poll::Ready(resp) => {
this.handle_outgoing_response(req.request_id, resp);
}
}
}
while this.conn.poll_ready_unpin(cx).is_ready() {
if let Some(msg) = this.queued_outgoing.pop_front() {
progress = true;
let res = match msg {
OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
};
if let Err(err) = res {
debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
return this.close_on_error(err, cx)
}
} else {
break
}
}
'receive: loop {
budget -= 1;
if budget == 0 {
cx.waker().wake_by_ref();
break 'main
}
if let Some(msg) = this.pending_message_to_session.take() {
match this.to_session_manager.poll_reserve(cx) {
Poll::Ready(Ok(_)) => {
let _ = this.to_session_manager.send_item(msg);
}
Poll::Ready(Err(_)) => return Poll::Ready(()),
Poll::Pending => {
this.pending_message_to_session = Some(msg);
break 'receive
}
};
}
match this.conn.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => {
if this.is_disconnecting() {
break
}
debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
return this.emit_disconnect(cx)
}
Poll::Ready(Some(res)) => {
match res {
Ok(msg) => {
trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
match this.on_incoming_message(msg) {
OnIncomingMessageOutcome::Ok => {
progress = true;
}
OnIncomingMessageOutcome::BadMessage { error, message } => {
debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
return this.close_on_error(error, cx)
}
OnIncomingMessageOutcome::NoCapacity(msg) => {
this.pending_message_to_session = Some(msg);
continue 'receive
}
}
}
Err(err) => {
debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
return this.close_on_error(err, cx)
}
}
}
}
}
if !progress {
break 'main
}
}
while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
if this.check_timed_out_requests(Instant::now()) {
if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
this.pending_message_to_session = Some(msg);
}
}
}
this.shrink_to_fit();
Poll::Pending
}
}
pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
request_id: u64,
rx: PeerResponse<N>,
#[allow(dead_code)]
received: Instant,
}
pub(crate) struct InflightRequest<R> {
request: RequestState<R>,
timestamp: Instant,
deadline: Instant,
}
impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
#[inline]
fn is_timed_out(&self, now: Instant) -> bool {
now > self.deadline
}
#[inline]
const fn is_waiting(&self) -> bool {
matches!(self.request, RequestState::Waiting(_))
}
fn timeout(&mut self) {
let mut req = RequestState::TimedOut;
std::mem::swap(&mut self.request, &mut req);
if let RequestState::Waiting(req) = req {
req.send_err_response(RequestError::Timeout);
}
}
}
enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
Ok,
BadMessage { error: EthStreamError, message: EthMessage<N> },
NoCapacity(ActiveSessionMessage<N>),
}
impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
for OnIncomingMessageOutcome<N>
{
fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
match res {
Ok(_) => Self::Ok,
Err(msg) => Self::NoCapacity(msg),
}
}
}
enum RequestState<R> {
Waiting(R),
TimedOut,
}
pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
Eth(EthMessage<N>),
Broadcast(EthBroadcastMessage<N>),
Raw(RawCapabilityMessage),
}
impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
fn from(value: EthMessage<N>) -> Self {
Self::Eth(value)
}
}
impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
fn from(value: EthBroadcastMessage<N>) -> Self {
Self::Broadcast(value)
}
}
#[inline]
fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
}
pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
messages: VecDeque<OutgoingMessage<N>>,
count: Gauge,
}
impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
pub(crate) const fn new(metric: Gauge) -> Self {
Self { messages: VecDeque::new(), count: metric }
}
pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
self.messages.push_back(message);
self.count.increment(1);
}
pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
self.messages.pop_front().inspect(|_| self.count.decrement(1))
}
pub(crate) fn shrink_to_fit(&mut self) {
self.messages.shrink_to_fit();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
use reth_chainspec::MAINNET;
use reth_ecies::stream::ECIESStream;
use reth_eth_wire::{
EthNetworkPrimitives, EthStream, GetBlockBodies, HelloMessageWithProtocols, P2PStream,
Status, StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
};
use reth_network_peers::pk2id;
use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
use reth_primitives::{EthereumHardfork, ForkFilter};
use secp256k1::{SecretKey, SECP256K1};
use tokio::{
net::{TcpListener, TcpStream},
sync::mpsc,
};
fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
}
struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
_remote_capabilities: Arc<Capabilities>,
active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
secret_key: SecretKey,
local_peer_id: PeerId,
hello: HelloMessageWithProtocols,
status: Status,
fork_filter: ForkFilter,
next_id: usize,
}
impl<N: NetworkPrimitives> SessionBuilder<N> {
fn next_id(&mut self) -> SessionId {
let id = self.next_id;
self.next_id += 1;
SessionId(id)
}
fn with_client_stream<F, O>(
&self,
local_addr: SocketAddr,
f: F,
) -> Pin<Box<dyn Future<Output = ()> + Send>>
where
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
O: Future<Output = ()> + Send + Sync,
{
let status = self.status;
let fork_filter = self.fork_filter.clone();
let local_peer_id = self.local_peer_id;
let mut hello = self.hello.clone();
let key = SecretKey::new(&mut rand::thread_rng());
hello.id = pk2id(&key.public_key(SECP256K1));
Box::pin(async move {
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
.handshake(status, fork_filter)
.await
.unwrap();
f(client_stream).await
})
}
async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
let remote_addr = stream.local_addr().unwrap();
let session_id = self.next_id();
let (_disconnect_tx, disconnect_rx) = oneshot::channel();
let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
tokio::task::spawn(start_pending_incoming_session(
disconnect_rx,
session_id,
stream,
pending_sessions_tx,
remote_addr,
self.secret_key,
self.hello.clone(),
self.status,
self.fork_filter.clone(),
Default::default(),
));
let mut stream = ReceiverStream::new(pending_sessions_rx);
match stream.next().await.unwrap() {
PendingSessionEvent::Established {
session_id,
remote_addr,
peer_id,
capabilities,
conn,
..
} => {
let (_to_session_tx, messages_rx) = mpsc::channel(10);
let (commands_to_session, commands_rx) = mpsc::channel(10);
let poll_sender = PollSender::new(self.active_session_tx.clone());
self.to_sessions.push(commands_to_session);
ActiveSession {
next_id: 0,
remote_peer_id: peer_id,
remote_addr,
remote_capabilities: Arc::clone(&capabilities),
session_id,
commands_rx: ReceiverStream::new(commands_rx),
to_session_manager: MeteredPollSender::new(
poll_sender,
"network_active_session",
),
pending_message_to_session: None,
internal_request_tx: ReceiverStream::new(messages_rx).fuse(),
inflight_requests: Default::default(),
conn,
queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
received_requests_from_remote: Default::default(),
internal_request_timeout_interval: tokio::time::interval(
INITIAL_REQUEST_TIMEOUT,
),
internal_request_timeout: Arc::new(AtomicU64::new(
INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
)),
protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
terminate_message: None,
}
}
ev => {
panic!("unexpected message {ev:?}")
}
}
}
}
impl Default for SessionBuilder {
fn default() -> Self {
let (active_session_tx, active_session_rx) = mpsc::channel(100);
let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand::thread_rng());
let local_peer_id = pk2id(&pk);
Self {
next_id: 0,
_remote_capabilities: Arc::new(Capabilities::from(vec![])),
active_session_tx,
active_session_rx: ReceiverStream::new(active_session_rx),
to_sessions: vec![],
hello: eth_hello(&secret_key),
secret_key,
local_peer_id,
status: StatusBuilder::default().build(),
fork_filter: MAINNET
.hardfork_fork_filter(EthereumHardfork::Frontier)
.expect("The Frontier fork filter should exist on mainnet"),
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_disconnect() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let expected_disconnect = DisconnectReason::UselessPeer;
let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
let msg = client_stream.next().await.unwrap().unwrap_err();
assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
});
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let mut session = builder.connect_incoming(incoming).await;
session.start_disconnect(expected_disconnect).unwrap();
session.await
});
fut.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn handle_dropped_stream() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
drop(client_stream);
tokio::time::sleep(Duration::from_secs(1)).await
});
let (tx, rx) = oneshot::channel();
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let session = builder.connect_incoming(incoming).await;
session.await;
tx.send(()).unwrap();
});
tokio::task::spawn(fut);
rx.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_send_many_messages() {
reth_tracing::init_test_tracing();
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let num_messages = 100;
let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
for _ in 0..num_messages {
client_stream
.send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
.await
.unwrap();
}
});
let (tx, rx) = oneshot::channel();
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let session = builder.connect_incoming(incoming).await;
session.await;
tx.send(()).unwrap();
});
tokio::task::spawn(fut);
rx.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_request_timeout() {
reth_tracing::init_test_tracing();
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let request_timeout = Duration::from_millis(100);
let drop_timeout = Duration::from_millis(1500);
let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
let _client_stream = client_stream;
tokio::time::sleep(drop_timeout * 60).await;
});
tokio::task::spawn(fut);
let (incoming, _) = listener.accept().await.unwrap();
let mut session = builder.connect_incoming(incoming).await;
session
.internal_request_timeout
.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
session.protocol_breach_request_timeout = drop_timeout;
session.internal_request_timeout_interval =
tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
let (tx, rx) = oneshot::channel();
let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
session.on_internal_peer_request(req, Instant::now());
tokio::spawn(session);
let err = rx.await.unwrap().unwrap_err();
assert_eq!(err, RequestError::Timeout);
let msg = builder.active_session_rx.next().await.unwrap();
match msg {
ActiveSessionMessage::ProtocolBreach { .. } => {}
ev => unreachable!("{ev:?}"),
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_keep_alive() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
});
let (tx, rx) = oneshot::channel();
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let session = builder.connect_incoming(incoming).await;
session.await;
tx.send(()).unwrap();
});
tokio::task::spawn(fut);
rx.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_send_at_capacity() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
client_stream
.send(EthMessage::NewPooledTransactionHashes68(Default::default()))
.await
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
});
tokio::task::spawn(fut);
let (incoming, _) = listener.accept().await.unwrap();
let session = builder.connect_incoming(incoming).await;
let mut num_fill_messages = 0;
loop {
if builder
.active_session_tx
.try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
.is_err()
{
break
}
num_fill_messages += 1;
}
tokio::task::spawn(async move {
session.await;
});
tokio::time::sleep(Duration::from_millis(100)).await;
for _ in 0..num_fill_messages {
let message = builder.active_session_rx.next().await.unwrap();
match message {
ActiveSessionMessage::ProtocolBreach { .. } => {}
ev => unreachable!("{ev:?}"),
}
}
let message = builder.active_session_rx.next().await.unwrap();
match message {
ActiveSessionMessage::ValidMessage {
message: PeerMessage::PooledTransactions(_),
..
} => {}
_ => unreachable!(),
}
}
#[test]
fn timeout_calculation_sanity_tests() {
let rtt = Duration::from_secs(5);
let timeout = rtt * TIMEOUT_SCALING;
assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
}
}