1use core::sync::atomic::Ordering;
4use std::{
5 collections::VecDeque,
6 future::Future,
7 net::SocketAddr,
8 pin::Pin,
9 sync::{atomic::AtomicU64, Arc},
10 task::{ready, Context, Poll},
11 time::{Duration, Instant},
12};
13
14use crate::{
15 message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
16 session::{
17 conn::EthRlpxConnection,
18 handle::{ActiveSessionMessage, SessionCommand},
19 BlockRangeInfo, EthVersion, SessionId,
20 },
21};
22use alloy_eips::merge::EPOCH_SLOTS;
23use alloy_primitives::Sealable;
24use futures::{stream::Fuse, SinkExt, StreamExt};
25use metrics::Gauge;
26use reth_eth_wire::{
27 errors::{EthHandshakeError, EthStreamError},
28 message::{EthBroadcastMessage, MessageError},
29 Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives, NewBlockPayload,
30};
31use reth_eth_wire_types::{message::RequestPair, RawCapabilityMessage};
32use reth_metrics::common::mpsc::MeteredPollSender;
33use reth_network_api::PeerRequest;
34use reth_network_p2p::error::RequestError;
35use reth_network_peers::PeerId;
36use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
37use reth_primitives_traits::Block;
38use rustc_hash::FxHashMap;
39use tokio::{
40 sync::{mpsc::error::TrySendError, oneshot},
41 time::Interval,
42};
43use tokio_stream::wrappers::ReceiverStream;
44use tokio_util::sync::PollSender;
45use tracing::{debug, trace};
46
47pub(super) const RANGE_UPDATE_INTERVAL: Duration = Duration::from_secs(EPOCH_SLOTS * 12);
53
54const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
58
59const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
61const SAMPLE_IMPACT: f64 = 0.1;
63const TIMEOUT_SCALING: u32 = 3;
65
66const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
78
79#[expect(dead_code)]
89pub(crate) struct ActiveSession<N: NetworkPrimitives> {
90 pub(crate) next_id: u64,
92 pub(crate) conn: EthRlpxConnection<N>,
94 pub(crate) remote_peer_id: PeerId,
96 pub(crate) remote_addr: SocketAddr,
98 pub(crate) remote_capabilities: Arc<Capabilities>,
100 pub(crate) session_id: SessionId,
102 pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
104 pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
106 pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
108 pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
110 pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
112 pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
114 pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
116 pub(crate) internal_request_timeout: Arc<AtomicU64>,
118 pub(crate) internal_request_timeout_interval: Interval,
120 pub(crate) protocol_breach_request_timeout: Duration,
123 pub(crate) terminate_message:
125 Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
126 pub(crate) range_info: Option<BlockRangeInfo>,
128 pub(crate) local_range_info: BlockRangeInfo,
131 pub(crate) range_update_interval: Option<Interval>,
135 pub(crate) last_sent_latest_block: Option<u64>,
138}
139
140impl<N: NetworkPrimitives> ActiveSession<N> {
141 fn is_disconnecting(&self) -> bool {
143 self.conn.inner().is_disconnecting()
144 }
145
146 const fn next_id(&mut self) -> u64 {
148 let id = self.next_id;
149 self.next_id += 1;
150 id
151 }
152
153 pub fn shrink_to_fit(&mut self) {
155 self.received_requests_from_remote.shrink_to_fit();
156 self.queued_outgoing.shrink_to_fit();
157 }
158
159 fn queued_response_count(&self) -> usize {
161 self.queued_outgoing.messages.iter().filter(|m| m.is_response()).count()
162 }
163
164 fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
168 macro_rules! on_request {
172 ($req:ident, $resp_item:ident, $req_item:ident) => {{
173 let RequestPair { request_id, message: request } = $req;
174 let (tx, response) = oneshot::channel();
175 let received = ReceivedRequest {
176 request_id,
177 rx: PeerResponse::$resp_item { response },
178 received: Instant::now(),
179 };
180 self.received_requests_from_remote.push(received);
181 self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
182 request,
183 response: tx,
184 }))
185 .into()
186 }};
187 }
188
189 macro_rules! on_response {
191 ($resp:ident, $item:ident) => {{
192 let RequestPair { request_id, message } = $resp;
193 if let Some(req) = self.inflight_requests.remove(&request_id) {
194 match req.request {
195 RequestState::Waiting(PeerRequest::$item { response, .. }) => {
196 trace!(peer_id=?self.remote_peer_id, ?request_id, "received response from peer");
197 let _ = response.send(Ok(message));
198 self.update_request_timeout(req.timestamp, Instant::now());
199 }
200 RequestState::Waiting(request) => {
201 request.send_bad_response();
202 }
203 RequestState::TimedOut => {
204 self.update_request_timeout(req.timestamp, Instant::now());
206 }
207 }
208 } else {
209 trace!(peer_id=?self.remote_peer_id, ?request_id, "received response to unknown request");
210 self.on_bad_message();
212 }
213
214 OnIncomingMessageOutcome::Ok
215 }};
216 }
217
218 match msg {
219 message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
220 error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
221 message,
222 },
223 EthMessage::NewBlockHashes(msg) => {
224 self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
225 }
226 EthMessage::NewBlock(msg) => {
227 let block = NewBlockMessage {
228 hash: msg.block().header().hash_slow(),
229 block: Arc::new(*msg),
230 };
231 self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
232 }
233 EthMessage::Transactions(msg) => {
234 self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
235 }
236 EthMessage::NewPooledTransactionHashes66(msg) => {
237 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
238 }
239 EthMessage::NewPooledTransactionHashes68(msg) => {
240 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
241 }
242 EthMessage::GetBlockHeaders(req) => {
243 on_request!(req, BlockHeaders, GetBlockHeaders)
244 }
245 EthMessage::BlockHeaders(resp) => {
246 on_response!(resp, GetBlockHeaders)
247 }
248 EthMessage::GetBlockBodies(req) => {
249 on_request!(req, BlockBodies, GetBlockBodies)
250 }
251 EthMessage::BlockBodies(resp) => {
252 on_response!(resp, GetBlockBodies)
253 }
254 EthMessage::GetPooledTransactions(req) => {
255 on_request!(req, PooledTransactions, GetPooledTransactions)
256 }
257 EthMessage::PooledTransactions(resp) => {
258 on_response!(resp, GetPooledTransactions)
259 }
260 EthMessage::GetNodeData(req) => {
261 on_request!(req, NodeData, GetNodeData)
262 }
263 EthMessage::NodeData(resp) => {
264 on_response!(resp, GetNodeData)
265 }
266 EthMessage::GetReceipts(req) => {
267 if self.conn.version() >= EthVersion::Eth69 {
268 on_request!(req, Receipts69, GetReceipts69)
269 } else {
270 on_request!(req, Receipts, GetReceipts)
271 }
272 }
273 EthMessage::GetReceipts70(req) => {
274 on_request!(req, Receipts70, GetReceipts70)
275 }
276 EthMessage::Receipts(resp) => {
277 on_response!(resp, GetReceipts)
278 }
279 EthMessage::Receipts69(resp) => {
280 on_response!(resp, GetReceipts69)
281 }
282 EthMessage::Receipts70(resp) => {
283 on_response!(resp, GetReceipts70)
284 }
285 EthMessage::GetBlockAccessLists(req) => {
286 on_request!(req, BlockAccessLists, GetBlockAccessLists)
287 }
288 EthMessage::BlockAccessLists(resp) => {
289 on_response!(resp, GetBlockAccessLists)
290 }
291 EthMessage::BlockRangeUpdate(msg) => {
292 if msg.earliest > msg.latest {
294 return OnIncomingMessageOutcome::BadMessage {
295 error: EthStreamError::InvalidMessage(MessageError::Other(format!(
296 "invalid block range: earliest ({}) > latest ({})",
297 msg.earliest, msg.latest
298 ))),
299 message: EthMessage::BlockRangeUpdate(msg),
300 };
301 }
302
303 if msg.latest_hash.is_zero() {
305 return OnIncomingMessageOutcome::BadMessage {
306 error: EthStreamError::InvalidMessage(MessageError::Other(
307 "invalid block range: latest_hash cannot be zero".to_string(),
308 )),
309 message: EthMessage::BlockRangeUpdate(msg),
310 };
311 }
312
313 if let Some(range_info) = self.range_info.as_ref() {
314 range_info.update(msg.earliest, msg.latest, msg.latest_hash);
315 }
316
317 OnIncomingMessageOutcome::Ok
318 }
319 EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
320 }
321 }
322
323 fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
325 let version = self.conn.version();
326 if !Self::is_request_supported_for_version(&request, version) {
327 debug!(
328 target: "net",
329 ?request,
330 peer_id=?self.remote_peer_id,
331 ?version,
332 "Request not supported for negotiated eth version",
333 );
334 request.send_err_response(RequestError::UnsupportedCapability);
335 return;
336 }
337
338 let request_id = self.next_id();
339 trace!(?request, peer_id=?self.remote_peer_id, ?request_id, "sending request to peer");
340 let msg = request.create_request_message(request_id).map_versioned(version);
341
342 self.queued_outgoing.push_back(msg.into());
343 let req = InflightRequest {
344 request: RequestState::Waiting(request),
345 timestamp: Instant::now(),
346 deadline,
347 };
348 self.inflight_requests.insert(request_id, req);
349 }
350
351 #[inline]
352 fn is_request_supported_for_version(request: &PeerRequest<N>, version: EthVersion) -> bool {
353 request.is_supported_by_eth_version(version)
354 }
355
356 fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
358 match msg {
359 PeerMessage::NewBlockHashes(msg) => {
360 self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
361 }
362 PeerMessage::NewBlock(msg) => {
363 self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
364 }
365 PeerMessage::PooledTransactions(msg) => {
366 if msg.is_valid_for_version(self.conn.version()) {
367 self.queued_outgoing.push_back(EthMessage::from(msg).into());
368 } else {
369 debug!(target: "net", ?msg, version=?self.conn.version(), "Message is invalid for connection version, skipping");
370 }
371 }
372 PeerMessage::EthRequest(req) => {
373 let deadline = self.request_deadline();
374 self.on_internal_peer_request(req, deadline);
375 }
376 PeerMessage::SendTransactions(msg) => {
377 self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
378 }
379 PeerMessage::BlockRangeUpdated(_) => {}
380 PeerMessage::ReceivedTransaction(_) => {
381 unreachable!("Not emitted by network")
382 }
383 PeerMessage::Other(other) => {
384 self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
385 }
386 }
387 }
388
389 fn request_deadline(&self) -> Instant {
391 Instant::now() +
392 Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
393 }
394
395 fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
399 match resp.try_into_message(id) {
400 Ok(msg) => {
401 self.queued_outgoing.push_back(msg.into());
402 }
403 Err(err) => {
404 debug!(target: "net", %err, "Failed to respond to received request");
405 }
406 }
407 }
408
409 #[expect(clippy::result_large_err)]
413 fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
414 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
415
416 match sender
417 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
418 {
419 Ok(_) => Ok(()),
420 Err(err) => {
421 trace!(
422 target: "net",
423 %err,
424 "no capacity for incoming broadcast",
425 );
426 match err {
427 TrySendError::Full(msg) => Err(msg),
428 TrySendError::Closed(_) => Ok(()),
429 }
430 }
431 }
432 }
433
434 #[expect(clippy::result_large_err)]
439 fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
440 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
441
442 match sender
443 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
444 {
445 Ok(_) => Ok(()),
446 Err(err) => {
447 trace!(
448 target: "net",
449 %err,
450 "no capacity for incoming request",
451 );
452 match err {
453 TrySendError::Full(msg) => Err(msg),
454 TrySendError::Closed(_) => {
455 Ok(())
458 }
459 }
460 }
461 }
462 }
463
464 fn on_bad_message(&self) {
466 let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
467 let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
468 }
469
470 fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
472 trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
473 let msg = ActiveSessionMessage::Disconnected {
474 peer_id: self.remote_peer_id,
475 remote_addr: self.remote_addr,
476 };
477
478 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
479 self.poll_terminate_message(cx).expect("message is set")
480 }
481
482 fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
484 let msg = ActiveSessionMessage::ClosedOnConnectionError {
485 peer_id: self.remote_peer_id,
486 remote_addr: self.remote_addr,
487 error,
488 };
489 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
490 self.poll_terminate_message(cx).expect("message is set")
491 }
492
493 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
495 Ok(self.conn.inner_mut().start_disconnect(reason)?)
496 }
497
498 fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
500 debug_assert!(self.is_disconnecting(), "not disconnecting");
501
502 let _ = ready!(self.conn.poll_close_unpin(cx));
504 self.emit_disconnect(cx)
505 }
506
507 fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
509 match self.start_disconnect(reason) {
510 Ok(()) => {
511 self.poll_disconnect(cx)
513 }
514 Err(err) => {
515 debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
516 self.close_on_error(err, cx)
517 }
518 }
519 }
520
521 #[must_use]
530 fn check_timed_out_requests(&mut self, now: Instant) -> bool {
531 for (id, req) in &mut self.inflight_requests {
532 if req.is_timed_out(now) {
533 if req.is_waiting() {
534 debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
535 req.timeout();
536 } else if now - req.timestamp > self.protocol_breach_request_timeout {
537 return true
538 }
539 }
540 }
541
542 false
543 }
544
545 fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
547 let elapsed = received.saturating_duration_since(sent);
548
549 let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
550 let request_timeout = calculate_new_timeout(current, elapsed);
551 self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
552 self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
553 }
554
555 fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
557 let (mut tx, msg) = self.terminate_message.take()?;
558 match tx.poll_reserve(cx) {
559 Poll::Pending => {
560 self.terminate_message = Some((tx, msg));
561 return Some(Poll::Pending)
562 }
563 Poll::Ready(Ok(())) => {
564 let _ = tx.send_item(msg);
565 }
566 Poll::Ready(Err(_)) => {
567 }
569 }
570 Some(Poll::Ready(()))
572 }
573}
574
575impl<N: NetworkPrimitives> Future for ActiveSession<N> {
576 type Output = ();
577
578 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
579 let this = self.get_mut();
580
581 if let Some(terminate) = this.poll_terminate_message(cx) {
583 return terminate
584 }
585
586 if this.is_disconnecting() {
587 return this.poll_disconnect(cx)
588 }
589
590 let mut budget = 4;
596
597 'main: loop {
599 let mut progress = false;
600
601 loop {
603 match this.commands_rx.poll_next_unpin(cx) {
604 Poll::Pending => break,
605 Poll::Ready(None) => {
606 return Poll::Ready(())
609 }
610 Poll::Ready(Some(cmd)) => {
611 progress = true;
612 match cmd {
613 SessionCommand::Disconnect { reason } => {
614 debug!(
615 target: "net::session",
616 ?reason,
617 remote_peer_id=?this.remote_peer_id,
618 "Received disconnect command for session"
619 );
620 let reason =
621 reason.unwrap_or(DisconnectReason::DisconnectRequested);
622
623 return this.try_disconnect(reason, cx)
624 }
625 SessionCommand::Message(msg) => {
626 this.on_internal_peer_message(msg);
627 }
628 }
629 }
630 }
631 }
632
633 let deadline = this.request_deadline();
634
635 while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
636 progress = true;
637 this.on_internal_peer_request(req, deadline);
638 }
639
640 for idx in (0..this.received_requests_from_remote.len()).rev() {
643 let mut req = this.received_requests_from_remote.swap_remove(idx);
644 match req.rx.poll(cx) {
645 Poll::Pending => {
646 this.received_requests_from_remote.push(req);
648 }
649 Poll::Ready(resp) => {
650 this.handle_outgoing_response(req.request_id, resp);
651 }
652 }
653 }
654
655 while this.conn.poll_ready_unpin(cx).is_ready() {
657 if let Some(msg) = this.queued_outgoing.pop_front() {
658 progress = true;
659 let res = match msg {
660 OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
661 OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
662 OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
663 };
664 if let Err(err) = res {
665 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
666 return this.close_on_error(err, cx)
668 }
669 } else {
670 break
672 }
673 }
674
675 'receive: loop {
677 budget -= 1;
679 if budget == 0 {
680 cx.waker().wake_by_ref();
682 break 'main
683 }
684
685 if let Some(msg) = this.pending_message_to_session.take() {
689 match this.to_session_manager.poll_reserve(cx) {
690 Poll::Ready(Ok(_)) => {
691 let _ = this.to_session_manager.send_item(msg);
692 }
693 Poll::Ready(Err(_)) => return Poll::Ready(()),
694 Poll::Pending => {
695 this.pending_message_to_session = Some(msg);
696 break 'receive
697 }
698 };
699 }
700
701 if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
703 break 'receive
709 }
710
711 if this.queued_outgoing.messages.len() > MAX_QUEUED_OUTGOING_RESPONSES &&
713 this.queued_response_count() > MAX_QUEUED_OUTGOING_RESPONSES
714 {
715 break 'receive
722 }
723
724 match this.conn.poll_next_unpin(cx) {
725 Poll::Pending => break,
726 Poll::Ready(None) => {
727 if this.is_disconnecting() {
728 break
729 }
730 debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
731 return this.emit_disconnect(cx)
732 }
733 Poll::Ready(Some(res)) => {
734 match res {
735 Ok(msg) => {
736 trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
737 match this.on_incoming_message(msg) {
739 OnIncomingMessageOutcome::Ok => {
740 progress = true;
742 }
743 OnIncomingMessageOutcome::BadMessage { error, message } => {
744 debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
745 return this.close_on_error(error, cx)
746 }
747 OnIncomingMessageOutcome::NoCapacity(msg) => {
748 this.pending_message_to_session = Some(msg);
750 }
751 }
752 }
753 Err(err) => {
754 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
755 return this.close_on_error(err, cx)
756 }
757 }
758 }
759 }
760 }
761
762 if !progress {
763 break 'main
764 }
765 }
766
767 if let Some(interval) = &mut this.range_update_interval {
768 while interval.poll_tick(cx).is_ready() {
770 let current_latest = this.local_range_info.latest();
771 let should_send = if let Some(last_sent) = this.last_sent_latest_block {
772 current_latest.saturating_sub(last_sent) >= EPOCH_SLOTS
774 } else {
775 true };
777
778 if should_send {
779 this.queued_outgoing.push_back(
780 EthMessage::BlockRangeUpdate(this.local_range_info.to_message()).into(),
781 );
782 this.last_sent_latest_block = Some(current_latest);
783 }
784 }
785 }
786
787 while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
788 if this.check_timed_out_requests(Instant::now()) &&
790 let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx)
791 {
792 let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
793 this.pending_message_to_session = Some(msg);
794 }
795 }
796
797 this.shrink_to_fit();
798
799 Poll::Pending
800 }
801}
802
803pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
805 request_id: u64,
807 rx: PeerResponse<N>,
809 #[expect(dead_code)]
811 received: Instant,
812}
813
814pub(crate) struct InflightRequest<R> {
816 request: RequestState<R>,
818 timestamp: Instant,
820 deadline: Instant,
822}
823
824impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
827 #[inline]
829 fn is_timed_out(&self, now: Instant) -> bool {
830 now > self.deadline
831 }
832
833 #[inline]
835 const fn is_waiting(&self) -> bool {
836 matches!(self.request, RequestState::Waiting(_))
837 }
838
839 fn timeout(&mut self) {
841 let mut req = RequestState::TimedOut;
842 std::mem::swap(&mut self.request, &mut req);
843
844 if let RequestState::Waiting(req) = req {
845 req.send_err_response(RequestError::Timeout);
846 }
847 }
848}
849
850enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
852 Ok,
854 BadMessage { error: EthStreamError, message: EthMessage<N> },
856 NoCapacity(ActiveSessionMessage<N>),
858}
859
860impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
861 for OnIncomingMessageOutcome<N>
862{
863 fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
864 match res {
865 Ok(_) => Self::Ok,
866 Err(msg) => Self::NoCapacity(msg),
867 }
868 }
869}
870
871enum RequestState<R> {
872 Waiting(R),
874 TimedOut,
876}
877
878#[derive(Debug)]
880pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
881 Eth(EthMessage<N>),
883 Broadcast(EthBroadcastMessage<N>),
885 Raw(RawCapabilityMessage),
887}
888
889impl<N: NetworkPrimitives> OutgoingMessage<N> {
890 const fn is_response(&self) -> bool {
892 match self {
893 Self::Eth(msg) => msg.is_response(),
894 _ => false,
895 }
896 }
897}
898
899impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
900 fn from(value: EthMessage<N>) -> Self {
901 Self::Eth(value)
902 }
903}
904
905impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
906 fn from(value: EthBroadcastMessage<N>) -> Self {
907 Self::Broadcast(value)
908 }
909}
910
911#[inline]
913fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
914 let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
915
916 let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
918
919 smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
920}
921
922pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
924 messages: VecDeque<OutgoingMessage<N>>,
925 count: Gauge,
926}
927
928impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
929 pub(crate) const fn new(metric: Gauge) -> Self {
930 Self { messages: VecDeque::new(), count: metric }
931 }
932
933 pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
934 self.messages.push_back(message);
935 self.count.increment(1);
936 }
937
938 pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
939 self.messages.pop_front().inspect(|_| self.count.decrement(1))
940 }
941
942 pub(crate) fn shrink_to_fit(&mut self) {
943 self.messages.shrink_to_fit();
944 }
945}
946
947impl<N: NetworkPrimitives> Drop for QueuedOutgoingMessages<N> {
948 fn drop(&mut self) {
949 let remaining = self.messages.len();
951 if remaining > 0 {
952 self.count.decrement(remaining as f64);
953 }
954 }
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960 use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
961 use alloy_eips::eip2124::ForkFilter;
962 use reth_chainspec::MAINNET;
963 use reth_ecies::stream::ECIESStream;
964 use reth_eth_wire::{
965 handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockAccessLists,
966 GetBlockBodies, HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream,
967 UnauthedP2PStream, UnifiedStatus,
968 };
969 use reth_ethereum_forks::EthereumHardfork;
970 use reth_network_peers::pk2id;
971 use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
972 use secp256k1::{SecretKey, SECP256K1};
973 use tokio::{
974 net::{TcpListener, TcpStream},
975 sync::mpsc,
976 };
977
978 fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
980 HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
981 }
982
983 struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
984 _remote_capabilities: Arc<Capabilities>,
985 active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
986 active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
987 to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
988 secret_key: SecretKey,
989 local_peer_id: PeerId,
990 hello: HelloMessageWithProtocols,
991 status: UnifiedStatus,
992 fork_filter: ForkFilter,
993 next_id: usize,
994 }
995
996 impl<N: NetworkPrimitives> SessionBuilder<N> {
997 fn next_id(&mut self) -> SessionId {
998 let id = self.next_id;
999 self.next_id += 1;
1000 SessionId(id)
1001 }
1002
1003 fn with_client_stream<F, O>(
1005 &self,
1006 local_addr: SocketAddr,
1007 f: F,
1008 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1009 where
1010 F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
1011 O: Future<Output = ()> + Send + Sync,
1012 {
1013 let mut status = self.status;
1014 let fork_filter = self.fork_filter.clone();
1015 let local_peer_id = self.local_peer_id;
1016 let mut hello = self.hello.clone();
1017 let key = SecretKey::new(&mut rand_08::thread_rng());
1018 hello.id = pk2id(&key.public_key(SECP256K1));
1019 Box::pin(async move {
1020 let outgoing = TcpStream::connect(local_addr).await.unwrap();
1021 let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
1022
1023 let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
1024
1025 let eth_version = p2p_stream.shared_capabilities().eth_version().unwrap();
1026 status.set_eth_version(eth_version);
1027
1028 let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
1029 .handshake(status, fork_filter)
1030 .await
1031 .unwrap();
1032 f(client_stream).await
1033 })
1034 }
1035
1036 async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
1037 let remote_addr = stream.local_addr().unwrap();
1038 let session_id = self.next_id();
1039 let (_disconnect_tx, disconnect_rx) = oneshot::channel();
1040 let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
1041
1042 tokio::task::spawn(start_pending_incoming_session(
1043 Arc::new(EthHandshake::default()),
1044 disconnect_rx,
1045 session_id,
1046 stream,
1047 pending_sessions_tx,
1048 remote_addr,
1049 self.secret_key,
1050 self.hello.clone(),
1051 self.status,
1052 self.fork_filter.clone(),
1053 Default::default(),
1054 ));
1055
1056 let mut stream = ReceiverStream::new(pending_sessions_rx);
1057
1058 match stream.next().await.unwrap() {
1059 PendingSessionEvent::Established {
1060 session_id,
1061 remote_addr,
1062 peer_id,
1063 capabilities,
1064 conn,
1065 ..
1066 } => {
1067 let (_to_session_tx, messages_rx) = mpsc::channel(10);
1068 let (commands_to_session, commands_rx) = mpsc::channel(10);
1069 let poll_sender = PollSender::new(self.active_session_tx.clone());
1070
1071 self.to_sessions.push(commands_to_session);
1072
1073 ActiveSession {
1074 next_id: 0,
1075 remote_peer_id: peer_id,
1076 remote_addr,
1077 remote_capabilities: Arc::clone(&capabilities),
1078 session_id,
1079 commands_rx: ReceiverStream::new(commands_rx),
1080 to_session_manager: MeteredPollSender::new(
1081 poll_sender,
1082 "network_active_session",
1083 ),
1084 pending_message_to_session: None,
1085 internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
1086 inflight_requests: Default::default(),
1087 conn,
1088 queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
1089 received_requests_from_remote: Default::default(),
1090 internal_request_timeout_interval: tokio::time::interval(
1091 INITIAL_REQUEST_TIMEOUT,
1092 ),
1093 internal_request_timeout: Arc::new(AtomicU64::new(
1094 INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
1095 )),
1096 protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
1097 terminate_message: None,
1098 range_info: None,
1099 local_range_info: BlockRangeInfo::new(
1100 0,
1101 1000,
1102 alloy_primitives::B256::ZERO,
1103 ),
1104 range_update_interval: None,
1105 last_sent_latest_block: None,
1106 }
1107 }
1108 ev => {
1109 panic!("unexpected message {ev:?}")
1110 }
1111 }
1112 }
1113 }
1114
1115 impl Default for SessionBuilder {
1116 fn default() -> Self {
1117 let (active_session_tx, active_session_rx) = mpsc::channel(100);
1118
1119 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1120 let local_peer_id = pk2id(&pk);
1121
1122 Self {
1123 next_id: 0,
1124 _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1125 active_session_tx,
1126 active_session_rx: ReceiverStream::new(active_session_rx),
1127 to_sessions: vec![],
1128 hello: eth_hello(&secret_key),
1129 secret_key,
1130 local_peer_id,
1131 status: StatusBuilder::default().build(),
1132 fork_filter: MAINNET
1133 .hardfork_fork_filter(EthereumHardfork::Frontier)
1134 .expect("The Frontier fork filter should exist on mainnet"),
1135 }
1136 }
1137 }
1138
1139 #[tokio::test(flavor = "multi_thread")]
1140 async fn test_disconnect() {
1141 let mut builder = SessionBuilder::default();
1142
1143 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1144 let local_addr = listener.local_addr().unwrap();
1145
1146 let expected_disconnect = DisconnectReason::UselessPeer;
1147
1148 let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1149 let msg = client_stream.next().await.unwrap().unwrap_err();
1150 assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1151 });
1152
1153 tokio::task::spawn(async move {
1154 let (incoming, _) = listener.accept().await.unwrap();
1155 let mut session = builder.connect_incoming(incoming).await;
1156
1157 session.start_disconnect(expected_disconnect).unwrap();
1158 session.await
1159 });
1160
1161 fut.await;
1162 }
1163
1164 #[tokio::test(flavor = "multi_thread")]
1165 async fn handle_dropped_stream() {
1166 let mut builder = SessionBuilder::default();
1167
1168 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1169 let local_addr = listener.local_addr().unwrap();
1170
1171 let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1172 drop(client_stream);
1173 tokio::time::sleep(Duration::from_secs(1)).await
1174 });
1175
1176 let (tx, rx) = oneshot::channel();
1177
1178 tokio::task::spawn(async move {
1179 let (incoming, _) = listener.accept().await.unwrap();
1180 let session = builder.connect_incoming(incoming).await;
1181 session.await;
1182
1183 tx.send(()).unwrap();
1184 });
1185
1186 tokio::task::spawn(fut);
1187
1188 rx.await.unwrap();
1189 }
1190
1191 #[tokio::test(flavor = "multi_thread")]
1192 async fn test_send_many_messages() {
1193 reth_tracing::init_test_tracing();
1194 let mut builder = SessionBuilder::default();
1195
1196 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1197 let local_addr = listener.local_addr().unwrap();
1198
1199 let num_messages = 100;
1200
1201 let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1202 for _ in 0..num_messages {
1203 client_stream
1204 .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1205 .await
1206 .unwrap();
1207 }
1208 });
1209
1210 let (tx, rx) = oneshot::channel();
1211
1212 tokio::task::spawn(async move {
1213 let (incoming, _) = listener.accept().await.unwrap();
1214 let session = builder.connect_incoming(incoming).await;
1215 session.await;
1216
1217 tx.send(()).unwrap();
1218 });
1219
1220 tokio::task::spawn(fut);
1221
1222 rx.await.unwrap();
1223 }
1224
1225 #[tokio::test(flavor = "multi_thread")]
1226 async fn test_request_timeout() {
1227 reth_tracing::init_test_tracing();
1228
1229 let mut builder = SessionBuilder::default();
1230
1231 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1232 let local_addr = listener.local_addr().unwrap();
1233
1234 let request_timeout = Duration::from_millis(100);
1235 let drop_timeout = Duration::from_millis(1500);
1236
1237 let fut = builder.with_client_stream(local_addr, async move |client_stream| {
1238 let _client_stream = client_stream;
1239 tokio::time::sleep(drop_timeout * 60).await;
1240 });
1241 tokio::task::spawn(fut);
1242
1243 let (incoming, _) = listener.accept().await.unwrap();
1244 let mut session = builder.connect_incoming(incoming).await;
1245 session
1246 .internal_request_timeout
1247 .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1248 session.protocol_breach_request_timeout = drop_timeout;
1249 session.internal_request_timeout_interval =
1250 tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1251 let (tx, rx) = oneshot::channel();
1252 let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1253 session.on_internal_peer_request(req, Instant::now());
1254 tokio::spawn(session);
1255
1256 let err = rx.await.unwrap().unwrap_err();
1257 assert_eq!(err, RequestError::Timeout);
1258
1259 let msg = builder.active_session_rx.next().await.unwrap();
1261 match msg {
1262 ActiveSessionMessage::ProtocolBreach { .. } => {}
1263 ev => unreachable!("{ev:?}"),
1264 }
1265 }
1266
1267 #[test]
1268 fn test_reject_bal_request_for_eth70() {
1269 let (tx, _rx) = oneshot::channel();
1270 let request: PeerRequest<EthNetworkPrimitives> =
1271 PeerRequest::GetBlockAccessLists { request: GetBlockAccessLists(vec![]), response: tx };
1272
1273 assert!(!ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1274 &request,
1275 EthVersion::Eth70
1276 ));
1277 assert!(ActiveSession::<EthNetworkPrimitives>::is_request_supported_for_version(
1278 &request,
1279 EthVersion::Eth71
1280 ));
1281 }
1282
1283 #[tokio::test(flavor = "multi_thread")]
1284 async fn test_keep_alive() {
1285 let mut builder = SessionBuilder::default();
1286
1287 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1288 let local_addr = listener.local_addr().unwrap();
1289
1290 let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1291 let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1292 client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1293 });
1294
1295 let (tx, rx) = oneshot::channel();
1296
1297 tokio::task::spawn(async move {
1298 let (incoming, _) = listener.accept().await.unwrap();
1299 let session = builder.connect_incoming(incoming).await;
1300 session.await;
1301
1302 tx.send(()).unwrap();
1303 });
1304
1305 tokio::task::spawn(fut);
1306
1307 rx.await.unwrap();
1308 }
1309
1310 #[tokio::test(flavor = "multi_thread")]
1312 async fn test_send_at_capacity() {
1313 let mut builder = SessionBuilder::default();
1314
1315 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1316 let local_addr = listener.local_addr().unwrap();
1317
1318 let fut = builder.with_client_stream(local_addr, async move |mut client_stream| {
1319 client_stream
1320 .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1321 .await
1322 .unwrap();
1323 let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1324 });
1325 tokio::task::spawn(fut);
1326
1327 let (incoming, _) = listener.accept().await.unwrap();
1328 let session = builder.connect_incoming(incoming).await;
1329
1330 let mut num_fill_messages = 0;
1332 loop {
1333 if builder
1334 .active_session_tx
1335 .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1336 .is_err()
1337 {
1338 break
1339 }
1340 num_fill_messages += 1;
1341 }
1342
1343 tokio::task::spawn(async move {
1344 session.await;
1345 });
1346
1347 tokio::time::sleep(Duration::from_millis(100)).await;
1348
1349 for _ in 0..num_fill_messages {
1350 let message = builder.active_session_rx.next().await.unwrap();
1351 match message {
1352 ActiveSessionMessage::ProtocolBreach { .. } => {}
1353 ev => unreachable!("{ev:?}"),
1354 }
1355 }
1356
1357 let message = builder.active_session_rx.next().await.unwrap();
1358 match message {
1359 ActiveSessionMessage::ValidMessage {
1360 message: PeerMessage::PooledTransactions(_),
1361 ..
1362 } => {}
1363 _ => unreachable!(),
1364 }
1365 }
1366
1367 #[test]
1368 fn timeout_calculation_sanity_tests() {
1369 let rtt = Duration::from_secs(5);
1370 let timeout = rtt * TIMEOUT_SCALING;
1372
1373 assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1375
1376 assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1378 assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1379 assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1380 assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1381 }
1382}