reth_network/session/
mod.rs

1//! Support for handling peer sessions.
2
3mod active;
4mod conn;
5mod counter;
6mod handle;
7
8use active::QueuedOutgoingMessages;
9pub use conn::EthRlpxConnection;
10pub use handle::{
11    ActiveSessionHandle, ActiveSessionMessage, PendingSessionEvent, PendingSessionHandle,
12    SessionCommand,
13};
14
15pub use reth_network_api::{Direction, PeerInfo};
16
17use std::{
18    collections::HashMap,
19    future::Future,
20    net::SocketAddr,
21    sync::{atomic::AtomicU64, Arc},
22    task::{Context, Poll},
23    time::{Duration, Instant},
24};
25
26use crate::{
27    message::PeerMessage,
28    metrics::SessionManagerMetrics,
29    protocol::{IntoRlpxSubProtocol, OnNotSupported, RlpxSubProtocolHandlers, RlpxSubProtocols},
30    session::active::ActiveSession,
31};
32use counter::SessionCounter;
33use futures::{future::Either, io, FutureExt, StreamExt};
34use reth_ecies::{stream::ECIESStream, ECIESError};
35use reth_eth_wire::{
36    errors::EthStreamError, handshake::EthRlpxHandshake, multiplex::RlpxProtocolMultiplexer,
37    Capabilities, DisconnectReason, EthStream, EthVersion, HelloMessageWithProtocols,
38    NetworkPrimitives, Status, UnauthedP2PStream, HANDSHAKE_TIMEOUT,
39};
40use reth_ethereum_forks::{ForkFilter, ForkId, ForkTransition, Head};
41use reth_metrics::common::mpsc::MeteredPollSender;
42use reth_network_api::{PeerRequest, PeerRequestSender};
43use reth_network_peers::PeerId;
44use reth_network_types::SessionsConfig;
45use reth_tasks::TaskSpawner;
46use rustc_hash::FxHashMap;
47use secp256k1::SecretKey;
48use tokio::{
49    io::{AsyncRead, AsyncWrite},
50    net::TcpStream,
51    sync::{mpsc, mpsc::error::TrySendError, oneshot},
52};
53use tokio_stream::wrappers::ReceiverStream;
54use tokio_util::sync::PollSender;
55use tracing::{debug, instrument, trace};
56
57/// Internal identifier for active sessions.
58#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Eq, Hash)]
59pub struct SessionId(usize);
60
61/// Manages a set of sessions.
62#[must_use = "Session Manager must be polled to process session events."]
63#[derive(Debug)]
64pub struct SessionManager<N: NetworkPrimitives> {
65    /// Tracks the identifier for the next session.
66    next_id: usize,
67    /// Keeps track of all sessions
68    counter: SessionCounter,
69    ///  The maximum initial time an [`ActiveSession`] waits for a response from the peer before it
70    /// responds to an _internal_ request with a `TimeoutError`
71    initial_internal_request_timeout: Duration,
72    /// If an [`ActiveSession`] does not receive a response at all within this duration then it is
73    /// considered a protocol violation and the session will initiate a drop.
74    protocol_breach_request_timeout: Duration,
75    /// The timeout after which a pending session attempt is considered failed.
76    pending_session_timeout: Duration,
77    /// The secret key used for authenticating sessions.
78    secret_key: SecretKey,
79    /// The `Status` message to send to peers.
80    status: Status,
81    /// The `HelloMessage` message to send to peers.
82    hello_message: HelloMessageWithProtocols,
83    /// The [`ForkFilter`] used to validate the peer's `Status` message.
84    fork_filter: ForkFilter,
85    /// Size of the command buffer per session.
86    session_command_buffer: usize,
87    /// The executor for spawned tasks.
88    executor: Box<dyn TaskSpawner>,
89    /// All pending session that are currently handshaking, exchanging `Hello`s.
90    ///
91    /// Events produced during the authentication phase are reported to this manager. Once the
92    /// session is authenticated, it can be moved to the `active_session` set.
93    pending_sessions: FxHashMap<SessionId, PendingSessionHandle>,
94    /// All active sessions that are ready to exchange messages.
95    active_sessions: HashMap<PeerId, ActiveSessionHandle<N>>,
96    /// The original Sender half of the [`PendingSessionEvent`] channel.
97    ///
98    /// When a new (pending) session is created, the corresponding [`PendingSessionHandle`] will
99    /// get a clone of this sender half.
100    pending_sessions_tx: mpsc::Sender<PendingSessionEvent<N>>,
101    /// Receiver half that listens for [`PendingSessionEvent`] produced by pending sessions.
102    pending_session_rx: ReceiverStream<PendingSessionEvent<N>>,
103    /// The original Sender half of the [`ActiveSessionMessage`] channel.
104    ///
105    /// When active session state is reached, the corresponding [`ActiveSessionHandle`] will get a
106    /// clone of this sender half.
107    active_session_tx: MeteredPollSender<ActiveSessionMessage<N>>,
108    /// Receiver half that listens for [`ActiveSessionMessage`] produced by pending sessions.
109    active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
110    /// Additional `RLPx` sub-protocols to be used by the session manager.
111    extra_protocols: RlpxSubProtocols,
112    /// Tracks the ongoing graceful disconnections attempts for incoming connections.
113    disconnections_counter: DisconnectionsCounter,
114    /// Metrics for the session manager.
115    metrics: SessionManagerMetrics,
116    /// The [`EthRlpxHandshake`] is used to perform the initial handshake with the peer.
117    handshake: Arc<dyn EthRlpxHandshake>,
118}
119
120// === impl SessionManager ===
121
122impl<N: NetworkPrimitives> SessionManager<N> {
123    /// Creates a new empty [`SessionManager`].
124    #[allow(clippy::too_many_arguments)]
125    pub fn new(
126        secret_key: SecretKey,
127        config: SessionsConfig,
128        executor: Box<dyn TaskSpawner>,
129        status: Status,
130        hello_message: HelloMessageWithProtocols,
131        fork_filter: ForkFilter,
132        extra_protocols: RlpxSubProtocols,
133        handshake: Arc<dyn EthRlpxHandshake>,
134    ) -> Self {
135        let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(config.session_event_buffer);
136        let (active_session_tx, active_session_rx) = mpsc::channel(config.session_event_buffer);
137        let active_session_tx = PollSender::new(active_session_tx);
138
139        Self {
140            next_id: 0,
141            counter: SessionCounter::new(config.limits),
142            initial_internal_request_timeout: config.initial_internal_request_timeout,
143            protocol_breach_request_timeout: config.protocol_breach_request_timeout,
144            pending_session_timeout: config.pending_session_timeout,
145            secret_key,
146            status,
147            hello_message,
148            fork_filter,
149            session_command_buffer: config.session_command_buffer,
150            executor,
151            pending_sessions: Default::default(),
152            active_sessions: Default::default(),
153            pending_sessions_tx,
154            pending_session_rx: ReceiverStream::new(pending_sessions_rx),
155            active_session_tx: MeteredPollSender::new(active_session_tx, "network_active_session"),
156            active_session_rx: ReceiverStream::new(active_session_rx),
157            extra_protocols,
158            disconnections_counter: Default::default(),
159            metrics: Default::default(),
160            handshake,
161        }
162    }
163
164    /// Check whether the provided [`ForkId`] is compatible based on the validation rules in
165    /// `EIP-2124`.
166    pub fn is_valid_fork_id(&self, fork_id: ForkId) -> bool {
167        self.fork_filter.validate(fork_id).is_ok()
168    }
169
170    /// Returns the next unique [`SessionId`].
171    fn next_id(&mut self) -> SessionId {
172        let id = self.next_id;
173        self.next_id += 1;
174        SessionId(id)
175    }
176
177    /// Returns the current status of the session.
178    pub const fn status(&self) -> Status {
179        self.status
180    }
181
182    /// Returns the secret key used for authenticating sessions.
183    pub const fn secret_key(&self) -> SecretKey {
184        self.secret_key
185    }
186
187    /// Returns a borrowed reference to the active sessions.
188    pub const fn active_sessions(&self) -> &HashMap<PeerId, ActiveSessionHandle<N>> {
189        &self.active_sessions
190    }
191
192    /// Returns the session hello message.
193    pub fn hello_message(&self) -> HelloMessageWithProtocols {
194        self.hello_message.clone()
195    }
196
197    /// Adds an additional protocol handler to the `RLPx` sub-protocol list.
198    pub(crate) fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol) {
199        self.extra_protocols.push(protocol)
200    }
201
202    /// Returns the number of currently pending connections.
203    #[inline]
204    pub(crate) fn num_pending_connections(&self) -> usize {
205        self.pending_sessions.len()
206    }
207
208    /// Spawns the given future onto a new task that is tracked in the `spawned_tasks`
209    /// [`JoinSet`](tokio::task::JoinSet).
210    fn spawn<F>(&self, f: F)
211    where
212        F: Future<Output = ()> + Send + 'static,
213    {
214        self.executor.spawn(f.boxed());
215    }
216
217    /// Invoked on a received status update.
218    ///
219    /// If the updated activated another fork, this will return a [`ForkTransition`] and updates the
220    /// active [`ForkId`]. See also [`ForkFilter::set_head`].
221    pub(crate) fn on_status_update(&mut self, head: Head) -> Option<ForkTransition> {
222        self.status.blockhash = head.hash;
223        self.status.total_difficulty = head.total_difficulty;
224        let transition = self.fork_filter.set_head(head);
225        self.status.forkid = self.fork_filter.current();
226        transition
227    }
228
229    /// An incoming TCP connection was received. This starts the authentication process to turn this
230    /// stream into an active peer session.
231    ///
232    /// Returns an error if the configured limit has been reached.
233    pub(crate) fn on_incoming(
234        &mut self,
235        stream: TcpStream,
236        remote_addr: SocketAddr,
237    ) -> Result<SessionId, ExceedsSessionLimit> {
238        self.counter.ensure_pending_inbound()?;
239
240        let session_id = self.next_id();
241
242        trace!(
243            target: "net::session",
244            ?remote_addr,
245            ?session_id,
246            "new pending incoming session"
247        );
248
249        let (disconnect_tx, disconnect_rx) = oneshot::channel();
250        let pending_events = self.pending_sessions_tx.clone();
251        let secret_key = self.secret_key;
252        let hello_message = self.hello_message.clone();
253        let status = self.status;
254        let fork_filter = self.fork_filter.clone();
255        let extra_handlers = self.extra_protocols.on_incoming(remote_addr);
256        self.spawn(pending_session_with_timeout(
257            self.pending_session_timeout,
258            session_id,
259            remote_addr,
260            Direction::Incoming,
261            pending_events.clone(),
262            start_pending_incoming_session(
263                self.handshake.clone(),
264                disconnect_rx,
265                session_id,
266                stream,
267                pending_events,
268                remote_addr,
269                secret_key,
270                hello_message,
271                status,
272                fork_filter,
273                extra_handlers,
274            ),
275        ));
276
277        let handle = PendingSessionHandle {
278            disconnect_tx: Some(disconnect_tx),
279            direction: Direction::Incoming,
280        };
281        self.pending_sessions.insert(session_id, handle);
282        self.counter.inc_pending_inbound();
283        Ok(session_id)
284    }
285
286    /// Starts a new pending session from the local node to the given remote node.
287    pub fn dial_outbound(&mut self, remote_addr: SocketAddr, remote_peer_id: PeerId) {
288        // The error can be dropped because no dial will be made if it would exceed the limit
289        if self.counter.ensure_pending_outbound().is_ok() {
290            let session_id = self.next_id();
291            let (disconnect_tx, disconnect_rx) = oneshot::channel();
292            let pending_events = self.pending_sessions_tx.clone();
293            let secret_key = self.secret_key;
294            let hello_message = self.hello_message.clone();
295            let fork_filter = self.fork_filter.clone();
296            let status = self.status;
297            let extra_handlers = self.extra_protocols.on_outgoing(remote_addr, remote_peer_id);
298            self.spawn(pending_session_with_timeout(
299                self.pending_session_timeout,
300                session_id,
301                remote_addr,
302                Direction::Outgoing(remote_peer_id),
303                pending_events.clone(),
304                start_pending_outbound_session(
305                    self.handshake.clone(),
306                    disconnect_rx,
307                    pending_events,
308                    session_id,
309                    remote_addr,
310                    remote_peer_id,
311                    secret_key,
312                    hello_message,
313                    status,
314                    fork_filter,
315                    extra_handlers,
316                ),
317            ));
318
319            let handle = PendingSessionHandle {
320                disconnect_tx: Some(disconnect_tx),
321                direction: Direction::Outgoing(remote_peer_id),
322            };
323            self.pending_sessions.insert(session_id, handle);
324            self.counter.inc_pending_outbound();
325        }
326    }
327
328    /// Initiates a shutdown of the channel.
329    ///
330    /// This will trigger the disconnect on the session task to gracefully terminate. The result
331    /// will be picked up by the receiver.
332    pub fn disconnect(&self, node: PeerId, reason: Option<DisconnectReason>) {
333        if let Some(session) = self.active_sessions.get(&node) {
334            session.disconnect(reason);
335        }
336    }
337
338    /// Initiates a shutdown of all sessions.
339    ///
340    /// It will trigger the disconnect on all the session tasks to gracefully terminate. The result
341    /// will be picked by the receiver.
342    pub fn disconnect_all(&self, reason: Option<DisconnectReason>) {
343        for session in self.active_sessions.values() {
344            session.disconnect(reason);
345        }
346    }
347
348    /// Disconnects all pending sessions.
349    pub fn disconnect_all_pending(&mut self) {
350        for session in self.pending_sessions.values_mut() {
351            session.disconnect();
352        }
353    }
354
355    /// Sends a message to the peer's session
356    pub fn send_message(&self, peer_id: &PeerId, msg: PeerMessage<N>) {
357        if let Some(session) = self.active_sessions.get(peer_id) {
358            let _ = session.commands_to_session.try_send(SessionCommand::Message(msg)).inspect_err(
359                |e| {
360                    if let TrySendError::Full(_) = e {
361                        debug!(
362                            target: "net::session",
363                            ?peer_id,
364                            "session command buffer full, dropping message"
365                        );
366                        self.metrics.total_outgoing_peer_messages_dropped.increment(1);
367                    }
368                },
369            );
370        }
371    }
372
373    /// Removes the [`PendingSessionHandle`] if it exists.
374    fn remove_pending_session(&mut self, id: &SessionId) -> Option<PendingSessionHandle> {
375        let session = self.pending_sessions.remove(id)?;
376        self.counter.dec_pending(&session.direction);
377        Some(session)
378    }
379
380    /// Removes the [`PendingSessionHandle`] if it exists.
381    fn remove_active_session(&mut self, id: &PeerId) -> Option<ActiveSessionHandle<N>> {
382        let session = self.active_sessions.remove(id)?;
383        self.counter.dec_active(&session.direction);
384        Some(session)
385    }
386
387    /// Try to gracefully disconnect an incoming connection by initiating a ECIES connection and
388    /// sending a disconnect. If [`SessionManager`] is at capacity for ongoing disconnections, will
389    /// simply drop the incoming connection.
390    pub(crate) fn try_disconnect_incoming_connection(
391        &self,
392        stream: TcpStream,
393        reason: DisconnectReason,
394    ) {
395        if !self.disconnections_counter.has_capacity() {
396            // drop the connection if we don't have capacity for gracefully disconnecting
397            return
398        }
399
400        let guard = self.disconnections_counter.clone();
401        let secret_key = self.secret_key;
402
403        self.spawn(async move {
404            trace!(
405                target: "net::session",
406                "gracefully disconnecting incoming connection"
407            );
408            if let Ok(stream) = get_ecies_stream(stream, secret_key, Direction::Incoming).await {
409                let mut unauth = UnauthedP2PStream::new(stream);
410                let _ = unauth.send_disconnect(reason).await;
411                drop(guard);
412            }
413        });
414    }
415
416    /// This polls all the session handles and returns [`SessionEvent`].
417    ///
418    /// Active sessions are prioritized.
419    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<SessionEvent<N>> {
420        // Poll events from active sessions
421        match self.active_session_rx.poll_next_unpin(cx) {
422            Poll::Pending => {}
423            Poll::Ready(None) => {
424                unreachable!("Manager holds both channel halves.")
425            }
426            Poll::Ready(Some(event)) => {
427                return match event {
428                    ActiveSessionMessage::Disconnected { peer_id, remote_addr } => {
429                        trace!(
430                            target: "net::session",
431                            ?peer_id,
432                            "gracefully disconnected active session."
433                        );
434                        self.remove_active_session(&peer_id);
435                        Poll::Ready(SessionEvent::Disconnected { peer_id, remote_addr })
436                    }
437                    ActiveSessionMessage::ClosedOnConnectionError {
438                        peer_id,
439                        remote_addr,
440                        error,
441                    } => {
442                        trace!(target: "net::session", ?peer_id, %error,"closed session.");
443                        self.remove_active_session(&peer_id);
444                        Poll::Ready(SessionEvent::SessionClosedOnConnectionError {
445                            remote_addr,
446                            peer_id,
447                            error,
448                        })
449                    }
450                    ActiveSessionMessage::ValidMessage { peer_id, message } => {
451                        Poll::Ready(SessionEvent::ValidMessage { peer_id, message })
452                    }
453                    ActiveSessionMessage::BadMessage { peer_id } => {
454                        Poll::Ready(SessionEvent::BadMessage { peer_id })
455                    }
456                    ActiveSessionMessage::ProtocolBreach { peer_id } => {
457                        Poll::Ready(SessionEvent::ProtocolBreach { peer_id })
458                    }
459                }
460            }
461        }
462
463        // Poll the pending session event stream
464        let event = match self.pending_session_rx.poll_next_unpin(cx) {
465            Poll::Pending => return Poll::Pending,
466            Poll::Ready(None) => unreachable!("Manager holds both channel halves."),
467            Poll::Ready(Some(event)) => event,
468        };
469        match event {
470            PendingSessionEvent::Established {
471                session_id,
472                remote_addr,
473                local_addr,
474                peer_id,
475                capabilities,
476                conn,
477                status,
478                direction,
479                client_id,
480            } => {
481                // move from pending to established.
482                self.remove_pending_session(&session_id);
483
484                // If there's already a session to the peer then we disconnect right away
485                if self.active_sessions.contains_key(&peer_id) {
486                    trace!(
487                        target: "net::session",
488                        ?session_id,
489                        ?remote_addr,
490                        ?peer_id,
491                        ?direction,
492                        "already connected"
493                    );
494
495                    self.spawn(async move {
496                        // send a disconnect message
497                        let _ =
498                            conn.into_inner().disconnect(DisconnectReason::AlreadyConnected).await;
499                    });
500
501                    return Poll::Ready(SessionEvent::AlreadyConnected {
502                        peer_id,
503                        remote_addr,
504                        direction,
505                    })
506                }
507
508                let (commands_to_session, commands_rx) = mpsc::channel(self.session_command_buffer);
509
510                let (to_session_tx, messages_rx) = mpsc::channel(self.session_command_buffer);
511
512                let messages = PeerRequestSender::new(peer_id, to_session_tx);
513
514                let timeout = Arc::new(AtomicU64::new(
515                    self.initial_internal_request_timeout.as_millis() as u64,
516                ));
517
518                // negotiated version
519                let version = conn.version();
520
521                let session = ActiveSession {
522                    next_id: 0,
523                    remote_peer_id: peer_id,
524                    remote_addr,
525                    remote_capabilities: Arc::clone(&capabilities),
526                    session_id,
527                    commands_rx: ReceiverStream::new(commands_rx),
528                    to_session_manager: self.active_session_tx.clone(),
529                    pending_message_to_session: None,
530                    internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
531                    inflight_requests: Default::default(),
532                    conn,
533                    queued_outgoing: QueuedOutgoingMessages::new(
534                        self.metrics.queued_outgoing_messages.clone(),
535                    ),
536                    received_requests_from_remote: Default::default(),
537                    internal_request_timeout_interval: tokio::time::interval(
538                        self.initial_internal_request_timeout,
539                    ),
540                    internal_request_timeout: Arc::clone(&timeout),
541                    protocol_breach_request_timeout: self.protocol_breach_request_timeout,
542                    terminate_message: None,
543                };
544
545                self.spawn(session);
546
547                let client_version = client_id.into();
548                let handle = ActiveSessionHandle {
549                    status: status.clone(),
550                    direction,
551                    session_id,
552                    remote_id: peer_id,
553                    version,
554                    established: Instant::now(),
555                    capabilities: Arc::clone(&capabilities),
556                    commands_to_session,
557                    client_version: Arc::clone(&client_version),
558                    remote_addr,
559                    local_addr,
560                };
561
562                self.active_sessions.insert(peer_id, handle);
563                self.counter.inc_active(&direction);
564
565                if direction.is_outgoing() {
566                    self.metrics.total_dial_successes.increment(1);
567                }
568
569                Poll::Ready(SessionEvent::SessionEstablished {
570                    peer_id,
571                    remote_addr,
572                    client_version,
573                    version,
574                    capabilities,
575                    status,
576                    messages,
577                    direction,
578                    timeout,
579                })
580            }
581            PendingSessionEvent::Disconnected { remote_addr, session_id, direction, error } => {
582                trace!(
583                    target: "net::session",
584                    ?session_id,
585                    ?remote_addr,
586                    ?error,
587                    "disconnected pending session"
588                );
589                self.remove_pending_session(&session_id);
590                match direction {
591                    Direction::Incoming => {
592                        Poll::Ready(SessionEvent::IncomingPendingSessionClosed {
593                            remote_addr,
594                            error,
595                        })
596                    }
597                    Direction::Outgoing(peer_id) => {
598                        Poll::Ready(SessionEvent::OutgoingPendingSessionClosed {
599                            remote_addr,
600                            peer_id,
601                            error,
602                        })
603                    }
604                }
605            }
606            PendingSessionEvent::OutgoingConnectionError {
607                remote_addr,
608                session_id,
609                peer_id,
610                error,
611            } => {
612                trace!(
613                    target: "net::session",
614                    %error,
615                    ?session_id,
616                    ?remote_addr,
617                    ?peer_id,
618                    "connection refused"
619                );
620                self.remove_pending_session(&session_id);
621                Poll::Ready(SessionEvent::OutgoingConnectionError { remote_addr, peer_id, error })
622            }
623            PendingSessionEvent::EciesAuthError { remote_addr, session_id, error, direction } => {
624                trace!(
625                    target: "net::session",
626                    %error,
627                    ?session_id,
628                    ?remote_addr,
629                    "ecies auth failed"
630                );
631                self.remove_pending_session(&session_id);
632                match direction {
633                    Direction::Incoming => {
634                        Poll::Ready(SessionEvent::IncomingPendingSessionClosed {
635                            remote_addr,
636                            error: Some(PendingSessionHandshakeError::Ecies(error)),
637                        })
638                    }
639                    Direction::Outgoing(peer_id) => {
640                        Poll::Ready(SessionEvent::OutgoingPendingSessionClosed {
641                            remote_addr,
642                            peer_id,
643                            error: Some(PendingSessionHandshakeError::Ecies(error)),
644                        })
645                    }
646                }
647            }
648        }
649    }
650}
651
652/// A counter for ongoing graceful disconnections attempts.
653#[derive(Default, Debug, Clone)]
654struct DisconnectionsCounter(Arc<()>);
655
656impl DisconnectionsCounter {
657    const MAX_CONCURRENT_GRACEFUL_DISCONNECTIONS: usize = 15;
658
659    /// Returns true if the [`DisconnectionsCounter`] still has capacity
660    /// for an additional graceful disconnection.
661    fn has_capacity(&self) -> bool {
662        Arc::strong_count(&self.0) <= Self::MAX_CONCURRENT_GRACEFUL_DISCONNECTIONS
663    }
664}
665
666/// Events produced by the [`SessionManager`]
667#[derive(Debug)]
668pub enum SessionEvent<N: NetworkPrimitives> {
669    /// A new session was successfully authenticated.
670    ///
671    /// This session is now able to exchange data.
672    SessionEstablished {
673        /// The remote node's public key
674        peer_id: PeerId,
675        /// The remote node's socket address
676        remote_addr: SocketAddr,
677        /// The user agent of the remote node, usually containing the client name and version
678        client_version: Arc<str>,
679        /// The capabilities the remote node has announced
680        capabilities: Arc<Capabilities>,
681        /// negotiated eth version
682        version: EthVersion,
683        /// The Status message the peer sent during the `eth` handshake
684        status: Arc<Status>,
685        /// The channel for sending messages to the peer with the session
686        messages: PeerRequestSender<PeerRequest<N>>,
687        /// The direction of the session, either `Inbound` or `Outgoing`
688        direction: Direction,
689        /// The maximum time that the session waits for a response from the peer before timing out
690        /// the connection
691        timeout: Arc<AtomicU64>,
692    },
693    /// The peer was already connected with another session.
694    AlreadyConnected {
695        /// The remote node's public key
696        peer_id: PeerId,
697        /// The remote node's socket address
698        remote_addr: SocketAddr,
699        /// The direction of the session, either `Inbound` or `Outgoing`
700        direction: Direction,
701    },
702    /// A session received a valid message via `RLPx`.
703    ValidMessage {
704        /// The remote node's public key
705        peer_id: PeerId,
706        /// Message received from the peer.
707        message: PeerMessage<N>,
708    },
709    /// Received a bad message from the peer.
710    BadMessage {
711        /// Identifier of the remote peer.
712        peer_id: PeerId,
713    },
714    /// Remote peer is considered in protocol violation
715    ProtocolBreach {
716        /// Identifier of the remote peer.
717        peer_id: PeerId,
718    },
719    /// Closed an incoming pending session during handshaking.
720    IncomingPendingSessionClosed {
721        /// The remote node's socket address
722        remote_addr: SocketAddr,
723        /// The pending handshake session error that caused the session to close
724        error: Option<PendingSessionHandshakeError>,
725    },
726    /// Closed an outgoing pending session during handshaking.
727    OutgoingPendingSessionClosed {
728        /// The remote node's socket address
729        remote_addr: SocketAddr,
730        /// The remote node's public key
731        peer_id: PeerId,
732        /// The pending handshake session error that caused the session to close
733        error: Option<PendingSessionHandshakeError>,
734    },
735    /// Failed to establish a tcp stream
736    OutgoingConnectionError {
737        /// The remote node's socket address
738        remote_addr: SocketAddr,
739        /// The remote node's public key
740        peer_id: PeerId,
741        /// The error that caused the outgoing connection to fail
742        error: io::Error,
743    },
744    /// Session was closed due to an error
745    SessionClosedOnConnectionError {
746        /// The id of the remote peer.
747        peer_id: PeerId,
748        /// The socket we were connected to.
749        remote_addr: SocketAddr,
750        /// The error that caused the session to close
751        error: EthStreamError,
752    },
753    /// Active session was gracefully disconnected.
754    Disconnected {
755        /// The remote node's public key
756        peer_id: PeerId,
757        /// The remote node's socket address that we were connected to
758        remote_addr: SocketAddr,
759    },
760}
761
762/// Errors that can occur during handshaking/authenticating the underlying streams.
763#[derive(Debug, thiserror::Error)]
764pub enum PendingSessionHandshakeError {
765    /// The pending session failed due to an error while establishing the `eth` stream
766    #[error(transparent)]
767    Eth(EthStreamError),
768    /// The pending session failed due to an error while establishing the ECIES stream
769    #[error(transparent)]
770    Ecies(ECIESError),
771    /// Thrown when the authentication timed out
772    #[error("authentication timed out")]
773    Timeout,
774    /// Thrown when the remote lacks the required capability
775    #[error("Mandatory extra capability unsupported")]
776    UnsupportedExtraCapability,
777}
778
779impl PendingSessionHandshakeError {
780    /// Returns the [`DisconnectReason`] if the error is a disconnect message
781    pub const fn as_disconnected(&self) -> Option<DisconnectReason> {
782        match self {
783            Self::Eth(eth_err) => eth_err.as_disconnected(),
784            _ => None,
785        }
786    }
787}
788
789/// The error thrown when the max configured limit has been reached and no more connections are
790/// accepted.
791#[derive(Debug, Clone, thiserror::Error)]
792#[error("session limit reached {0}")]
793pub struct ExceedsSessionLimit(pub(crate) u32);
794
795/// Starts a pending session authentication with a timeout.
796pub(crate) async fn pending_session_with_timeout<F, N: NetworkPrimitives>(
797    timeout: Duration,
798    session_id: SessionId,
799    remote_addr: SocketAddr,
800    direction: Direction,
801    events: mpsc::Sender<PendingSessionEvent<N>>,
802    f: F,
803) where
804    F: Future<Output = ()>,
805{
806    if tokio::time::timeout(timeout, f).await.is_err() {
807        trace!(target: "net::session", ?remote_addr, ?direction, "pending session timed out");
808        let event = PendingSessionEvent::Disconnected {
809            remote_addr,
810            session_id,
811            direction,
812            error: Some(PendingSessionHandshakeError::Timeout),
813        };
814        let _ = events.send(event).await;
815    }
816}
817
818/// Starts the authentication process for a connection initiated by a remote peer.
819///
820/// This will wait for the _incoming_ handshake request and answer it.
821#[allow(clippy::too_many_arguments)]
822pub(crate) async fn start_pending_incoming_session<N: NetworkPrimitives>(
823    handshake: Arc<dyn EthRlpxHandshake>,
824    disconnect_rx: oneshot::Receiver<()>,
825    session_id: SessionId,
826    stream: TcpStream,
827    events: mpsc::Sender<PendingSessionEvent<N>>,
828    remote_addr: SocketAddr,
829    secret_key: SecretKey,
830    hello: HelloMessageWithProtocols,
831    status: Status,
832    fork_filter: ForkFilter,
833    extra_handlers: RlpxSubProtocolHandlers,
834) {
835    authenticate(
836        handshake,
837        disconnect_rx,
838        events,
839        stream,
840        session_id,
841        remote_addr,
842        secret_key,
843        Direction::Incoming,
844        hello,
845        status,
846        fork_filter,
847        extra_handlers,
848    )
849    .await
850}
851
852/// Starts the authentication process for a connection initiated by a remote peer.
853#[instrument(skip_all, fields(%remote_addr, peer_id), target = "net")]
854#[allow(clippy::too_many_arguments)]
855async fn start_pending_outbound_session<N: NetworkPrimitives>(
856    handshake: Arc<dyn EthRlpxHandshake>,
857    disconnect_rx: oneshot::Receiver<()>,
858    events: mpsc::Sender<PendingSessionEvent<N>>,
859    session_id: SessionId,
860    remote_addr: SocketAddr,
861    remote_peer_id: PeerId,
862    secret_key: SecretKey,
863    hello: HelloMessageWithProtocols,
864    status: Status,
865    fork_filter: ForkFilter,
866    extra_handlers: RlpxSubProtocolHandlers,
867) {
868    let stream = match TcpStream::connect(remote_addr).await {
869        Ok(stream) => {
870            if let Err(err) = stream.set_nodelay(true) {
871                tracing::warn!(target: "net::session", "set nodelay failed: {:?}", err);
872            }
873            stream
874        }
875        Err(error) => {
876            let _ = events
877                .send(PendingSessionEvent::OutgoingConnectionError {
878                    remote_addr,
879                    session_id,
880                    peer_id: remote_peer_id,
881                    error,
882                })
883                .await;
884            return
885        }
886    };
887    authenticate(
888        handshake,
889        disconnect_rx,
890        events,
891        stream,
892        session_id,
893        remote_addr,
894        secret_key,
895        Direction::Outgoing(remote_peer_id),
896        hello,
897        status,
898        fork_filter,
899        extra_handlers,
900    )
901    .await
902}
903
904/// Authenticates a session
905#[allow(clippy::too_many_arguments)]
906async fn authenticate<N: NetworkPrimitives>(
907    handshake: Arc<dyn EthRlpxHandshake>,
908    disconnect_rx: oneshot::Receiver<()>,
909    events: mpsc::Sender<PendingSessionEvent<N>>,
910    stream: TcpStream,
911    session_id: SessionId,
912    remote_addr: SocketAddr,
913    secret_key: SecretKey,
914    direction: Direction,
915    hello: HelloMessageWithProtocols,
916    status: Status,
917    fork_filter: ForkFilter,
918    extra_handlers: RlpxSubProtocolHandlers,
919) {
920    let local_addr = stream.local_addr().ok();
921    let stream = match get_ecies_stream(stream, secret_key, direction).await {
922        Ok(stream) => stream,
923        Err(error) => {
924            let _ = events
925                .send(PendingSessionEvent::EciesAuthError {
926                    remote_addr,
927                    session_id,
928                    error,
929                    direction,
930                })
931                .await;
932            return
933        }
934    };
935
936    let unauthed = UnauthedP2PStream::new(stream);
937
938    let auth = authenticate_stream(
939        handshake,
940        unauthed,
941        session_id,
942        remote_addr,
943        local_addr,
944        direction,
945        hello,
946        status,
947        fork_filter,
948        extra_handlers,
949    )
950    .boxed();
951
952    match futures::future::select(disconnect_rx, auth).await {
953        Either::Left((_, _)) => {
954            let _ = events
955                .send(PendingSessionEvent::Disconnected {
956                    remote_addr,
957                    session_id,
958                    direction,
959                    error: None,
960                })
961                .await;
962        }
963        Either::Right((res, _)) => {
964            let _ = events.send(res).await;
965        }
966    }
967}
968
969/// Returns an [`ECIESStream`] if it can be built. If not, send a
970/// [`PendingSessionEvent::EciesAuthError`] and returns `None`
971async fn get_ecies_stream<Io: AsyncRead + AsyncWrite + Unpin>(
972    stream: Io,
973    secret_key: SecretKey,
974    direction: Direction,
975) -> Result<ECIESStream<Io>, ECIESError> {
976    match direction {
977        Direction::Incoming => ECIESStream::incoming(stream, secret_key).await,
978        Direction::Outgoing(remote_peer_id) => {
979            ECIESStream::connect(stream, secret_key, remote_peer_id).await
980        }
981    }
982}
983
984/// Authenticate the stream via handshake
985///
986/// On Success return the authenticated stream as [`PendingSessionEvent`].
987///
988/// If additional [`RlpxSubProtocolHandlers`] are provided, the hello message will be updated to
989/// also negotiate the additional protocols.
990#[allow(clippy::too_many_arguments)]
991async fn authenticate_stream<N: NetworkPrimitives>(
992    handshake: Arc<dyn EthRlpxHandshake>,
993    stream: UnauthedP2PStream<ECIESStream<TcpStream>>,
994    session_id: SessionId,
995    remote_addr: SocketAddr,
996    local_addr: Option<SocketAddr>,
997    direction: Direction,
998    mut hello: HelloMessageWithProtocols,
999    mut status: Status,
1000    fork_filter: ForkFilter,
1001    mut extra_handlers: RlpxSubProtocolHandlers,
1002) -> PendingSessionEvent<N> {
1003    // Add extra protocols to the hello message
1004    extra_handlers.retain(|handler| hello.try_add_protocol(handler.protocol()).is_ok());
1005
1006    // conduct the p2p rlpx handshake and return the rlpx authenticated stream
1007    let (mut p2p_stream, their_hello) = match stream.handshake(hello).await {
1008        Ok(stream_res) => stream_res,
1009        Err(err) => {
1010            return PendingSessionEvent::Disconnected {
1011                remote_addr,
1012                session_id,
1013                direction,
1014                error: Some(PendingSessionHandshakeError::Eth(err.into())),
1015            }
1016        }
1017    };
1018
1019    // if we have extra handlers, check if it must be supported by the remote
1020    if !extra_handlers.is_empty() {
1021        // ensure that no extra handlers that aren't supported are not mandatory
1022        while let Some(pos) = extra_handlers.iter().position(|handler| {
1023            p2p_stream
1024                .shared_capabilities()
1025                .ensure_matching_capability(&handler.protocol().cap)
1026                .is_err()
1027        }) {
1028            let handler = extra_handlers.remove(pos);
1029            if handler.on_unsupported_by_peer(
1030                p2p_stream.shared_capabilities(),
1031                direction,
1032                their_hello.id,
1033            ) == OnNotSupported::Disconnect
1034            {
1035                return PendingSessionEvent::Disconnected {
1036                    remote_addr,
1037                    session_id,
1038                    direction,
1039                    error: Some(PendingSessionHandshakeError::UnsupportedExtraCapability),
1040                };
1041            }
1042        }
1043    }
1044
1045    // Ensure we negotiated mandatory eth protocol
1046    let eth_version = match p2p_stream.shared_capabilities().eth_version() {
1047        Ok(version) => version,
1048        Err(err) => {
1049            return PendingSessionEvent::Disconnected {
1050                remote_addr,
1051                session_id,
1052                direction,
1053                error: Some(PendingSessionHandshakeError::Eth(err.into())),
1054            }
1055        }
1056    };
1057
1058    let (conn, their_status) = if p2p_stream.shared_capabilities().len() == 1 {
1059        // if the shared caps are 1, we know both support the eth version
1060        // if the hello handshake was successful we can try status handshake
1061        //
1062        // Before trying status handshake, set up the version to negotiated shared version
1063        status.set_eth_version(eth_version);
1064
1065        // perform the eth protocol handshake
1066        match handshake
1067            .handshake(&mut p2p_stream, status, fork_filter.clone(), HANDSHAKE_TIMEOUT)
1068            .await
1069        {
1070            Ok(their_status) => {
1071                let eth_stream = EthStream::new(status.version, p2p_stream);
1072                (eth_stream.into(), their_status)
1073            }
1074            Err(err) => {
1075                return PendingSessionEvent::Disconnected {
1076                    remote_addr,
1077                    session_id,
1078                    direction,
1079                    error: Some(PendingSessionHandshakeError::Eth(err)),
1080                }
1081            }
1082        }
1083    } else {
1084        // Multiplex the stream with the extra protocols
1085        let mut multiplex_stream = RlpxProtocolMultiplexer::new(p2p_stream);
1086
1087        // install additional handlers
1088        for handler in extra_handlers.into_iter() {
1089            let cap = handler.protocol().cap;
1090            let remote_peer_id = their_hello.id;
1091
1092            multiplex_stream
1093                .install_protocol(&cap, move |conn| {
1094                    handler.into_connection(direction, remote_peer_id, conn)
1095                })
1096                .ok();
1097        }
1098
1099        let (multiplex_stream, their_status) =
1100            match multiplex_stream.into_eth_satellite_stream(status, fork_filter).await {
1101                Ok((multiplex_stream, their_status)) => (multiplex_stream, their_status),
1102                Err(err) => {
1103                    return PendingSessionEvent::Disconnected {
1104                        remote_addr,
1105                        session_id,
1106                        direction,
1107                        error: Some(PendingSessionHandshakeError::Eth(err)),
1108                    }
1109                }
1110            };
1111
1112        (multiplex_stream.into(), their_status)
1113    };
1114
1115    PendingSessionEvent::Established {
1116        session_id,
1117        remote_addr,
1118        local_addr,
1119        peer_id: their_hello.id,
1120        capabilities: Arc::new(Capabilities::from(their_hello.capabilities)),
1121        status: Arc::new(their_status),
1122        conn,
1123        direction,
1124        client_id: their_hello.client_version,
1125    }
1126}