reth_tasks/
shutdown.rs

1//! Helper for shutdown signals
2
3use futures_util::{
4    future::{FusedFuture, Shared},
5    FutureExt,
6};
7use std::{
8    future::Future,
9    pin::Pin,
10    sync::{atomic::AtomicUsize, Arc},
11    task::{ready, Context, Poll},
12};
13use tokio::sync::oneshot;
14
15/// A Future that resolves when the shutdown event has been fired.
16#[derive(Debug)]
17pub struct GracefulShutdown {
18    shutdown: Shutdown,
19    guard: Option<GracefulShutdownGuard>,
20}
21
22impl GracefulShutdown {
23    pub(crate) const fn new(shutdown: Shutdown, guard: GracefulShutdownGuard) -> Self {
24        Self { shutdown, guard: Some(guard) }
25    }
26
27    /// Returns a new shutdown future that is ignores the returned [`GracefulShutdownGuard`].
28    ///
29    /// This just maps the return value of the future to `()`, it does not drop the guard.
30    pub fn ignore_guard(self) -> impl Future<Output = ()> + Send + Sync + Unpin + 'static {
31        self.map(drop)
32    }
33}
34
35impl Future for GracefulShutdown {
36    type Output = GracefulShutdownGuard;
37
38    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
39        ready!(self.shutdown.poll_unpin(cx));
40        Poll::Ready(self.get_mut().guard.take().expect("Future polled after completion"))
41    }
42}
43
44impl Clone for GracefulShutdown {
45    fn clone(&self) -> Self {
46        Self {
47            shutdown: self.shutdown.clone(),
48            guard: self.guard.as_ref().map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))),
49        }
50    }
51}
52
53/// A guard that fires once dropped to signal the [`TaskManager`](crate::TaskManager) that the
54/// [`GracefulShutdown`] has completed.
55#[derive(Debug)]
56#[must_use = "if unused the task will not be gracefully shutdown"]
57pub struct GracefulShutdownGuard(Arc<AtomicUsize>);
58
59impl GracefulShutdownGuard {
60    pub(crate) fn new(counter: Arc<AtomicUsize>) -> Self {
61        counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
62        Self(counter)
63    }
64}
65
66impl Drop for GracefulShutdownGuard {
67    fn drop(&mut self) {
68        self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
69    }
70}
71
72/// A Future that resolves when the shutdown event has been fired.
73#[derive(Debug, Clone)]
74pub struct Shutdown(Shared<oneshot::Receiver<()>>);
75
76impl Future for Shutdown {
77    type Output = ();
78
79    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
80        let pin = self.get_mut();
81        if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
82            Poll::Ready(())
83        } else {
84            Poll::Pending
85        }
86    }
87}
88
89/// Shutdown signal that fires either manually or on drop by closing the channel
90#[derive(Debug)]
91pub struct Signal(oneshot::Sender<()>);
92
93impl Signal {
94    /// Fire the signal manually.
95    pub fn fire(self) {
96        let _ = self.0.send(());
97    }
98}
99
100/// Create a channel pair that's used to propagate shutdown event
101pub fn signal() -> (Signal, Shutdown) {
102    let (sender, receiver) = oneshot::channel();
103    (Signal(sender), Shutdown(receiver.shared()))
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use futures_util::future::join_all;
110    use std::time::Duration;
111
112    #[tokio::test(flavor = "multi_thread")]
113    async fn test_shutdown() {
114        let (_signal, _shutdown) = signal();
115    }
116
117    #[tokio::test(flavor = "multi_thread")]
118    async fn test_drop_signal() {
119        let (signal, shutdown) = signal();
120
121        tokio::task::spawn(async move {
122            tokio::time::sleep(Duration::from_millis(500)).await;
123            drop(signal)
124        });
125
126        shutdown.await;
127    }
128
129    #[tokio::test(flavor = "multi_thread")]
130    async fn test_multi_shutdowns() {
131        let (signal, shutdown) = signal();
132
133        let mut tasks = Vec::with_capacity(100);
134        for _ in 0..100 {
135            let shutdown = shutdown.clone();
136            let task = tokio::task::spawn(async move {
137                shutdown.await;
138            });
139            tasks.push(task);
140        }
141
142        drop(signal);
143
144        join_all(tasks).await;
145    }
146
147    #[tokio::test(flavor = "multi_thread")]
148    async fn test_drop_signal_from_thread() {
149        let (signal, shutdown) = signal();
150
151        let _thread = std::thread::spawn(|| {
152            std::thread::sleep(Duration::from_millis(500));
153            drop(signal)
154        });
155
156        shutdown.await;
157    }
158}