reth_ress_protocol/
handlers.rs

1use crate::{
2    connection::{RessPeerRequest, RessProtocolConnection},
3    NodeType, RessProtocolMessage, RessProtocolProvider,
4};
5use reth_eth_wire::{
6    capability::SharedCapabilities, multiplex::ProtocolConnection, protocol::Protocol,
7};
8use reth_network::protocol::{ConnectionHandler, OnNotSupported, ProtocolHandler};
9use reth_network_api::{test_utils::PeersHandle, Direction, PeerId};
10use std::{
11    fmt,
12    net::SocketAddr,
13    sync::{
14        atomic::{AtomicU64, Ordering},
15        Arc,
16    },
17};
18use tokio::sync::mpsc;
19use tokio_stream::wrappers::UnboundedReceiverStream;
20use tracing::*;
21
22/// The events that can be emitted by our custom protocol.
23#[derive(Debug)]
24pub enum ProtocolEvent {
25    /// Connection established.
26    Established {
27        /// Connection direction.
28        direction: Direction,
29        /// Peer ID.
30        peer_id: PeerId,
31        /// Sender part for forwarding commands.
32        to_connection: mpsc::UnboundedSender<RessPeerRequest>,
33    },
34    /// Number of max active connections exceeded. New connection was rejected.
35    MaxActiveConnectionsExceeded {
36        /// The current number
37        num_active: u64,
38    },
39}
40
41/// Protocol state is an helper struct to store the protocol events.
42#[derive(Clone, Debug)]
43pub struct ProtocolState {
44    /// Protocol event sender.
45    pub events_sender: mpsc::UnboundedSender<ProtocolEvent>,
46    /// The number of active connections.
47    pub active_connections: Arc<AtomicU64>,
48}
49
50impl ProtocolState {
51    /// Create new protocol state.
52    pub fn new(events_sender: mpsc::UnboundedSender<ProtocolEvent>) -> Self {
53        Self { events_sender, active_connections: Arc::default() }
54    }
55
56    /// Returns the current number of active connections.
57    pub fn active_connections(&self) -> u64 {
58        self.active_connections.load(Ordering::Relaxed)
59    }
60}
61
62/// The protocol handler takes care of incoming and outgoing connections.
63#[derive(Clone)]
64pub struct RessProtocolHandler<P> {
65    /// Provider.
66    pub provider: P,
67    /// Node type.
68    pub node_type: NodeType,
69    /// Peers handle.
70    pub peers_handle: PeersHandle,
71    /// The maximum number of active connections.
72    pub max_active_connections: u64,
73    /// Current state of the protocol.
74    pub state: ProtocolState,
75}
76
77impl<P> fmt::Debug for RessProtocolHandler<P> {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_struct("RessProtocolHandler")
80            .field("node_type", &self.node_type)
81            .field("peers_handle", &self.peers_handle)
82            .field("max_active_connections", &self.max_active_connections)
83            .field("state", &self.state)
84            .finish_non_exhaustive()
85    }
86}
87
88impl<P> ProtocolHandler for RessProtocolHandler<P>
89where
90    P: RessProtocolProvider + Clone + Unpin + 'static,
91{
92    type ConnectionHandler = Self;
93
94    fn on_incoming(&self, socket_addr: SocketAddr) -> Option<Self::ConnectionHandler> {
95        let num_active = self.state.active_connections();
96        if num_active >= self.max_active_connections {
97            trace!(
98                target: "ress::net",
99                num_active, max_connections = self.max_active_connections, %socket_addr,
100                "ignoring incoming connection, max active reached"
101            );
102            let _ = self
103                .state
104                .events_sender
105                .send(ProtocolEvent::MaxActiveConnectionsExceeded { num_active });
106            None
107        } else {
108            Some(self.clone())
109        }
110    }
111
112    fn on_outgoing(
113        &self,
114        socket_addr: SocketAddr,
115        peer_id: PeerId,
116    ) -> Option<Self::ConnectionHandler> {
117        let num_active = self.state.active_connections();
118        if num_active >= self.max_active_connections {
119            trace!(
120                target: "ress::net",
121                num_active, max_connections = self.max_active_connections, %socket_addr, %peer_id,
122                "ignoring outgoing connection, max active reached"
123            );
124            let _ = self
125                .state
126                .events_sender
127                .send(ProtocolEvent::MaxActiveConnectionsExceeded { num_active });
128            None
129        } else {
130            Some(self.clone())
131        }
132    }
133}
134
135impl<P> ConnectionHandler for RessProtocolHandler<P>
136where
137    P: RessProtocolProvider + Clone + Unpin + 'static,
138{
139    type Connection = RessProtocolConnection<P>;
140
141    fn protocol(&self) -> Protocol {
142        RessProtocolMessage::protocol()
143    }
144
145    fn on_unsupported_by_peer(
146        self,
147        _supported: &SharedCapabilities,
148        _direction: Direction,
149        _peer_id: PeerId,
150    ) -> OnNotSupported {
151        if self.node_type.is_stateful() {
152            OnNotSupported::KeepAlive
153        } else {
154            OnNotSupported::Disconnect
155        }
156    }
157
158    fn into_connection(
159        self,
160        direction: Direction,
161        peer_id: PeerId,
162        conn: ProtocolConnection,
163    ) -> Self::Connection {
164        let (tx, rx) = mpsc::unbounded_channel();
165
166        // Emit connection established event.
167        self.state
168            .events_sender
169            .send(ProtocolEvent::Established { direction, peer_id, to_connection: tx })
170            .ok();
171
172        // Increment the number of active sessions.
173        self.state.active_connections.fetch_add(1, Ordering::Relaxed);
174
175        RessProtocolConnection::new(
176            self.provider.clone(),
177            self.node_type,
178            self.peers_handle,
179            peer_id,
180            conn,
181            UnboundedReceiverStream::from(rx),
182            self.state.active_connections,
183        )
184    }
185}