reth_network/
listener.rs
1use futures::{ready, Stream, StreamExt};
4use std::{
5 io,
6 net::SocketAddr,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use tokio::net::{TcpListener, TcpStream};
11
12#[must_use = "Transport does nothing unless polled."]
16#[derive(Debug)]
17pub struct ConnectionListener {
18 local_address: SocketAddr,
20 incoming: TcpListenerStream,
22}
23
24impl ConnectionListener {
25 pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
27 let listener = TcpListener::bind(addr).await?;
28 let local_addr = listener.local_addr()?;
29 Ok(Self::new(listener, local_addr))
30 }
31
32 pub(crate) const fn new(listener: TcpListener, local_address: SocketAddr) -> Self {
34 Self { local_address, incoming: TcpListenerStream { inner: listener } }
35 }
36
37 pub fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<ListenerEvent> {
39 let this = self.get_mut();
40 match ready!(this.incoming.poll_next_unpin(cx)) {
41 Some(Ok((stream, remote_addr))) => {
42 if let Err(err) = stream.set_nodelay(true) {
43 tracing::warn!(target: "net", "set nodelay failed: {:?}", err);
44 }
45 Poll::Ready(ListenerEvent::Incoming { stream, remote_addr })
46 }
47 Some(Err(err)) => Poll::Ready(ListenerEvent::Error(err)),
48 None => {
49 Poll::Ready(ListenerEvent::ListenerClosed { local_address: this.local_address })
50 }
51 }
52 }
53
54 pub const fn local_address(&self) -> SocketAddr {
56 self.local_address
57 }
58}
59
60pub enum ListenerEvent {
62 Incoming {
64 stream: TcpStream,
66 remote_addr: SocketAddr,
68 },
69 ListenerClosed {
73 local_address: SocketAddr,
75 },
76 Error(io::Error),
81}
82
83#[derive(Debug)]
85struct TcpListenerStream {
86 inner: TcpListener,
88}
89
90impl Stream for TcpListenerStream {
91 type Item = io::Result<(TcpStream, SocketAddr)>;
92
93 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
94 match self.inner.poll_accept(cx) {
95 Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))),
96 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
97 Poll::Pending => Poll::Pending,
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use std::{
106 net::{Ipv4Addr, SocketAddrV4},
107 pin::pin,
108 };
109 use tokio::macros::support::poll_fn;
110
111 #[tokio::test(flavor = "multi_thread")]
112 async fn test_incoming_listener() {
113 let listener =
114 ConnectionListener::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
115 .await
116 .unwrap();
117 let local_addr = listener.local_address();
118
119 tokio::task::spawn(async move {
120 let mut listener = pin!(listener);
121 match poll_fn(|cx| listener.as_mut().poll(cx)).await {
122 ListenerEvent::Incoming { .. } => {}
123 _ => {
124 panic!("unexpected event")
125 }
126 }
127 });
128
129 let _ = TcpStream::connect(local_addr).await.unwrap();
130 }
131}