1use crate::errors::PingerError;
2use std::{
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll},
6 time::Duration,
7};
8use tokio::time::{Instant, Interval, Sleep};
9use tokio_stream::Stream;
10
11#[derive(Debug)]
14pub(crate) struct Pinger {
15 ping_interval: Interval,
17 timeout_timer: Pin<Box<Sleep>>,
19 timeout: Duration,
21 state: PingState,
23}
24
25impl Pinger {
28 pub(crate) fn new(ping_interval: Duration, timeout_duration: Duration) -> Self {
31 let now = Instant::now();
32 let timeout_timer = tokio::time::sleep(timeout_duration);
33 Self {
34 state: PingState::Ready,
35 ping_interval: tokio::time::interval_at(now + ping_interval, ping_interval),
36 timeout_timer: Box::pin(timeout_timer),
37 timeout: timeout_duration,
38 }
39 }
40
41 pub(crate) fn on_pong(&mut self) -> Result<(), PingerError> {
44 match self.state {
45 PingState::Ready => Err(PingerError::UnexpectedPong),
46 PingState::WaitingForPong => {
47 self.state = PingState::Ready;
48 self.ping_interval.reset();
49 Ok(())
50 }
51 PingState::TimedOut => {
52 self.state = PingState::Ready;
55 self.ping_interval.reset();
56 Ok(())
57 }
58 }
59 }
60
61 pub(crate) const fn state(&self) -> PingState {
63 self.state
64 }
65
66 pub(crate) fn poll_ping(
69 &mut self,
70 cx: &mut Context<'_>,
71 ) -> Poll<Result<PingerEvent, PingerError>> {
72 match self.state() {
73 PingState::Ready => {
74 if self.ping_interval.poll_tick(cx).is_ready() {
75 self.timeout_timer.as_mut().reset(Instant::now() + self.timeout);
76 self.state = PingState::WaitingForPong;
77 return Poll::Ready(Ok(PingerEvent::Ping))
78 }
79 }
80 PingState::WaitingForPong => {
81 if self.timeout_timer.as_mut().poll(cx).is_ready() {
82 self.state = PingState::TimedOut;
83 return Poll::Ready(Ok(PingerEvent::Timeout))
84 }
85 }
86 PingState::TimedOut => {
87 return Poll::Pending
90 }
91 };
92 Poll::Pending
93 }
94}
95
96impl Stream for Pinger {
97 type Item = Result<PingerEvent, PingerError>;
98
99 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 self.get_mut().poll_ping(cx).map(Some)
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub(crate) enum PingState {
107 Ready,
110 WaitingForPong,
112 TimedOut,
114}
115
116#[derive(Debug, Clone, PartialEq, Eq)]
120pub(crate) enum PingerEvent {
121 Ping,
123
124 Timeout,
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use futures::StreamExt;
132
133 #[tokio::test]
134 async fn test_ping_timeout() {
135 let interval = Duration::from_millis(300);
136 let mut pinger = Pinger::new(interval, Duration::from_millis(20));
138 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
139 pinger.on_pong().unwrap();
140 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
141
142 tokio::time::sleep(interval).await;
143 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Timeout);
144 pinger.on_pong().unwrap();
145
146 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
147 }
148}