reth_tasks/
shutdown.rs
1use futures_util::{
4 future::{FusedFuture, Shared},
5 FutureExt,
6};
7use std::{
8 future::Future,
9 pin::Pin,
10 sync::{atomic::AtomicUsize, Arc},
11 task::{ready, Context, Poll},
12};
13use tokio::sync::oneshot;
14
15#[derive(Debug)]
17pub struct GracefulShutdown {
18 shutdown: Shutdown,
19 guard: Option<GracefulShutdownGuard>,
20}
21
22impl GracefulShutdown {
23 pub(crate) const fn new(shutdown: Shutdown, guard: GracefulShutdownGuard) -> Self {
24 Self { shutdown, guard: Some(guard) }
25 }
26
27 pub fn ignore_guard(self) -> impl Future<Output = ()> + Send + Sync + Unpin + 'static {
31 self.map(drop)
32 }
33}
34
35impl Future for GracefulShutdown {
36 type Output = GracefulShutdownGuard;
37
38 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
39 ready!(self.shutdown.poll_unpin(cx));
40 Poll::Ready(self.get_mut().guard.take().expect("Future polled after completion"))
41 }
42}
43
44impl Clone for GracefulShutdown {
45 fn clone(&self) -> Self {
46 Self {
47 shutdown: self.shutdown.clone(),
48 guard: self.guard.as_ref().map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))),
49 }
50 }
51}
52
53#[derive(Debug)]
56#[must_use = "if unused the task will not be gracefully shutdown"]
57pub struct GracefulShutdownGuard(Arc<AtomicUsize>);
58
59impl GracefulShutdownGuard {
60 pub(crate) fn new(counter: Arc<AtomicUsize>) -> Self {
61 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
62 Self(counter)
63 }
64}
65
66impl Drop for GracefulShutdownGuard {
67 fn drop(&mut self) {
68 self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct Shutdown(Shared<oneshot::Receiver<()>>);
75
76impl Future for Shutdown {
77 type Output = ();
78
79 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
80 let pin = self.get_mut();
81 if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
82 Poll::Ready(())
83 } else {
84 Poll::Pending
85 }
86 }
87}
88
89#[derive(Debug)]
91pub struct Signal(oneshot::Sender<()>);
92
93impl Signal {
94 pub fn fire(self) {
96 let _ = self.0.send(());
97 }
98}
99
100pub fn signal() -> (Signal, Shutdown) {
102 let (sender, receiver) = oneshot::channel();
103 (Signal(sender), Shutdown(receiver.shared()))
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use futures_util::future::join_all;
110 use std::time::Duration;
111
112 #[tokio::test(flavor = "multi_thread")]
113 async fn test_shutdown() {
114 let (_signal, _shutdown) = signal();
115 }
116
117 #[tokio::test(flavor = "multi_thread")]
118 async fn test_drop_signal() {
119 let (signal, shutdown) = signal();
120
121 tokio::task::spawn(async move {
122 tokio::time::sleep(Duration::from_millis(500)).await;
123 drop(signal)
124 });
125
126 shutdown.await;
127 }
128
129 #[tokio::test(flavor = "multi_thread")]
130 async fn test_multi_shutdowns() {
131 let (signal, shutdown) = signal();
132
133 let mut tasks = Vec::with_capacity(100);
134 for _ in 0..100 {
135 let shutdown = shutdown.clone();
136 let task = tokio::task::spawn(async move {
137 shutdown.await;
138 });
139 tasks.push(task);
140 }
141
142 drop(signal);
143
144 join_all(tasks).await;
145 }
146
147 #[tokio::test(flavor = "multi_thread")]
148 async fn test_drop_signal_from_thread() {
149 let (signal, shutdown) = signal();
150
151 let _thread = std::thread::spawn(|| {
152 std::thread::sleep(Duration::from_millis(500));
153 drop(signal)
154 });
155
156 shutdown.await;
157 }
158}