reth_tokio_util/
ratelimit.rs

1//! A rate limit implementation to enforce a specific rate.
2
3use std::{
4    future::{poll_fn, Future},
5    pin::Pin,
6    task::{Context, Poll},
7    time::Duration,
8};
9use tokio::time::Sleep;
10
11/// Given a [`Rate`] this type enforces a rate limit.
12#[derive(Debug)]
13pub struct RateLimit {
14    rate: Rate,
15    state: State,
16    sleep: Pin<Box<Sleep>>,
17}
18
19// === impl RateLimit ===
20
21impl RateLimit {
22    /// Create a new rate limiter
23    pub fn new(rate: Rate) -> Self {
24        let until = tokio::time::Instant::now();
25        let state = State::Ready { until, remaining: rate.limit() };
26
27        Self { rate, state, sleep: Box::pin(tokio::time::sleep_until(until)) }
28    }
29
30    /// Returns the configured limit of the [`RateLimit`]
31    pub const fn limit(&self) -> u64 {
32        self.rate.limit()
33    }
34
35    /// Checks if the [`RateLimit`] is ready to handle a new call
36    pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
37        match self.state {
38            State::Ready { .. } => return Poll::Ready(()),
39            State::Limited => {
40                if Pin::new(&mut self.sleep).poll(cx).is_pending() {
41                    return Poll::Pending
42                }
43            }
44        }
45
46        self.state = State::Ready {
47            until: tokio::time::Instant::now() + self.rate.duration(),
48            remaining: self.rate.limit(),
49        };
50
51        Poll::Ready(())
52    }
53
54    /// Wait until the [`RateLimit`] is ready.
55    pub async fn wait(&mut self) {
56        poll_fn(|cx| self.poll_ready(cx)).await
57    }
58
59    /// Updates the [`RateLimit`] when a new call was triggered
60    ///
61    /// # Panics
62    ///
63    /// Panics if [`RateLimit::poll_ready`] returned [`Poll::Pending`]
64    pub fn tick(&mut self) {
65        match self.state {
66            State::Ready { mut until, remaining: mut rem } => {
67                let now = tokio::time::Instant::now();
68
69                // If the period has elapsed, reset it.
70                if now >= until {
71                    until = now + self.rate.duration();
72                    rem = self.rate.limit();
73                }
74
75                if rem > 1 {
76                    rem -= 1;
77                    self.state = State::Ready { until, remaining: rem };
78                } else {
79                    // rate limited until elapsed
80                    self.sleep.as_mut().reset(until);
81                    self.state = State::Limited;
82                }
83            }
84            State::Limited => panic!("RateLimit limited; poll_ready must be called first"),
85        }
86    }
87}
88
89/// Tracks the state of the [`RateLimit`]
90#[derive(Debug)]
91enum State {
92    /// Currently limited
93    Limited,
94    Ready {
95        until: tokio::time::Instant,
96        remaining: u64,
97    },
98}
99
100/// A rate of requests per time period.
101#[derive(Debug, Copy, Clone)]
102pub struct Rate {
103    limit: u64,
104    duration: Duration,
105}
106
107impl Rate {
108    /// Create a new [Rate] with the given `limit/duration` ratio.
109    pub const fn new(limit: u64, duration: Duration) -> Self {
110        Self { limit, duration }
111    }
112
113    const fn limit(&self) -> u64 {
114        self.limit
115    }
116
117    const fn duration(&self) -> Duration {
118        self.duration
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use tokio::time;
126
127    #[tokio::test]
128    async fn test_rate_limit() {
129        let mut limit = RateLimit::new(Rate::new(2, Duration::from_millis(500)));
130
131        poll_fn(|cx| {
132            assert!(limit.poll_ready(cx).is_ready());
133            Poll::Ready(())
134        })
135        .await;
136
137        limit.tick();
138
139        poll_fn(|cx| {
140            assert!(limit.poll_ready(cx).is_ready());
141            Poll::Ready(())
142        })
143        .await;
144
145        limit.tick();
146
147        poll_fn(|cx| {
148            assert!(limit.poll_ready(cx).is_pending());
149            Poll::Ready(())
150        })
151        .await;
152
153        tokio::time::sleep(limit.rate.duration).await;
154
155        poll_fn(|cx| {
156            assert!(limit.poll_ready(cx).is_ready());
157            Poll::Ready(())
158        })
159        .await;
160    }
161
162    #[tokio::test]
163    async fn test_rate_limit_initialization() {
164        let rate = Rate::new(5, Duration::from_secs(1));
165        let limit = RateLimit::new(rate);
166
167        // Verify the limit is correctly set
168        assert_eq!(limit.limit(), 5);
169    }
170
171    #[tokio::test]
172    async fn test_rate_limit_allows_within_limit() {
173        let mut limit = RateLimit::new(Rate::new(3, Duration::from_millis(1)));
174
175        // Check that the rate limiter is ready initially
176        for _ in 0..3 {
177            poll_fn(|cx| {
178                // Should be ready within the limit
179                assert!(limit.poll_ready(cx).is_ready());
180                Poll::Ready(())
181            })
182            .await;
183            // Signal that a request has been made
184            limit.tick();
185        }
186
187        // After 3 requests, it should be pending (rate limit hit)
188        poll_fn(|cx| {
189            // Exceeded limit, should now be limited
190            assert!(limit.poll_ready(cx).is_pending());
191            Poll::Ready(())
192        })
193        .await;
194    }
195
196    #[tokio::test]
197    async fn test_rate_limit_enforces_wait_after_limit() {
198        let mut limit = RateLimit::new(Rate::new(2, Duration::from_millis(500)));
199
200        // Consume the limit
201        for _ in 0..2 {
202            poll_fn(|cx| {
203                assert!(limit.poll_ready(cx).is_ready());
204                Poll::Ready(())
205            })
206            .await;
207            limit.tick();
208        }
209
210        // Should now be limited (pending)
211        poll_fn(|cx| {
212            assert!(limit.poll_ready(cx).is_pending());
213            Poll::Ready(())
214        })
215        .await;
216
217        // Wait until the rate period elapses
218        time::sleep(limit.rate.duration()).await;
219
220        // Now it should be ready again after the wait
221        poll_fn(|cx| {
222            assert!(limit.poll_ready(cx).is_ready());
223            Poll::Ready(())
224        })
225        .await;
226    }
227
228    #[tokio::test]
229    async fn test_wait_method_awaits_readiness() {
230        let mut limit = RateLimit::new(Rate::new(1, Duration::from_millis(500)));
231
232        poll_fn(|cx| {
233            assert!(limit.poll_ready(cx).is_ready());
234            Poll::Ready(())
235        })
236        .await;
237
238        limit.tick();
239
240        // The limit should now be exceeded
241        poll_fn(|cx| {
242            assert!(limit.poll_ready(cx).is_pending());
243            Poll::Ready(())
244        })
245        .await;
246
247        // The `wait` method should block until the rate period elapses
248        limit.wait().await;
249
250        // After `wait`, it should now be ready
251        poll_fn(|cx| {
252            assert!(limit.poll_ready(cx).is_ready());
253            Poll::Ready(())
254        })
255        .await;
256    }
257
258    #[tokio::test]
259    #[should_panic(expected = "RateLimit limited; poll_ready must be called first")]
260    async fn test_tick_panics_when_limited() {
261        let mut limit = RateLimit::new(Rate::new(1, Duration::from_secs(1)));
262
263        poll_fn(|cx| {
264            assert!(limit.poll_ready(cx).is_ready());
265            Poll::Ready(())
266        })
267        .await;
268
269        // Consume the limit
270        limit.tick();
271
272        // Attempting to tick again without poll_ready being ready should panic
273        limit.tick();
274    }
275}