use futures_util::{
future::{FusedFuture, Shared},
FutureExt,
};
use std::{
future::Future,
pin::Pin,
sync::{atomic::AtomicUsize, Arc},
task::{ready, Context, Poll},
};
use tokio::sync::oneshot;
#[derive(Debug)]
pub struct GracefulShutdown {
shutdown: Shutdown,
guard: Option<GracefulShutdownGuard>,
}
impl GracefulShutdown {
pub(crate) const fn new(shutdown: Shutdown, guard: GracefulShutdownGuard) -> Self {
Self { shutdown, guard: Some(guard) }
}
pub fn ignore_guard(self) -> impl Future<Output = ()> + Send + Sync + Unpin + 'static {
self.map(drop)
}
}
impl Future for GracefulShutdown {
type Output = GracefulShutdownGuard;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
ready!(self.shutdown.poll_unpin(cx));
Poll::Ready(self.get_mut().guard.take().expect("Future polled after completion"))
}
}
impl Clone for GracefulShutdown {
fn clone(&self) -> Self {
Self {
shutdown: self.shutdown.clone(),
guard: self.guard.as_ref().map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))),
}
}
}
#[derive(Debug)]
#[must_use = "if unused the task will not be gracefully shutdown"]
pub struct GracefulShutdownGuard(Arc<AtomicUsize>);
impl GracefulShutdownGuard {
pub(crate) fn new(counter: Arc<AtomicUsize>) -> Self {
counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self(counter)
}
}
impl Drop for GracefulShutdownGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
}
}
#[derive(Debug, Clone)]
pub struct Shutdown(Shared<oneshot::Receiver<()>>);
impl Future for Shutdown {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let pin = self.get_mut();
if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct Signal(oneshot::Sender<()>);
impl Signal {
pub fn fire(self) {
let _ = self.0.send(());
}
}
pub fn signal() -> (Signal, Shutdown) {
let (sender, receiver) = oneshot::channel();
(Signal(sender), Shutdown(receiver.shared()))
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::future::join_all;
use std::time::Duration;
#[tokio::test(flavor = "multi_thread")]
async fn test_shutdown() {
let (_signal, _shutdown) = signal();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_drop_signal() {
let (signal, shutdown) = signal();
tokio::task::spawn(async move {
tokio::time::sleep(Duration::from_millis(500)).await;
drop(signal)
});
shutdown.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_multi_shutdowns() {
let (signal, shutdown) = signal();
let mut tasks = Vec::with_capacity(100);
for _ in 0..100 {
let shutdown = shutdown.clone();
let task = tokio::task::spawn(async move {
shutdown.await;
});
tasks.push(task);
}
drop(signal);
join_all(tasks).await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_drop_signal_from_thread() {
let (signal, shutdown) = signal();
let _thread = std::thread::spawn(|| {
std::thread::sleep(Duration::from_millis(500));
drop(signal)
});
shutdown.await;
}
}