reth_network/
listener.rs

1//! Contains connection-oriented interfaces.
2
3use futures::{ready, Stream, StreamExt};
4use std::{
5    io,
6    net::SocketAddr,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use tokio::net::{TcpListener, TcpStream};
11
12/// A tcp connection listener.
13///
14/// Listens for incoming connections.
15#[must_use = "Transport does nothing unless polled."]
16#[derive(Debug)]
17pub struct ConnectionListener {
18    /// Local address of the listener stream.
19    local_address: SocketAddr,
20    /// The active tcp listener for incoming connections.
21    incoming: TcpListenerStream,
22}
23
24impl ConnectionListener {
25    /// Creates a new [`TcpListener`] that listens for incoming connections.
26    pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
27        let listener = TcpListener::bind(addr).await?;
28        let local_addr = listener.local_addr()?;
29        Ok(Self::new(listener, local_addr))
30    }
31
32    /// Creates a new connection listener stream.
33    pub(crate) const fn new(listener: TcpListener, local_address: SocketAddr) -> Self {
34        Self { local_address, incoming: TcpListenerStream { inner: listener } }
35    }
36
37    /// Polls the type to make progress.
38    pub fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<ListenerEvent> {
39        let this = self.get_mut();
40        match ready!(this.incoming.poll_next_unpin(cx)) {
41            Some(Ok((stream, remote_addr))) => {
42                if let Err(err) = stream.set_nodelay(true) {
43                    tracing::warn!(target: "net", "set nodelay failed: {:?}", err);
44                }
45                Poll::Ready(ListenerEvent::Incoming { stream, remote_addr })
46            }
47            Some(Err(err)) => Poll::Ready(ListenerEvent::Error(err)),
48            None => {
49                Poll::Ready(ListenerEvent::ListenerClosed { local_address: this.local_address })
50            }
51        }
52    }
53
54    /// Returns the socket address this listener listens on.
55    pub const fn local_address(&self) -> SocketAddr {
56        self.local_address
57    }
58}
59
60/// Event type produced by the [`TcpListenerStream`].
61pub enum ListenerEvent {
62    /// Received a new incoming.
63    Incoming {
64        /// Accepted connection
65        stream: TcpStream,
66        /// Address of the remote peer.
67        remote_addr: SocketAddr,
68    },
69    /// Returned when the underlying connection listener has been closed.
70    ///
71    /// This is the case if the [`TcpListenerStream`] should ever return `None`
72    ListenerClosed {
73        /// Address of the closed listener.
74        local_address: SocketAddr,
75    },
76    /// Encountered an error when accepting a connection.
77    ///
78    /// This is a non-fatal error as the listener continues to listen for new connections to
79    /// accept.
80    Error(io::Error),
81}
82
83/// A stream of incoming [`TcpStream`]s.
84#[derive(Debug)]
85struct TcpListenerStream {
86    /// listener for incoming connections.
87    inner: TcpListener,
88}
89
90impl Stream for TcpListenerStream {
91    type Item = io::Result<(TcpStream, SocketAddr)>;
92
93    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
94        match self.inner.poll_accept(cx) {
95            Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))),
96            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
97            Poll::Pending => Poll::Pending,
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use std::{
106        net::{Ipv4Addr, SocketAddrV4},
107        pin::pin,
108    };
109    use tokio::macros::support::poll_fn;
110
111    #[tokio::test(flavor = "multi_thread")]
112    async fn test_incoming_listener() {
113        let listener =
114            ConnectionListener::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
115                .await
116                .unwrap();
117        let local_addr = listener.local_address();
118
119        tokio::task::spawn(async move {
120            let mut listener = pin!(listener);
121            match poll_fn(|cx| listener.as_mut().poll(cx)).await {
122                ListenerEvent::Incoming { .. } => {}
123                _ => {
124                    panic!("unexpected event")
125                }
126            }
127        });
128
129        let _ = TcpStream::connect(local_addr).await.unwrap();
130    }
131}