1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//! A rate limit implementation to enforce a specific rate.

use std::{
    future::{poll_fn, Future},
    pin::Pin,
    task::{Context, Poll},
    time::Duration,
};
use tokio::time::Sleep;

/// Given a [Rate] this type enforces a rate limit.
#[derive(Debug)]
pub struct RateLimit {
    rate: Rate,
    state: State,
    sleep: Pin<Box<Sleep>>,
}

// === impl RateLimit ===

impl RateLimit {
    /// Create a new rate limiter
    pub fn new(rate: Rate) -> Self {
        let until = tokio::time::Instant::now();
        let state = State::Ready { until, remaining: rate.limit() };

        Self { rate, state, sleep: Box::pin(tokio::time::sleep_until(until)) }
    }

    /// Returns the configured limit of the [`RateLimit`]
    pub const fn limit(&self) -> u64 {
        self.rate.limit()
    }

    /// Checks if the [`RateLimit`] is ready to handle a new call
    pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
        match self.state {
            State::Ready { .. } => return Poll::Ready(()),
            State::Limited => {
                if Pin::new(&mut self.sleep).poll(cx).is_pending() {
                    return Poll::Pending
                }
            }
        }

        self.state = State::Ready {
            until: tokio::time::Instant::now() + self.rate.duration(),
            remaining: self.rate.limit(),
        };

        Poll::Ready(())
    }

    /// Wait until the [`RateLimit`] is ready.
    pub async fn wait(&mut self) {
        poll_fn(|cx| self.poll_ready(cx)).await
    }

    /// Updates the [`RateLimit`] when a new call was triggered
    ///
    /// # Panics
    ///
    /// Panics if [`RateLimit::poll_ready`] returned [`Poll::Pending`]
    pub fn tick(&mut self) {
        match self.state {
            State::Ready { mut until, remaining: mut rem } => {
                let now = tokio::time::Instant::now();

                // If the period has elapsed, reset it.
                if now >= until {
                    until = now + self.rate.duration();
                    rem = self.rate.limit();
                }

                if rem > 1 {
                    rem -= 1;
                    self.state = State::Ready { until, remaining: rem };
                } else {
                    // rate limited until elapsed
                    self.sleep.as_mut().reset(until);
                    self.state = State::Limited;
                }
            }
            State::Limited => panic!("RateLimit limited; poll_ready must be called first"),
        }
    }
}

/// Tracks the state of the [`RateLimit`]
#[derive(Debug)]
enum State {
    /// Currently limited
    Limited,
    Ready {
        until: tokio::time::Instant,
        remaining: u64,
    },
}

/// A rate of requests per time period.
#[derive(Debug, Copy, Clone)]
pub struct Rate {
    limit: u64,
    duration: Duration,
}

impl Rate {
    /// Create a new [Rate] with the given `limit/duration` ratio.
    pub const fn new(limit: u64, duration: Duration) -> Self {
        Self { limit, duration }
    }

    const fn limit(&self) -> u64 {
        self.limit
    }

    const fn duration(&self) -> Duration {
        self.duration
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_rate_limit() {
        let mut limit = RateLimit::new(Rate::new(2, Duration::from_millis(500)));

        poll_fn(|cx| {
            assert!(limit.poll_ready(cx).is_ready());
            Poll::Ready(())
        })
        .await;

        limit.tick();

        poll_fn(|cx| {
            assert!(limit.poll_ready(cx).is_ready());
            Poll::Ready(())
        })
        .await;

        limit.tick();

        poll_fn(|cx| {
            assert!(limit.poll_ready(cx).is_pending());
            Poll::Ready(())
        })
        .await;

        tokio::time::sleep(limit.rate.duration).await;

        poll_fn(|cx| {
            assert!(limit.poll_ready(cx).is_ready());
            Poll::Ready(())
        })
        .await;
    }
}