1use core::sync::atomic::Ordering;
4use std::{
5 collections::VecDeque,
6 future::Future,
7 net::SocketAddr,
8 pin::Pin,
9 sync::{atomic::AtomicU64, Arc},
10 task::{ready, Context, Poll},
11 time::{Duration, Instant},
12};
13
14use crate::{
15 message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
16 session::{
17 conn::EthRlpxConnection,
18 handle::{ActiveSessionMessage, SessionCommand},
19 BlockRangeInfo, EthVersion, SessionId,
20 },
21};
22use alloy_eips::merge::EPOCH_SLOTS;
23use alloy_primitives::Sealable;
24use futures::{stream::Fuse, SinkExt, StreamExt};
25use metrics::Gauge;
26use reth_eth_wire::{
27 errors::{EthHandshakeError, EthStreamError},
28 message::{EthBroadcastMessage, MessageError, RequestPair},
29 Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives, NewBlockPayload,
30};
31use reth_eth_wire_types::RawCapabilityMessage;
32use reth_metrics::common::mpsc::MeteredPollSender;
33use reth_network_api::PeerRequest;
34use reth_network_p2p::error::RequestError;
35use reth_network_peers::PeerId;
36use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
37use reth_primitives_traits::Block;
38use rustc_hash::FxHashMap;
39use tokio::{
40 sync::{mpsc::error::TrySendError, oneshot},
41 time::Interval,
42};
43use tokio_stream::wrappers::ReceiverStream;
44use tokio_util::sync::PollSender;
45use tracing::{debug, trace};
46
47pub(super) const RANGE_UPDATE_INTERVAL: Duration = Duration::from_secs(EPOCH_SLOTS * 12);
53
54const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
58
59const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
61const SAMPLE_IMPACT: f64 = 0.1;
63const TIMEOUT_SCALING: u32 = 3;
65
66const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
78
79#[expect(dead_code)]
89pub(crate) struct ActiveSession<N: NetworkPrimitives> {
90 pub(crate) next_id: u64,
92 pub(crate) conn: EthRlpxConnection<N>,
94 pub(crate) remote_peer_id: PeerId,
96 pub(crate) remote_addr: SocketAddr,
98 pub(crate) remote_capabilities: Arc<Capabilities>,
100 pub(crate) session_id: SessionId,
102 pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
104 pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
106 pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
108 pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
110 pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
112 pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
114 pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
116 pub(crate) internal_request_timeout: Arc<AtomicU64>,
118 pub(crate) internal_request_timeout_interval: Interval,
120 pub(crate) protocol_breach_request_timeout: Duration,
123 pub(crate) terminate_message:
125 Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
126 pub(crate) range_info: Option<BlockRangeInfo>,
128 pub(crate) local_range_info: BlockRangeInfo,
131 pub(crate) range_update_interval: Option<Interval>,
135 pub(crate) last_sent_latest_block: Option<u64>,
138}
139
140impl<N: NetworkPrimitives> ActiveSession<N> {
141 fn is_disconnecting(&self) -> bool {
143 self.conn.inner().is_disconnecting()
144 }
145
146 const fn next_id(&mut self) -> u64 {
148 let id = self.next_id;
149 self.next_id += 1;
150 id
151 }
152
153 pub fn shrink_to_fit(&mut self) {
155 self.received_requests_from_remote.shrink_to_fit();
156 self.queued_outgoing.shrink_to_fit();
157 }
158
159 fn queued_response_count(&self) -> usize {
161 self.queued_outgoing.messages.iter().filter(|m| m.is_response()).count()
162 }
163
164 fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
168 macro_rules! on_request {
172 ($req:ident, $resp_item:ident, $req_item:ident) => {{
173 let RequestPair { request_id, message: request } = $req;
174 let (tx, response) = oneshot::channel();
175 let received = ReceivedRequest {
176 request_id,
177 rx: PeerResponse::$resp_item { response },
178 received: Instant::now(),
179 };
180 self.received_requests_from_remote.push(received);
181 self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
182 request,
183 response: tx,
184 }))
185 .into()
186 }};
187 }
188
189 macro_rules! on_response {
191 ($resp:ident, $item:ident) => {{
192 let RequestPair { request_id, message } = $resp;
193 if let Some(req) = self.inflight_requests.remove(&request_id) {
194 match req.request {
195 RequestState::Waiting(PeerRequest::$item { response, .. }) => {
196 trace!(peer_id=?self.remote_peer_id, ?request_id, "received response from peer");
197 let _ = response.send(Ok(message));
198 self.update_request_timeout(req.timestamp, Instant::now());
199 }
200 RequestState::Waiting(request) => {
201 request.send_bad_response();
202 }
203 RequestState::TimedOut => {
204 self.update_request_timeout(req.timestamp, Instant::now());
206 }
207 }
208 } else {
209 trace!(peer_id=?self.remote_peer_id, ?request_id, "received response to unknown request");
210 self.on_bad_message();
212 }
213
214 OnIncomingMessageOutcome::Ok
215 }};
216 }
217
218 match msg {
219 message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
220 error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
221 message,
222 },
223 EthMessage::NewBlockHashes(msg) => {
224 self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
225 }
226 EthMessage::NewBlock(msg) => {
227 let block = NewBlockMessage {
228 hash: msg.block().header().hash_slow(),
229 block: Arc::new(*msg),
230 };
231 self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
232 }
233 EthMessage::Transactions(msg) => {
234 self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
235 }
236 EthMessage::NewPooledTransactionHashes66(msg) => {
237 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
238 }
239 EthMessage::NewPooledTransactionHashes68(msg) => {
240 if msg.hashes.len() != msg.types.len() || msg.hashes.len() != msg.sizes.len() {
241 return OnIncomingMessageOutcome::BadMessage {
242 error: EthStreamError::TransactionHashesInvalidLenOfFields {
243 hashes_len: msg.hashes.len(),
244 types_len: msg.types.len(),
245 sizes_len: msg.sizes.len(),
246 },
247 message: EthMessage::NewPooledTransactionHashes68(msg),
248 }
249 }
250 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
251 }
252 EthMessage::GetBlockHeaders(req) => {
253 on_request!(req, BlockHeaders, GetBlockHeaders)
254 }
255 EthMessage::BlockHeaders(resp) => {
256 on_response!(resp, GetBlockHeaders)
257 }
258 EthMessage::GetBlockBodies(req) => {
259 on_request!(req, BlockBodies, GetBlockBodies)
260 }
261 EthMessage::BlockBodies(resp) => {
262 on_response!(resp, GetBlockBodies)
263 }
264 EthMessage::GetPooledTransactions(req) => {
265 on_request!(req, PooledTransactions, GetPooledTransactions)
266 }
267 EthMessage::PooledTransactions(resp) => {
268 on_response!(resp, GetPooledTransactions)
269 }
270 EthMessage::GetNodeData(req) => {
271 on_request!(req, NodeData, GetNodeData)
272 }
273 EthMessage::NodeData(resp) => {
274 on_response!(resp, GetNodeData)
275 }
276 EthMessage::GetReceipts(req) => {
277 if self.conn.version() >= EthVersion::Eth69 {
278 on_request!(req, Receipts69, GetReceipts69)
279 } else {
280 on_request!(req, Receipts, GetReceipts)
281 }
282 }
283 EthMessage::Receipts(resp) => {
284 on_response!(resp, GetReceipts)
285 }
286 EthMessage::Receipts69(resp) => {
287 on_response!(resp, GetReceipts69)
288 }
289 EthMessage::BlockRangeUpdate(msg) => {
290 if msg.earliest > msg.latest {
292 return OnIncomingMessageOutcome::BadMessage {
293 error: EthStreamError::InvalidMessage(MessageError::Other(format!(
294 "invalid block range: earliest ({}) > latest ({})",
295 msg.earliest, msg.latest
296 ))),
297 message: EthMessage::BlockRangeUpdate(msg),
298 };
299 }
300
301 if msg.latest_hash.is_zero() {
303 return OnIncomingMessageOutcome::BadMessage {
304 error: EthStreamError::InvalidMessage(MessageError::Other(
305 "invalid block range: latest_hash cannot be zero".to_string(),
306 )),
307 message: EthMessage::BlockRangeUpdate(msg),
308 };
309 }
310
311 if let Some(range_info) = self.range_info.as_ref() {
312 range_info.update(msg.earliest, msg.latest, msg.latest_hash);
313 }
314
315 OnIncomingMessageOutcome::Ok
316 }
317 EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
318 }
319 }
320
321 fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
323 let request_id = self.next_id();
324
325 trace!(?request, peer_id=?self.remote_peer_id, ?request_id, "sending request to peer");
326 let msg = request.create_request_message(request_id);
327 self.queued_outgoing.push_back(msg.into());
328 let req = InflightRequest {
329 request: RequestState::Waiting(request),
330 timestamp: Instant::now(),
331 deadline,
332 };
333 self.inflight_requests.insert(request_id, req);
334 }
335
336 fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
338 match msg {
339 PeerMessage::NewBlockHashes(msg) => {
340 self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
341 }
342 PeerMessage::NewBlock(msg) => {
343 self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
344 }
345 PeerMessage::PooledTransactions(msg) => {
346 if msg.is_valid_for_version(self.conn.version()) {
347 self.queued_outgoing.push_back(EthMessage::from(msg).into());
348 } else {
349 debug!(target: "net", ?msg, version=?self.conn.version(), "Message is invalid for connection version, skipping");
350 }
351 }
352 PeerMessage::EthRequest(req) => {
353 let deadline = self.request_deadline();
354 self.on_internal_peer_request(req, deadline);
355 }
356 PeerMessage::SendTransactions(msg) => {
357 self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
358 }
359 PeerMessage::BlockRangeUpdated(_) => {}
360 PeerMessage::ReceivedTransaction(_) => {
361 unreachable!("Not emitted by network")
362 }
363 PeerMessage::Other(other) => {
364 self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
365 }
366 }
367 }
368
369 fn request_deadline(&self) -> Instant {
371 Instant::now() +
372 Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
373 }
374
375 fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
379 match resp.try_into_message(id) {
380 Ok(msg) => {
381 self.queued_outgoing.push_back(msg.into());
382 }
383 Err(err) => {
384 debug!(target: "net", %err, "Failed to respond to received request");
385 }
386 }
387 }
388
389 #[expect(clippy::result_large_err)]
393 fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
394 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
395
396 match sender
397 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
398 {
399 Ok(_) => Ok(()),
400 Err(err) => {
401 trace!(
402 target: "net",
403 %err,
404 "no capacity for incoming broadcast",
405 );
406 match err {
407 TrySendError::Full(msg) => Err(msg),
408 TrySendError::Closed(_) => Ok(()),
409 }
410 }
411 }
412 }
413
414 #[expect(clippy::result_large_err)]
419 fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
420 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
421
422 match sender
423 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
424 {
425 Ok(_) => Ok(()),
426 Err(err) => {
427 trace!(
428 target: "net",
429 %err,
430 "no capacity for incoming request",
431 );
432 match err {
433 TrySendError::Full(msg) => Err(msg),
434 TrySendError::Closed(_) => {
435 Ok(())
438 }
439 }
440 }
441 }
442 }
443
444 fn on_bad_message(&self) {
446 let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
447 let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
448 }
449
450 fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
452 trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
453 let msg = ActiveSessionMessage::Disconnected {
454 peer_id: self.remote_peer_id,
455 remote_addr: self.remote_addr,
456 };
457
458 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
459 self.poll_terminate_message(cx).expect("message is set")
460 }
461
462 fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
464 let msg = ActiveSessionMessage::ClosedOnConnectionError {
465 peer_id: self.remote_peer_id,
466 remote_addr: self.remote_addr,
467 error,
468 };
469 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
470 self.poll_terminate_message(cx).expect("message is set")
471 }
472
473 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
475 Ok(self.conn.inner_mut().start_disconnect(reason)?)
476 }
477
478 fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
480 debug_assert!(self.is_disconnecting(), "not disconnecting");
481
482 let _ = ready!(self.conn.poll_close_unpin(cx));
484 self.emit_disconnect(cx)
485 }
486
487 fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
489 match self.start_disconnect(reason) {
490 Ok(()) => {
491 self.poll_disconnect(cx)
493 }
494 Err(err) => {
495 debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
496 self.close_on_error(err, cx)
497 }
498 }
499 }
500
501 #[must_use]
510 fn check_timed_out_requests(&mut self, now: Instant) -> bool {
511 for (id, req) in &mut self.inflight_requests {
512 if req.is_timed_out(now) {
513 if req.is_waiting() {
514 debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
515 req.timeout();
516 } else if now - req.timestamp > self.protocol_breach_request_timeout {
517 return true
518 }
519 }
520 }
521
522 false
523 }
524
525 fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
527 let elapsed = received.saturating_duration_since(sent);
528
529 let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
530 let request_timeout = calculate_new_timeout(current, elapsed);
531 self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
532 self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
533 }
534
535 fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
537 let (mut tx, msg) = self.terminate_message.take()?;
538 match tx.poll_reserve(cx) {
539 Poll::Pending => {
540 self.terminate_message = Some((tx, msg));
541 return Some(Poll::Pending)
542 }
543 Poll::Ready(Ok(())) => {
544 let _ = tx.send_item(msg);
545 }
546 Poll::Ready(Err(_)) => {
547 }
549 }
550 Some(Poll::Ready(()))
552 }
553}
554
555impl<N: NetworkPrimitives> Future for ActiveSession<N> {
556 type Output = ();
557
558 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
559 let this = self.get_mut();
560
561 if let Some(terminate) = this.poll_terminate_message(cx) {
563 return terminate
564 }
565
566 if this.is_disconnecting() {
567 return this.poll_disconnect(cx)
568 }
569
570 let mut budget = 4;
576
577 'main: loop {
579 let mut progress = false;
580
581 loop {
583 match this.commands_rx.poll_next_unpin(cx) {
584 Poll::Pending => break,
585 Poll::Ready(None) => {
586 return Poll::Ready(())
589 }
590 Poll::Ready(Some(cmd)) => {
591 progress = true;
592 match cmd {
593 SessionCommand::Disconnect { reason } => {
594 debug!(
595 target: "net::session",
596 ?reason,
597 remote_peer_id=?this.remote_peer_id,
598 "Received disconnect command for session"
599 );
600 let reason =
601 reason.unwrap_or(DisconnectReason::DisconnectRequested);
602
603 return this.try_disconnect(reason, cx)
604 }
605 SessionCommand::Message(msg) => {
606 this.on_internal_peer_message(msg);
607 }
608 }
609 }
610 }
611 }
612
613 let deadline = this.request_deadline();
614
615 while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
616 progress = true;
617 this.on_internal_peer_request(req, deadline);
618 }
619
620 for idx in (0..this.received_requests_from_remote.len()).rev() {
623 let mut req = this.received_requests_from_remote.swap_remove(idx);
624 match req.rx.poll(cx) {
625 Poll::Pending => {
626 this.received_requests_from_remote.push(req);
628 }
629 Poll::Ready(resp) => {
630 this.handle_outgoing_response(req.request_id, resp);
631 }
632 }
633 }
634
635 while this.conn.poll_ready_unpin(cx).is_ready() {
637 if let Some(msg) = this.queued_outgoing.pop_front() {
638 progress = true;
639 let res = match msg {
640 OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
641 OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
642 OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
643 };
644 if let Err(err) = res {
645 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
646 return this.close_on_error(err, cx)
648 }
649 } else {
650 break
652 }
653 }
654
655 'receive: loop {
657 budget -= 1;
659 if budget == 0 {
660 cx.waker().wake_by_ref();
662 break 'main
663 }
664
665 if let Some(msg) = this.pending_message_to_session.take() {
669 match this.to_session_manager.poll_reserve(cx) {
670 Poll::Ready(Ok(_)) => {
671 let _ = this.to_session_manager.send_item(msg);
672 }
673 Poll::Ready(Err(_)) => return Poll::Ready(()),
674 Poll::Pending => {
675 this.pending_message_to_session = Some(msg);
676 break 'receive
677 }
678 };
679 }
680
681 if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
683 break 'receive
689 }
690
691 if this.queued_outgoing.messages.len() > MAX_QUEUED_OUTGOING_RESPONSES &&
693 this.queued_response_count() > MAX_QUEUED_OUTGOING_RESPONSES
694 {
695 break 'receive
702 }
703
704 match this.conn.poll_next_unpin(cx) {
705 Poll::Pending => break,
706 Poll::Ready(None) => {
707 if this.is_disconnecting() {
708 break
709 }
710 debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
711 return this.emit_disconnect(cx)
712 }
713 Poll::Ready(Some(res)) => {
714 match res {
715 Ok(msg) => {
716 trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
717 match this.on_incoming_message(msg) {
719 OnIncomingMessageOutcome::Ok => {
720 progress = true;
722 }
723 OnIncomingMessageOutcome::BadMessage { error, message } => {
724 debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
725 return this.close_on_error(error, cx)
726 }
727 OnIncomingMessageOutcome::NoCapacity(msg) => {
728 this.pending_message_to_session = Some(msg);
730 }
731 }
732 }
733 Err(err) => {
734 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
735 return this.close_on_error(err, cx)
736 }
737 }
738 }
739 }
740 }
741
742 if !progress {
743 break 'main
744 }
745 }
746
747 if let Some(interval) = &mut this.range_update_interval {
748 while interval.poll_tick(cx).is_ready() {
750 let current_latest = this.local_range_info.latest();
751 let should_send = if let Some(last_sent) = this.last_sent_latest_block {
752 current_latest.saturating_sub(last_sent) >= EPOCH_SLOTS
754 } else {
755 true };
757
758 if should_send {
759 this.queued_outgoing.push_back(
760 EthMessage::BlockRangeUpdate(this.local_range_info.to_message()).into(),
761 );
762 this.last_sent_latest_block = Some(current_latest);
763 }
764 }
765 }
766
767 while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
768 if this.check_timed_out_requests(Instant::now()) &&
770 let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx)
771 {
772 let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
773 this.pending_message_to_session = Some(msg);
774 }
775 }
776
777 this.shrink_to_fit();
778
779 Poll::Pending
780 }
781}
782
783pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
785 request_id: u64,
787 rx: PeerResponse<N>,
789 #[expect(dead_code)]
791 received: Instant,
792}
793
794pub(crate) struct InflightRequest<R> {
796 request: RequestState<R>,
798 timestamp: Instant,
800 deadline: Instant,
802}
803
804impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
807 #[inline]
809 fn is_timed_out(&self, now: Instant) -> bool {
810 now > self.deadline
811 }
812
813 #[inline]
815 const fn is_waiting(&self) -> bool {
816 matches!(self.request, RequestState::Waiting(_))
817 }
818
819 fn timeout(&mut self) {
821 let mut req = RequestState::TimedOut;
822 std::mem::swap(&mut self.request, &mut req);
823
824 if let RequestState::Waiting(req) = req {
825 req.send_err_response(RequestError::Timeout);
826 }
827 }
828}
829
830enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
832 Ok,
834 BadMessage { error: EthStreamError, message: EthMessage<N> },
836 NoCapacity(ActiveSessionMessage<N>),
838}
839
840impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
841 for OnIncomingMessageOutcome<N>
842{
843 fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
844 match res {
845 Ok(_) => Self::Ok,
846 Err(msg) => Self::NoCapacity(msg),
847 }
848 }
849}
850
851enum RequestState<R> {
852 Waiting(R),
854 TimedOut,
856}
857
858#[derive(Debug)]
860pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
861 Eth(EthMessage<N>),
863 Broadcast(EthBroadcastMessage<N>),
865 Raw(RawCapabilityMessage),
867}
868
869impl<N: NetworkPrimitives> OutgoingMessage<N> {
870 const fn is_response(&self) -> bool {
872 match self {
873 Self::Eth(msg) => msg.is_response(),
874 _ => false,
875 }
876 }
877}
878
879impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
880 fn from(value: EthMessage<N>) -> Self {
881 Self::Eth(value)
882 }
883}
884
885impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
886 fn from(value: EthBroadcastMessage<N>) -> Self {
887 Self::Broadcast(value)
888 }
889}
890
891#[inline]
893fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
894 let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
895
896 let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
898
899 smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
900}
901
902pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
904 messages: VecDeque<OutgoingMessage<N>>,
905 count: Gauge,
906}
907
908impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
909 pub(crate) const fn new(metric: Gauge) -> Self {
910 Self { messages: VecDeque::new(), count: metric }
911 }
912
913 pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
914 self.messages.push_back(message);
915 self.count.increment(1);
916 }
917
918 pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
919 self.messages.pop_front().inspect(|_| self.count.decrement(1))
920 }
921
922 pub(crate) fn shrink_to_fit(&mut self) {
923 self.messages.shrink_to_fit();
924 }
925}
926
927#[cfg(test)]
928mod tests {
929 use super::*;
930 use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
931 use alloy_eips::eip2124::ForkFilter;
932 use reth_chainspec::MAINNET;
933 use reth_ecies::stream::ECIESStream;
934 use reth_eth_wire::{
935 handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockBodies,
936 HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
937 UnifiedStatus,
938 };
939 use reth_ethereum_forks::EthereumHardfork;
940 use reth_network_peers::pk2id;
941 use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
942 use secp256k1::{SecretKey, SECP256K1};
943 use tokio::{
944 net::{TcpListener, TcpStream},
945 sync::mpsc,
946 };
947
948 fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
950 HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
951 }
952
953 struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
954 _remote_capabilities: Arc<Capabilities>,
955 active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
956 active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
957 to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
958 secret_key: SecretKey,
959 local_peer_id: PeerId,
960 hello: HelloMessageWithProtocols,
961 status: UnifiedStatus,
962 fork_filter: ForkFilter,
963 next_id: usize,
964 }
965
966 impl<N: NetworkPrimitives> SessionBuilder<N> {
967 fn next_id(&mut self) -> SessionId {
968 let id = self.next_id;
969 self.next_id += 1;
970 SessionId(id)
971 }
972
973 fn with_client_stream<F, O>(
975 &self,
976 local_addr: SocketAddr,
977 f: F,
978 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
979 where
980 F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
981 O: Future<Output = ()> + Send + Sync,
982 {
983 let mut status = self.status;
984 let fork_filter = self.fork_filter.clone();
985 let local_peer_id = self.local_peer_id;
986 let mut hello = self.hello.clone();
987 let key = SecretKey::new(&mut rand_08::thread_rng());
988 hello.id = pk2id(&key.public_key(SECP256K1));
989 Box::pin(async move {
990 let outgoing = TcpStream::connect(local_addr).await.unwrap();
991 let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
992
993 let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
994
995 let eth_version = p2p_stream.shared_capabilities().eth_version().unwrap();
996 status.set_eth_version(eth_version);
997
998 let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
999 .handshake(status, fork_filter)
1000 .await
1001 .unwrap();
1002 f(client_stream).await
1003 })
1004 }
1005
1006 async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
1007 let remote_addr = stream.local_addr().unwrap();
1008 let session_id = self.next_id();
1009 let (_disconnect_tx, disconnect_rx) = oneshot::channel();
1010 let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
1011
1012 tokio::task::spawn(start_pending_incoming_session(
1013 Arc::new(EthHandshake::default()),
1014 disconnect_rx,
1015 session_id,
1016 stream,
1017 pending_sessions_tx,
1018 remote_addr,
1019 self.secret_key,
1020 self.hello.clone(),
1021 self.status,
1022 self.fork_filter.clone(),
1023 Default::default(),
1024 ));
1025
1026 let mut stream = ReceiverStream::new(pending_sessions_rx);
1027
1028 match stream.next().await.unwrap() {
1029 PendingSessionEvent::Established {
1030 session_id,
1031 remote_addr,
1032 peer_id,
1033 capabilities,
1034 conn,
1035 ..
1036 } => {
1037 let (_to_session_tx, messages_rx) = mpsc::channel(10);
1038 let (commands_to_session, commands_rx) = mpsc::channel(10);
1039 let poll_sender = PollSender::new(self.active_session_tx.clone());
1040
1041 self.to_sessions.push(commands_to_session);
1042
1043 ActiveSession {
1044 next_id: 0,
1045 remote_peer_id: peer_id,
1046 remote_addr,
1047 remote_capabilities: Arc::clone(&capabilities),
1048 session_id,
1049 commands_rx: ReceiverStream::new(commands_rx),
1050 to_session_manager: MeteredPollSender::new(
1051 poll_sender,
1052 "network_active_session",
1053 ),
1054 pending_message_to_session: None,
1055 internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
1056 inflight_requests: Default::default(),
1057 conn,
1058 queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
1059 received_requests_from_remote: Default::default(),
1060 internal_request_timeout_interval: tokio::time::interval(
1061 INITIAL_REQUEST_TIMEOUT,
1062 ),
1063 internal_request_timeout: Arc::new(AtomicU64::new(
1064 INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
1065 )),
1066 protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
1067 terminate_message: None,
1068 range_info: None,
1069 local_range_info: BlockRangeInfo::new(
1070 0,
1071 1000,
1072 alloy_primitives::B256::ZERO,
1073 ),
1074 range_update_interval: None,
1075 last_sent_latest_block: None,
1076 }
1077 }
1078 ev => {
1079 panic!("unexpected message {ev:?}")
1080 }
1081 }
1082 }
1083 }
1084
1085 impl Default for SessionBuilder {
1086 fn default() -> Self {
1087 let (active_session_tx, active_session_rx) = mpsc::channel(100);
1088
1089 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1090 let local_peer_id = pk2id(&pk);
1091
1092 Self {
1093 next_id: 0,
1094 _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1095 active_session_tx,
1096 active_session_rx: ReceiverStream::new(active_session_rx),
1097 to_sessions: vec![],
1098 hello: eth_hello(&secret_key),
1099 secret_key,
1100 local_peer_id,
1101 status: StatusBuilder::default().build(),
1102 fork_filter: MAINNET
1103 .hardfork_fork_filter(EthereumHardfork::Frontier)
1104 .expect("The Frontier fork filter should exist on mainnet"),
1105 }
1106 }
1107 }
1108
1109 #[tokio::test(flavor = "multi_thread")]
1110 async fn test_disconnect() {
1111 let mut builder = SessionBuilder::default();
1112
1113 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1114 let local_addr = listener.local_addr().unwrap();
1115
1116 let expected_disconnect = DisconnectReason::UselessPeer;
1117
1118 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1119 let msg = client_stream.next().await.unwrap().unwrap_err();
1120 assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1121 });
1122
1123 tokio::task::spawn(async move {
1124 let (incoming, _) = listener.accept().await.unwrap();
1125 let mut session = builder.connect_incoming(incoming).await;
1126
1127 session.start_disconnect(expected_disconnect).unwrap();
1128 session.await
1129 });
1130
1131 fut.await;
1132 }
1133
1134 #[tokio::test(flavor = "multi_thread")]
1135 async fn handle_dropped_stream() {
1136 let mut builder = SessionBuilder::default();
1137
1138 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1139 let local_addr = listener.local_addr().unwrap();
1140
1141 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1142 drop(client_stream);
1143 tokio::time::sleep(Duration::from_secs(1)).await
1144 });
1145
1146 let (tx, rx) = oneshot::channel();
1147
1148 tokio::task::spawn(async move {
1149 let (incoming, _) = listener.accept().await.unwrap();
1150 let session = builder.connect_incoming(incoming).await;
1151 session.await;
1152
1153 tx.send(()).unwrap();
1154 });
1155
1156 tokio::task::spawn(fut);
1157
1158 rx.await.unwrap();
1159 }
1160
1161 #[tokio::test(flavor = "multi_thread")]
1162 async fn test_send_many_messages() {
1163 reth_tracing::init_test_tracing();
1164 let mut builder = SessionBuilder::default();
1165
1166 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1167 let local_addr = listener.local_addr().unwrap();
1168
1169 let num_messages = 100;
1170
1171 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1172 for _ in 0..num_messages {
1173 client_stream
1174 .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1175 .await
1176 .unwrap();
1177 }
1178 });
1179
1180 let (tx, rx) = oneshot::channel();
1181
1182 tokio::task::spawn(async move {
1183 let (incoming, _) = listener.accept().await.unwrap();
1184 let session = builder.connect_incoming(incoming).await;
1185 session.await;
1186
1187 tx.send(()).unwrap();
1188 });
1189
1190 tokio::task::spawn(fut);
1191
1192 rx.await.unwrap();
1193 }
1194
1195 #[tokio::test(flavor = "multi_thread")]
1196 async fn test_request_timeout() {
1197 reth_tracing::init_test_tracing();
1198
1199 let mut builder = SessionBuilder::default();
1200
1201 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1202 let local_addr = listener.local_addr().unwrap();
1203
1204 let request_timeout = Duration::from_millis(100);
1205 let drop_timeout = Duration::from_millis(1500);
1206
1207 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1208 let _client_stream = client_stream;
1209 tokio::time::sleep(drop_timeout * 60).await;
1210 });
1211 tokio::task::spawn(fut);
1212
1213 let (incoming, _) = listener.accept().await.unwrap();
1214 let mut session = builder.connect_incoming(incoming).await;
1215 session
1216 .internal_request_timeout
1217 .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1218 session.protocol_breach_request_timeout = drop_timeout;
1219 session.internal_request_timeout_interval =
1220 tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1221 let (tx, rx) = oneshot::channel();
1222 let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1223 session.on_internal_peer_request(req, Instant::now());
1224 tokio::spawn(session);
1225
1226 let err = rx.await.unwrap().unwrap_err();
1227 assert_eq!(err, RequestError::Timeout);
1228
1229 let msg = builder.active_session_rx.next().await.unwrap();
1231 match msg {
1232 ActiveSessionMessage::ProtocolBreach { .. } => {}
1233 ev => unreachable!("{ev:?}"),
1234 }
1235 }
1236
1237 #[tokio::test(flavor = "multi_thread")]
1238 async fn test_keep_alive() {
1239 let mut builder = SessionBuilder::default();
1240
1241 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1242 let local_addr = listener.local_addr().unwrap();
1243
1244 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1245 let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1246 client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1247 });
1248
1249 let (tx, rx) = oneshot::channel();
1250
1251 tokio::task::spawn(async move {
1252 let (incoming, _) = listener.accept().await.unwrap();
1253 let session = builder.connect_incoming(incoming).await;
1254 session.await;
1255
1256 tx.send(()).unwrap();
1257 });
1258
1259 tokio::task::spawn(fut);
1260
1261 rx.await.unwrap();
1262 }
1263
1264 #[tokio::test(flavor = "multi_thread")]
1266 async fn test_send_at_capacity() {
1267 let mut builder = SessionBuilder::default();
1268
1269 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1270 let local_addr = listener.local_addr().unwrap();
1271
1272 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1273 client_stream
1274 .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1275 .await
1276 .unwrap();
1277 let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1278 });
1279 tokio::task::spawn(fut);
1280
1281 let (incoming, _) = listener.accept().await.unwrap();
1282 let session = builder.connect_incoming(incoming).await;
1283
1284 let mut num_fill_messages = 0;
1286 loop {
1287 if builder
1288 .active_session_tx
1289 .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1290 .is_err()
1291 {
1292 break
1293 }
1294 num_fill_messages += 1;
1295 }
1296
1297 tokio::task::spawn(async move {
1298 session.await;
1299 });
1300
1301 tokio::time::sleep(Duration::from_millis(100)).await;
1302
1303 for _ in 0..num_fill_messages {
1304 let message = builder.active_session_rx.next().await.unwrap();
1305 match message {
1306 ActiveSessionMessage::ProtocolBreach { .. } => {}
1307 ev => unreachable!("{ev:?}"),
1308 }
1309 }
1310
1311 let message = builder.active_session_rx.next().await.unwrap();
1312 match message {
1313 ActiveSessionMessage::ValidMessage {
1314 message: PeerMessage::PooledTransactions(_),
1315 ..
1316 } => {}
1317 _ => unreachable!(),
1318 }
1319 }
1320
1321 #[test]
1322 fn timeout_calculation_sanity_tests() {
1323 let rtt = Duration::from_secs(5);
1324 let timeout = rtt * TIMEOUT_SCALING;
1326
1327 assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1329
1330 assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1332 assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1333 assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1334 assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1335 }
1336}