1use core::sync::atomic::Ordering;
4use std::{
5 collections::VecDeque,
6 future::Future,
7 net::SocketAddr,
8 pin::Pin,
9 sync::{atomic::AtomicU64, Arc},
10 task::{ready, Context, Poll},
11 time::{Duration, Instant},
12};
13
14use crate::{
15 message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
16 session::{
17 conn::EthRlpxConnection,
18 handle::{ActiveSessionMessage, SessionCommand},
19 BlockRangeInfo, EthVersion, SessionId,
20 },
21};
22use alloy_primitives::Sealable;
23use futures::{stream::Fuse, SinkExt, StreamExt};
24use metrics::Gauge;
25use reth_eth_wire::{
26 errors::{EthHandshakeError, EthStreamError},
27 message::{EthBroadcastMessage, MessageError, RequestPair},
28 Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives, NewBlockPayload,
29};
30use reth_eth_wire_types::RawCapabilityMessage;
31use reth_metrics::common::mpsc::MeteredPollSender;
32use reth_network_api::PeerRequest;
33use reth_network_p2p::error::RequestError;
34use reth_network_peers::PeerId;
35use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
36use reth_primitives_traits::Block;
37use rustc_hash::FxHashMap;
38use tokio::{
39 sync::{mpsc::error::TrySendError, oneshot},
40 time::Interval,
41};
42use tokio_stream::wrappers::ReceiverStream;
43use tokio_util::sync::PollSender;
44use tracing::{debug, trace};
45
46pub(super) const RANGE_UPDATE_INTERVAL: Duration = Duration::from_secs(120);
50
51const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
55
56const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
58const SAMPLE_IMPACT: f64 = 0.1;
60const TIMEOUT_SCALING: u32 = 3;
62
63const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
75
76#[expect(dead_code)]
86pub(crate) struct ActiveSession<N: NetworkPrimitives> {
87 pub(crate) next_id: u64,
89 pub(crate) conn: EthRlpxConnection<N>,
91 pub(crate) remote_peer_id: PeerId,
93 pub(crate) remote_addr: SocketAddr,
95 pub(crate) remote_capabilities: Arc<Capabilities>,
97 pub(crate) session_id: SessionId,
99 pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
101 pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
103 pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
105 pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
107 pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
109 pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
111 pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
113 pub(crate) internal_request_timeout: Arc<AtomicU64>,
115 pub(crate) internal_request_timeout_interval: Interval,
117 pub(crate) protocol_breach_request_timeout: Duration,
120 pub(crate) terminate_message:
122 Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
123 pub(crate) range_info: Option<BlockRangeInfo>,
125 pub(crate) local_range_info: BlockRangeInfo,
128 pub(crate) range_update_interval: Option<Interval>,
131}
132
133impl<N: NetworkPrimitives> ActiveSession<N> {
134 fn is_disconnecting(&self) -> bool {
136 self.conn.inner().is_disconnecting()
137 }
138
139 const fn next_id(&mut self) -> u64 {
141 let id = self.next_id;
142 self.next_id += 1;
143 id
144 }
145
146 pub fn shrink_to_fit(&mut self) {
148 self.received_requests_from_remote.shrink_to_fit();
149 self.queued_outgoing.shrink_to_fit();
150 }
151
152 fn queued_response_count(&self) -> usize {
154 self.queued_outgoing.messages.iter().filter(|m| m.is_response()).count()
155 }
156
157 fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
161 macro_rules! on_request {
165 ($req:ident, $resp_item:ident, $req_item:ident) => {{
166 let RequestPair { request_id, message: request } = $req;
167 let (tx, response) = oneshot::channel();
168 let received = ReceivedRequest {
169 request_id,
170 rx: PeerResponse::$resp_item { response },
171 received: Instant::now(),
172 };
173 self.received_requests_from_remote.push(received);
174 self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
175 request,
176 response: tx,
177 }))
178 .into()
179 }};
180 }
181
182 macro_rules! on_response {
184 ($resp:ident, $item:ident) => {{
185 let RequestPair { request_id, message } = $resp;
186 if let Some(req) = self.inflight_requests.remove(&request_id) {
187 match req.request {
188 RequestState::Waiting(PeerRequest::$item { response, .. }) => {
189 trace!(peer_id=?self.remote_peer_id, ?request_id, "received response from peer");
190 let _ = response.send(Ok(message));
191 self.update_request_timeout(req.timestamp, Instant::now());
192 }
193 RequestState::Waiting(request) => {
194 request.send_bad_response();
195 }
196 RequestState::TimedOut => {
197 self.update_request_timeout(req.timestamp, Instant::now());
199 }
200 }
201 } else {
202 trace!(peer_id=?self.remote_peer_id, ?request_id, "received response to unknown request");
203 self.on_bad_message();
205 }
206
207 OnIncomingMessageOutcome::Ok
208 }};
209 }
210
211 match msg {
212 message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
213 error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
214 message,
215 },
216 EthMessage::NewBlockHashes(msg) => {
217 self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
218 }
219 EthMessage::NewBlock(msg) => {
220 let block = NewBlockMessage {
221 hash: msg.block().header().hash_slow(),
222 block: Arc::new(*msg),
223 };
224 self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
225 }
226 EthMessage::Transactions(msg) => {
227 self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
228 }
229 EthMessage::NewPooledTransactionHashes66(msg) => {
230 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
231 }
232 EthMessage::NewPooledTransactionHashes68(msg) => {
233 if msg.hashes.len() != msg.types.len() || msg.hashes.len() != msg.sizes.len() {
234 return OnIncomingMessageOutcome::BadMessage {
235 error: EthStreamError::TransactionHashesInvalidLenOfFields {
236 hashes_len: msg.hashes.len(),
237 types_len: msg.types.len(),
238 sizes_len: msg.sizes.len(),
239 },
240 message: EthMessage::NewPooledTransactionHashes68(msg),
241 }
242 }
243 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
244 }
245 EthMessage::GetBlockHeaders(req) => {
246 on_request!(req, BlockHeaders, GetBlockHeaders)
247 }
248 EthMessage::BlockHeaders(resp) => {
249 on_response!(resp, GetBlockHeaders)
250 }
251 EthMessage::GetBlockBodies(req) => {
252 on_request!(req, BlockBodies, GetBlockBodies)
253 }
254 EthMessage::BlockBodies(resp) => {
255 on_response!(resp, GetBlockBodies)
256 }
257 EthMessage::GetPooledTransactions(req) => {
258 on_request!(req, PooledTransactions, GetPooledTransactions)
259 }
260 EthMessage::PooledTransactions(resp) => {
261 on_response!(resp, GetPooledTransactions)
262 }
263 EthMessage::GetNodeData(req) => {
264 on_request!(req, NodeData, GetNodeData)
265 }
266 EthMessage::NodeData(resp) => {
267 on_response!(resp, GetNodeData)
268 }
269 EthMessage::GetReceipts(req) => {
270 if self.conn.version() >= EthVersion::Eth69 {
271 on_request!(req, Receipts69, GetReceipts69)
272 } else {
273 on_request!(req, Receipts, GetReceipts)
274 }
275 }
276 EthMessage::Receipts(resp) => {
277 on_response!(resp, GetReceipts)
278 }
279 EthMessage::Receipts69(resp) => {
280 on_response!(resp, GetReceipts69)
281 }
282 EthMessage::BlockRangeUpdate(msg) => {
283 if msg.earliest > msg.latest {
285 return OnIncomingMessageOutcome::BadMessage {
286 error: EthStreamError::InvalidMessage(MessageError::Other(format!(
287 "invalid block range: earliest ({}) > latest ({})",
288 msg.earliest, msg.latest
289 ))),
290 message: EthMessage::BlockRangeUpdate(msg),
291 };
292 }
293
294 if let Some(range_info) = self.range_info.as_ref() {
295 range_info.update(msg.earliest, msg.latest, msg.latest_hash);
296 }
297
298 OnIncomingMessageOutcome::Ok
299 }
300 EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
301 }
302 }
303
304 fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
306 let request_id = self.next_id();
307
308 trace!(?request, peer_id=?self.remote_peer_id, ?request_id, "sending request to peer");
309 let msg = request.create_request_message(request_id);
310 self.queued_outgoing.push_back(msg.into());
311 let req = InflightRequest {
312 request: RequestState::Waiting(request),
313 timestamp: Instant::now(),
314 deadline,
315 };
316 self.inflight_requests.insert(request_id, req);
317 }
318
319 fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
321 match msg {
322 PeerMessage::NewBlockHashes(msg) => {
323 self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
324 }
325 PeerMessage::NewBlock(msg) => {
326 self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
327 }
328 PeerMessage::PooledTransactions(msg) => {
329 if msg.is_valid_for_version(self.conn.version()) {
330 self.queued_outgoing.push_back(EthMessage::from(msg).into());
331 } else {
332 debug!(target: "net", ?msg, version=?self.conn.version(), "Message is invalid for connection version, skipping");
333 }
334 }
335 PeerMessage::EthRequest(req) => {
336 let deadline = self.request_deadline();
337 self.on_internal_peer_request(req, deadline);
338 }
339 PeerMessage::SendTransactions(msg) => {
340 self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
341 }
342 PeerMessage::BlockRangeUpdated(_) => {}
343 PeerMessage::ReceivedTransaction(_) => {
344 unreachable!("Not emitted by network")
345 }
346 PeerMessage::Other(other) => {
347 self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
348 }
349 }
350 }
351
352 fn request_deadline(&self) -> Instant {
354 Instant::now() +
355 Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
356 }
357
358 fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
362 match resp.try_into_message(id) {
363 Ok(msg) => {
364 self.queued_outgoing.push_back(msg.into());
365 }
366 Err(err) => {
367 debug!(target: "net", %err, "Failed to respond to received request");
368 }
369 }
370 }
371
372 #[expect(clippy::result_large_err)]
376 fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
377 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
378
379 match sender
380 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
381 {
382 Ok(_) => Ok(()),
383 Err(err) => {
384 trace!(
385 target: "net",
386 %err,
387 "no capacity for incoming broadcast",
388 );
389 match err {
390 TrySendError::Full(msg) => Err(msg),
391 TrySendError::Closed(_) => Ok(()),
392 }
393 }
394 }
395 }
396
397 #[expect(clippy::result_large_err)]
402 fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
403 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
404
405 match sender
406 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
407 {
408 Ok(_) => Ok(()),
409 Err(err) => {
410 trace!(
411 target: "net",
412 %err,
413 "no capacity for incoming request",
414 );
415 match err {
416 TrySendError::Full(msg) => Err(msg),
417 TrySendError::Closed(_) => {
418 Ok(())
421 }
422 }
423 }
424 }
425 }
426
427 fn on_bad_message(&self) {
429 let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
430 let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
431 }
432
433 fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
435 trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
436 let msg = ActiveSessionMessage::Disconnected {
437 peer_id: self.remote_peer_id,
438 remote_addr: self.remote_addr,
439 };
440
441 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
442 self.poll_terminate_message(cx).expect("message is set")
443 }
444
445 fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
447 let msg = ActiveSessionMessage::ClosedOnConnectionError {
448 peer_id: self.remote_peer_id,
449 remote_addr: self.remote_addr,
450 error,
451 };
452 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
453 self.poll_terminate_message(cx).expect("message is set")
454 }
455
456 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
458 Ok(self.conn.inner_mut().start_disconnect(reason)?)
459 }
460
461 fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
463 debug_assert!(self.is_disconnecting(), "not disconnecting");
464
465 let _ = ready!(self.conn.poll_close_unpin(cx));
467 self.emit_disconnect(cx)
468 }
469
470 fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
472 match self.start_disconnect(reason) {
473 Ok(()) => {
474 self.poll_disconnect(cx)
476 }
477 Err(err) => {
478 debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
479 self.close_on_error(err, cx)
480 }
481 }
482 }
483
484 #[must_use]
493 fn check_timed_out_requests(&mut self, now: Instant) -> bool {
494 for (id, req) in &mut self.inflight_requests {
495 if req.is_timed_out(now) {
496 if req.is_waiting() {
497 debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
498 req.timeout();
499 } else if now - req.timestamp > self.protocol_breach_request_timeout {
500 return true
501 }
502 }
503 }
504
505 false
506 }
507
508 fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
510 let elapsed = received.saturating_duration_since(sent);
511
512 let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
513 let request_timeout = calculate_new_timeout(current, elapsed);
514 self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
515 self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
516 }
517
518 fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
520 let (mut tx, msg) = self.terminate_message.take()?;
521 match tx.poll_reserve(cx) {
522 Poll::Pending => {
523 self.terminate_message = Some((tx, msg));
524 return Some(Poll::Pending)
525 }
526 Poll::Ready(Ok(())) => {
527 let _ = tx.send_item(msg);
528 }
529 Poll::Ready(Err(_)) => {
530 }
532 }
533 Some(Poll::Ready(()))
535 }
536}
537
538impl<N: NetworkPrimitives> Future for ActiveSession<N> {
539 type Output = ();
540
541 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
542 let this = self.get_mut();
543
544 if let Some(terminate) = this.poll_terminate_message(cx) {
546 return terminate
547 }
548
549 if this.is_disconnecting() {
550 return this.poll_disconnect(cx)
551 }
552
553 let mut budget = 4;
559
560 'main: loop {
562 let mut progress = false;
563
564 loop {
566 match this.commands_rx.poll_next_unpin(cx) {
567 Poll::Pending => break,
568 Poll::Ready(None) => {
569 return Poll::Ready(())
572 }
573 Poll::Ready(Some(cmd)) => {
574 progress = true;
575 match cmd {
576 SessionCommand::Disconnect { reason } => {
577 debug!(
578 target: "net::session",
579 ?reason,
580 remote_peer_id=?this.remote_peer_id,
581 "Received disconnect command for session"
582 );
583 let reason =
584 reason.unwrap_or(DisconnectReason::DisconnectRequested);
585
586 return this.try_disconnect(reason, cx)
587 }
588 SessionCommand::Message(msg) => {
589 this.on_internal_peer_message(msg);
590 }
591 }
592 }
593 }
594 }
595
596 let deadline = this.request_deadline();
597
598 while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
599 progress = true;
600 this.on_internal_peer_request(req, deadline);
601 }
602
603 for idx in (0..this.received_requests_from_remote.len()).rev() {
606 let mut req = this.received_requests_from_remote.swap_remove(idx);
607 match req.rx.poll(cx) {
608 Poll::Pending => {
609 this.received_requests_from_remote.push(req);
611 }
612 Poll::Ready(resp) => {
613 this.handle_outgoing_response(req.request_id, resp);
614 }
615 }
616 }
617
618 while this.conn.poll_ready_unpin(cx).is_ready() {
620 if let Some(msg) = this.queued_outgoing.pop_front() {
621 progress = true;
622 let res = match msg {
623 OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
624 OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
625 OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
626 };
627 if let Err(err) = res {
628 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
629 return this.close_on_error(err, cx)
631 }
632 } else {
633 break
635 }
636 }
637
638 'receive: loop {
640 budget -= 1;
642 if budget == 0 {
643 cx.waker().wake_by_ref();
645 break 'main
646 }
647
648 if let Some(msg) = this.pending_message_to_session.take() {
652 match this.to_session_manager.poll_reserve(cx) {
653 Poll::Ready(Ok(_)) => {
654 let _ = this.to_session_manager.send_item(msg);
655 }
656 Poll::Ready(Err(_)) => return Poll::Ready(()),
657 Poll::Pending => {
658 this.pending_message_to_session = Some(msg);
659 break 'receive
660 }
661 };
662 }
663
664 if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
666 break 'receive
672 }
673
674 if this.queued_outgoing.messages.len() > MAX_QUEUED_OUTGOING_RESPONSES &&
676 this.queued_response_count() > MAX_QUEUED_OUTGOING_RESPONSES
677 {
678 break 'receive
685 }
686
687 match this.conn.poll_next_unpin(cx) {
688 Poll::Pending => break,
689 Poll::Ready(None) => {
690 if this.is_disconnecting() {
691 break
692 }
693 debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
694 return this.emit_disconnect(cx)
695 }
696 Poll::Ready(Some(res)) => {
697 match res {
698 Ok(msg) => {
699 trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
700 match this.on_incoming_message(msg) {
702 OnIncomingMessageOutcome::Ok => {
703 progress = true;
705 }
706 OnIncomingMessageOutcome::BadMessage { error, message } => {
707 debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
708 return this.close_on_error(error, cx)
709 }
710 OnIncomingMessageOutcome::NoCapacity(msg) => {
711 this.pending_message_to_session = Some(msg);
713 }
714 }
715 }
716 Err(err) => {
717 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
718 return this.close_on_error(err, cx)
719 }
720 }
721 }
722 }
723 }
724
725 if !progress {
726 break 'main
727 }
728 }
729
730 if let Some(interval) = &mut this.range_update_interval {
731 while interval.poll_tick(cx).is_ready() {
733 this.queued_outgoing.push_back(
734 EthMessage::BlockRangeUpdate(this.local_range_info.to_message()).into(),
735 );
736 }
737 }
738
739 while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
740 if this.check_timed_out_requests(Instant::now()) {
742 if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
743 let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
744 this.pending_message_to_session = Some(msg);
745 }
746 }
747 }
748
749 this.shrink_to_fit();
750
751 Poll::Pending
752 }
753}
754
755pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
757 request_id: u64,
759 rx: PeerResponse<N>,
761 #[expect(dead_code)]
763 received: Instant,
764}
765
766pub(crate) struct InflightRequest<R> {
768 request: RequestState<R>,
770 timestamp: Instant,
772 deadline: Instant,
774}
775
776impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
779 #[inline]
781 fn is_timed_out(&self, now: Instant) -> bool {
782 now > self.deadline
783 }
784
785 #[inline]
787 const fn is_waiting(&self) -> bool {
788 matches!(self.request, RequestState::Waiting(_))
789 }
790
791 fn timeout(&mut self) {
793 let mut req = RequestState::TimedOut;
794 std::mem::swap(&mut self.request, &mut req);
795
796 if let RequestState::Waiting(req) = req {
797 req.send_err_response(RequestError::Timeout);
798 }
799 }
800}
801
802enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
804 Ok,
806 BadMessage { error: EthStreamError, message: EthMessage<N> },
808 NoCapacity(ActiveSessionMessage<N>),
810}
811
812impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
813 for OnIncomingMessageOutcome<N>
814{
815 fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
816 match res {
817 Ok(_) => Self::Ok,
818 Err(msg) => Self::NoCapacity(msg),
819 }
820 }
821}
822
823enum RequestState<R> {
824 Waiting(R),
826 TimedOut,
828}
829
830#[derive(Debug)]
832pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
833 Eth(EthMessage<N>),
835 Broadcast(EthBroadcastMessage<N>),
837 Raw(RawCapabilityMessage),
839}
840
841impl<N: NetworkPrimitives> OutgoingMessage<N> {
842 const fn is_response(&self) -> bool {
844 match self {
845 Self::Eth(msg) => msg.is_response(),
846 _ => false,
847 }
848 }
849}
850
851impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
852 fn from(value: EthMessage<N>) -> Self {
853 Self::Eth(value)
854 }
855}
856
857impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
858 fn from(value: EthBroadcastMessage<N>) -> Self {
859 Self::Broadcast(value)
860 }
861}
862
863#[inline]
865fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
866 let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
867
868 let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
870
871 smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
872}
873
874pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
876 messages: VecDeque<OutgoingMessage<N>>,
877 count: Gauge,
878}
879
880impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
881 pub(crate) const fn new(metric: Gauge) -> Self {
882 Self { messages: VecDeque::new(), count: metric }
883 }
884
885 pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
886 self.messages.push_back(message);
887 self.count.increment(1);
888 }
889
890 pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
891 self.messages.pop_front().inspect(|_| self.count.decrement(1))
892 }
893
894 pub(crate) fn shrink_to_fit(&mut self) {
895 self.messages.shrink_to_fit();
896 }
897}
898
899#[cfg(test)]
900mod tests {
901 use super::*;
902 use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
903 use alloy_eips::eip2124::ForkFilter;
904 use reth_chainspec::MAINNET;
905 use reth_ecies::stream::ECIESStream;
906 use reth_eth_wire::{
907 handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockBodies,
908 HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
909 UnifiedStatus,
910 };
911 use reth_ethereum_forks::EthereumHardfork;
912 use reth_network_peers::pk2id;
913 use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
914 use secp256k1::{SecretKey, SECP256K1};
915 use tokio::{
916 net::{TcpListener, TcpStream},
917 sync::mpsc,
918 };
919
920 fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
922 HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
923 }
924
925 struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
926 _remote_capabilities: Arc<Capabilities>,
927 active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
928 active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
929 to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
930 secret_key: SecretKey,
931 local_peer_id: PeerId,
932 hello: HelloMessageWithProtocols,
933 status: UnifiedStatus,
934 fork_filter: ForkFilter,
935 next_id: usize,
936 }
937
938 impl<N: NetworkPrimitives> SessionBuilder<N> {
939 fn next_id(&mut self) -> SessionId {
940 let id = self.next_id;
941 self.next_id += 1;
942 SessionId(id)
943 }
944
945 fn with_client_stream<F, O>(
947 &self,
948 local_addr: SocketAddr,
949 f: F,
950 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
951 where
952 F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
953 O: Future<Output = ()> + Send + Sync,
954 {
955 let mut status = self.status;
956 let fork_filter = self.fork_filter.clone();
957 let local_peer_id = self.local_peer_id;
958 let mut hello = self.hello.clone();
959 let key = SecretKey::new(&mut rand_08::thread_rng());
960 hello.id = pk2id(&key.public_key(SECP256K1));
961 Box::pin(async move {
962 let outgoing = TcpStream::connect(local_addr).await.unwrap();
963 let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
964
965 let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
966
967 let eth_version = p2p_stream.shared_capabilities().eth_version().unwrap();
968 status.set_eth_version(eth_version);
969
970 let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
971 .handshake(status, fork_filter)
972 .await
973 .unwrap();
974 f(client_stream).await
975 })
976 }
977
978 async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
979 let remote_addr = stream.local_addr().unwrap();
980 let session_id = self.next_id();
981 let (_disconnect_tx, disconnect_rx) = oneshot::channel();
982 let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
983
984 tokio::task::spawn(start_pending_incoming_session(
985 Arc::new(EthHandshake::default()),
986 disconnect_rx,
987 session_id,
988 stream,
989 pending_sessions_tx,
990 remote_addr,
991 self.secret_key,
992 self.hello.clone(),
993 self.status,
994 self.fork_filter.clone(),
995 Default::default(),
996 ));
997
998 let mut stream = ReceiverStream::new(pending_sessions_rx);
999
1000 match stream.next().await.unwrap() {
1001 PendingSessionEvent::Established {
1002 session_id,
1003 remote_addr,
1004 peer_id,
1005 capabilities,
1006 conn,
1007 ..
1008 } => {
1009 let (_to_session_tx, messages_rx) = mpsc::channel(10);
1010 let (commands_to_session, commands_rx) = mpsc::channel(10);
1011 let poll_sender = PollSender::new(self.active_session_tx.clone());
1012
1013 self.to_sessions.push(commands_to_session);
1014
1015 ActiveSession {
1016 next_id: 0,
1017 remote_peer_id: peer_id,
1018 remote_addr,
1019 remote_capabilities: Arc::clone(&capabilities),
1020 session_id,
1021 commands_rx: ReceiverStream::new(commands_rx),
1022 to_session_manager: MeteredPollSender::new(
1023 poll_sender,
1024 "network_active_session",
1025 ),
1026 pending_message_to_session: None,
1027 internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
1028 inflight_requests: Default::default(),
1029 conn,
1030 queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
1031 received_requests_from_remote: Default::default(),
1032 internal_request_timeout_interval: tokio::time::interval(
1033 INITIAL_REQUEST_TIMEOUT,
1034 ),
1035 internal_request_timeout: Arc::new(AtomicU64::new(
1036 INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
1037 )),
1038 protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
1039 terminate_message: None,
1040 range_info: None,
1041 local_range_info: BlockRangeInfo::new(
1042 0,
1043 1000,
1044 alloy_primitives::B256::ZERO,
1045 ),
1046 range_update_interval: None,
1047 }
1048 }
1049 ev => {
1050 panic!("unexpected message {ev:?}")
1051 }
1052 }
1053 }
1054 }
1055
1056 impl Default for SessionBuilder {
1057 fn default() -> Self {
1058 let (active_session_tx, active_session_rx) = mpsc::channel(100);
1059
1060 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1061 let local_peer_id = pk2id(&pk);
1062
1063 Self {
1064 next_id: 0,
1065 _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1066 active_session_tx,
1067 active_session_rx: ReceiverStream::new(active_session_rx),
1068 to_sessions: vec![],
1069 hello: eth_hello(&secret_key),
1070 secret_key,
1071 local_peer_id,
1072 status: StatusBuilder::default().build(),
1073 fork_filter: MAINNET
1074 .hardfork_fork_filter(EthereumHardfork::Frontier)
1075 .expect("The Frontier fork filter should exist on mainnet"),
1076 }
1077 }
1078 }
1079
1080 #[tokio::test(flavor = "multi_thread")]
1081 async fn test_disconnect() {
1082 let mut builder = SessionBuilder::default();
1083
1084 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1085 let local_addr = listener.local_addr().unwrap();
1086
1087 let expected_disconnect = DisconnectReason::UselessPeer;
1088
1089 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1090 let msg = client_stream.next().await.unwrap().unwrap_err();
1091 assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1092 });
1093
1094 tokio::task::spawn(async move {
1095 let (incoming, _) = listener.accept().await.unwrap();
1096 let mut session = builder.connect_incoming(incoming).await;
1097
1098 session.start_disconnect(expected_disconnect).unwrap();
1099 session.await
1100 });
1101
1102 fut.await;
1103 }
1104
1105 #[tokio::test(flavor = "multi_thread")]
1106 async fn handle_dropped_stream() {
1107 let mut builder = SessionBuilder::default();
1108
1109 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1110 let local_addr = listener.local_addr().unwrap();
1111
1112 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1113 drop(client_stream);
1114 tokio::time::sleep(Duration::from_secs(1)).await
1115 });
1116
1117 let (tx, rx) = oneshot::channel();
1118
1119 tokio::task::spawn(async move {
1120 let (incoming, _) = listener.accept().await.unwrap();
1121 let session = builder.connect_incoming(incoming).await;
1122 session.await;
1123
1124 tx.send(()).unwrap();
1125 });
1126
1127 tokio::task::spawn(fut);
1128
1129 rx.await.unwrap();
1130 }
1131
1132 #[tokio::test(flavor = "multi_thread")]
1133 async fn test_send_many_messages() {
1134 reth_tracing::init_test_tracing();
1135 let mut builder = SessionBuilder::default();
1136
1137 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1138 let local_addr = listener.local_addr().unwrap();
1139
1140 let num_messages = 100;
1141
1142 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1143 for _ in 0..num_messages {
1144 client_stream
1145 .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1146 .await
1147 .unwrap();
1148 }
1149 });
1150
1151 let (tx, rx) = oneshot::channel();
1152
1153 tokio::task::spawn(async move {
1154 let (incoming, _) = listener.accept().await.unwrap();
1155 let session = builder.connect_incoming(incoming).await;
1156 session.await;
1157
1158 tx.send(()).unwrap();
1159 });
1160
1161 tokio::task::spawn(fut);
1162
1163 rx.await.unwrap();
1164 }
1165
1166 #[tokio::test(flavor = "multi_thread")]
1167 async fn test_request_timeout() {
1168 reth_tracing::init_test_tracing();
1169
1170 let mut builder = SessionBuilder::default();
1171
1172 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1173 let local_addr = listener.local_addr().unwrap();
1174
1175 let request_timeout = Duration::from_millis(100);
1176 let drop_timeout = Duration::from_millis(1500);
1177
1178 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1179 let _client_stream = client_stream;
1180 tokio::time::sleep(drop_timeout * 60).await;
1181 });
1182 tokio::task::spawn(fut);
1183
1184 let (incoming, _) = listener.accept().await.unwrap();
1185 let mut session = builder.connect_incoming(incoming).await;
1186 session
1187 .internal_request_timeout
1188 .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1189 session.protocol_breach_request_timeout = drop_timeout;
1190 session.internal_request_timeout_interval =
1191 tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1192 let (tx, rx) = oneshot::channel();
1193 let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1194 session.on_internal_peer_request(req, Instant::now());
1195 tokio::spawn(session);
1196
1197 let err = rx.await.unwrap().unwrap_err();
1198 assert_eq!(err, RequestError::Timeout);
1199
1200 let msg = builder.active_session_rx.next().await.unwrap();
1202 match msg {
1203 ActiveSessionMessage::ProtocolBreach { .. } => {}
1204 ev => unreachable!("{ev:?}"),
1205 }
1206 }
1207
1208 #[tokio::test(flavor = "multi_thread")]
1209 async fn test_keep_alive() {
1210 let mut builder = SessionBuilder::default();
1211
1212 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1213 let local_addr = listener.local_addr().unwrap();
1214
1215 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1216 let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1217 client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1218 });
1219
1220 let (tx, rx) = oneshot::channel();
1221
1222 tokio::task::spawn(async move {
1223 let (incoming, _) = listener.accept().await.unwrap();
1224 let session = builder.connect_incoming(incoming).await;
1225 session.await;
1226
1227 tx.send(()).unwrap();
1228 });
1229
1230 tokio::task::spawn(fut);
1231
1232 rx.await.unwrap();
1233 }
1234
1235 #[tokio::test(flavor = "multi_thread")]
1237 async fn test_send_at_capacity() {
1238 let mut builder = SessionBuilder::default();
1239
1240 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1241 let local_addr = listener.local_addr().unwrap();
1242
1243 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1244 client_stream
1245 .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1246 .await
1247 .unwrap();
1248 let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1249 });
1250 tokio::task::spawn(fut);
1251
1252 let (incoming, _) = listener.accept().await.unwrap();
1253 let session = builder.connect_incoming(incoming).await;
1254
1255 let mut num_fill_messages = 0;
1257 loop {
1258 if builder
1259 .active_session_tx
1260 .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1261 .is_err()
1262 {
1263 break
1264 }
1265 num_fill_messages += 1;
1266 }
1267
1268 tokio::task::spawn(async move {
1269 session.await;
1270 });
1271
1272 tokio::time::sleep(Duration::from_millis(100)).await;
1273
1274 for _ in 0..num_fill_messages {
1275 let message = builder.active_session_rx.next().await.unwrap();
1276 match message {
1277 ActiveSessionMessage::ProtocolBreach { .. } => {}
1278 ev => unreachable!("{ev:?}"),
1279 }
1280 }
1281
1282 let message = builder.active_session_rx.next().await.unwrap();
1283 match message {
1284 ActiveSessionMessage::ValidMessage {
1285 message: PeerMessage::PooledTransactions(_),
1286 ..
1287 } => {}
1288 _ => unreachable!(),
1289 }
1290 }
1291
1292 #[test]
1293 fn timeout_calculation_sanity_tests() {
1294 let rtt = Duration::from_secs(5);
1295 let timeout = rtt * TIMEOUT_SCALING;
1297
1298 assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1300
1301 assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1303 assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1304 assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1305 assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1306 }
1307}