reth_ress_protocol/
handlers.rs
1use crate::{
2 connection::{RessPeerRequest, RessProtocolConnection},
3 NodeType, RessProtocolMessage, RessProtocolProvider,
4};
5use reth_eth_wire::{
6 capability::SharedCapabilities, multiplex::ProtocolConnection, protocol::Protocol,
7};
8use reth_network::protocol::{ConnectionHandler, OnNotSupported, ProtocolHandler};
9use reth_network_api::{test_utils::PeersHandle, Direction, PeerId};
10use std::{
11 fmt,
12 net::SocketAddr,
13 sync::{
14 atomic::{AtomicU64, Ordering},
15 Arc,
16 },
17};
18use tokio::sync::mpsc;
19use tokio_stream::wrappers::UnboundedReceiverStream;
20use tracing::*;
21
22#[derive(Debug)]
24pub enum ProtocolEvent {
25 Established {
27 direction: Direction,
29 peer_id: PeerId,
31 to_connection: mpsc::UnboundedSender<RessPeerRequest>,
33 },
34 MaxActiveConnectionsExceeded {
36 num_active: u64,
38 },
39}
40
41#[derive(Clone, Debug)]
43pub struct ProtocolState {
44 pub events_sender: mpsc::UnboundedSender<ProtocolEvent>,
46 pub active_connections: Arc<AtomicU64>,
48}
49
50impl ProtocolState {
51 pub fn new(events_sender: mpsc::UnboundedSender<ProtocolEvent>) -> Self {
53 Self { events_sender, active_connections: Arc::default() }
54 }
55
56 pub fn active_connections(&self) -> u64 {
58 self.active_connections.load(Ordering::Relaxed)
59 }
60}
61
62#[derive(Clone)]
64pub struct RessProtocolHandler<P> {
65 pub provider: P,
67 pub node_type: NodeType,
69 pub peers_handle: PeersHandle,
71 pub max_active_connections: u64,
73 pub state: ProtocolState,
75}
76
77impl<P> fmt::Debug for RessProtocolHandler<P> {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_struct("RessProtocolHandler")
80 .field("node_type", &self.node_type)
81 .field("peers_handle", &self.peers_handle)
82 .field("max_active_connections", &self.max_active_connections)
83 .field("state", &self.state)
84 .finish_non_exhaustive()
85 }
86}
87
88impl<P> ProtocolHandler for RessProtocolHandler<P>
89where
90 P: RessProtocolProvider + Clone + Unpin + 'static,
91{
92 type ConnectionHandler = Self;
93
94 fn on_incoming(&self, socket_addr: SocketAddr) -> Option<Self::ConnectionHandler> {
95 let num_active = self.state.active_connections();
96 if num_active >= self.max_active_connections {
97 trace!(
98 target: "ress::net",
99 num_active, max_connections = self.max_active_connections, %socket_addr,
100 "ignoring incoming connection, max active reached"
101 );
102 let _ = self
103 .state
104 .events_sender
105 .send(ProtocolEvent::MaxActiveConnectionsExceeded { num_active });
106 None
107 } else {
108 Some(self.clone())
109 }
110 }
111
112 fn on_outgoing(
113 &self,
114 socket_addr: SocketAddr,
115 peer_id: PeerId,
116 ) -> Option<Self::ConnectionHandler> {
117 let num_active = self.state.active_connections();
118 if num_active >= self.max_active_connections {
119 trace!(
120 target: "ress::net",
121 num_active, max_connections = self.max_active_connections, %socket_addr, %peer_id,
122 "ignoring outgoing connection, max active reached"
123 );
124 let _ = self
125 .state
126 .events_sender
127 .send(ProtocolEvent::MaxActiveConnectionsExceeded { num_active });
128 None
129 } else {
130 Some(self.clone())
131 }
132 }
133}
134
135impl<P> ConnectionHandler for RessProtocolHandler<P>
136where
137 P: RessProtocolProvider + Clone + Unpin + 'static,
138{
139 type Connection = RessProtocolConnection<P>;
140
141 fn protocol(&self) -> Protocol {
142 RessProtocolMessage::protocol()
143 }
144
145 fn on_unsupported_by_peer(
146 self,
147 _supported: &SharedCapabilities,
148 _direction: Direction,
149 _peer_id: PeerId,
150 ) -> OnNotSupported {
151 if self.node_type.is_stateful() {
152 OnNotSupported::KeepAlive
153 } else {
154 OnNotSupported::Disconnect
155 }
156 }
157
158 fn into_connection(
159 self,
160 direction: Direction,
161 peer_id: PeerId,
162 conn: ProtocolConnection,
163 ) -> Self::Connection {
164 let (tx, rx) = mpsc::unbounded_channel();
165
166 self.state
168 .events_sender
169 .send(ProtocolEvent::Established { direction, peer_id, to_connection: tx })
170 .ok();
171
172 self.state.active_connections.fetch_add(1, Ordering::Relaxed);
174
175 RessProtocolConnection::new(
176 self.provider.clone(),
177 self.node_type,
178 self.peers_handle,
179 peer_id,
180 conn,
181 UnboundedReceiverStream::from(rx),
182 self.state.active_connections,
183 )
184 }
185}