use futures::{ready, Stream};
use std::{
io,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use tokio::net::{TcpListener, TcpStream};
#[must_use = "Transport does nothing unless polled."]
#[pin_project::pin_project]
#[derive(Debug)]
pub struct ConnectionListener {
local_address: SocketAddr,
#[pin]
incoming: TcpListenerStream,
}
impl ConnectionListener {
pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
Ok(Self::new(listener, local_addr))
}
pub(crate) const fn new(listener: TcpListener, local_address: SocketAddr) -> Self {
Self { local_address, incoming: TcpListenerStream { inner: listener } }
}
pub fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<ListenerEvent> {
let this = self.project();
match ready!(this.incoming.poll_next(cx)) {
Some(Ok((stream, remote_addr))) => {
if let Err(err) = stream.set_nodelay(true) {
tracing::warn!(target: "net", "set nodelay failed: {:?}", err);
}
Poll::Ready(ListenerEvent::Incoming { stream, remote_addr })
}
Some(Err(err)) => Poll::Ready(ListenerEvent::Error(err)),
None => {
Poll::Ready(ListenerEvent::ListenerClosed { local_address: *this.local_address })
}
}
}
pub const fn local_address(&self) -> SocketAddr {
self.local_address
}
}
pub enum ListenerEvent {
Incoming {
stream: TcpStream,
remote_addr: SocketAddr,
},
ListenerClosed {
local_address: SocketAddr,
},
Error(io::Error),
}
#[derive(Debug)]
struct TcpListenerStream {
inner: TcpListener,
}
impl Stream for TcpListenerStream {
type Item = io::Result<(TcpStream, SocketAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.poll_accept(cx) {
Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
net::{Ipv4Addr, SocketAddrV4},
pin::pin,
};
use tokio::macros::support::poll_fn;
#[tokio::test(flavor = "multi_thread")]
async fn test_incoming_listener() {
let listener =
ConnectionListener::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
.await
.unwrap();
let local_addr = listener.local_address();
tokio::task::spawn(async move {
let mut listener = pin!(listener);
match poll_fn(|cx| listener.as_mut().poll(cx)).await {
ListenerEvent::Incoming { .. } => {}
_ => {
panic!("unexpected event")
}
}
});
let _ = TcpStream::connect(local_addr).await.unwrap();
}
}