reth_tokio_util/
ratelimit.rs
1use std::{
4 future::{poll_fn, Future},
5 pin::Pin,
6 task::{Context, Poll},
7 time::Duration,
8};
9use tokio::time::Sleep;
10
11#[derive(Debug)]
13pub struct RateLimit {
14 rate: Rate,
15 state: State,
16 sleep: Pin<Box<Sleep>>,
17}
18
19impl RateLimit {
22 pub fn new(rate: Rate) -> Self {
24 let until = tokio::time::Instant::now();
25 let state = State::Ready { until, remaining: rate.limit() };
26
27 Self { rate, state, sleep: Box::pin(tokio::time::sleep_until(until)) }
28 }
29
30 pub const fn limit(&self) -> u64 {
32 self.rate.limit()
33 }
34
35 pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
37 match self.state {
38 State::Ready { .. } => return Poll::Ready(()),
39 State::Limited => {
40 if Pin::new(&mut self.sleep).poll(cx).is_pending() {
41 return Poll::Pending
42 }
43 }
44 }
45
46 self.state = State::Ready {
47 until: tokio::time::Instant::now() + self.rate.duration(),
48 remaining: self.rate.limit(),
49 };
50
51 Poll::Ready(())
52 }
53
54 pub async fn wait(&mut self) {
56 poll_fn(|cx| self.poll_ready(cx)).await
57 }
58
59 pub fn tick(&mut self) {
65 match self.state {
66 State::Ready { mut until, remaining: mut rem } => {
67 let now = tokio::time::Instant::now();
68
69 if now >= until {
71 until = now + self.rate.duration();
72 rem = self.rate.limit();
73 }
74
75 if rem > 1 {
76 rem -= 1;
77 self.state = State::Ready { until, remaining: rem };
78 } else {
79 self.sleep.as_mut().reset(until);
81 self.state = State::Limited;
82 }
83 }
84 State::Limited => panic!("RateLimit limited; poll_ready must be called first"),
85 }
86 }
87}
88
89#[derive(Debug)]
91enum State {
92 Limited,
94 Ready {
95 until: tokio::time::Instant,
96 remaining: u64,
97 },
98}
99
100#[derive(Debug, Copy, Clone)]
102pub struct Rate {
103 limit: u64,
104 duration: Duration,
105}
106
107impl Rate {
108 pub const fn new(limit: u64, duration: Duration) -> Self {
110 Self { limit, duration }
111 }
112
113 const fn limit(&self) -> u64 {
114 self.limit
115 }
116
117 const fn duration(&self) -> Duration {
118 self.duration
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use tokio::time;
126
127 #[tokio::test]
128 async fn test_rate_limit() {
129 let mut limit = RateLimit::new(Rate::new(2, Duration::from_millis(500)));
130
131 poll_fn(|cx| {
132 assert!(limit.poll_ready(cx).is_ready());
133 Poll::Ready(())
134 })
135 .await;
136
137 limit.tick();
138
139 poll_fn(|cx| {
140 assert!(limit.poll_ready(cx).is_ready());
141 Poll::Ready(())
142 })
143 .await;
144
145 limit.tick();
146
147 poll_fn(|cx| {
148 assert!(limit.poll_ready(cx).is_pending());
149 Poll::Ready(())
150 })
151 .await;
152
153 tokio::time::sleep(limit.rate.duration).await;
154
155 poll_fn(|cx| {
156 assert!(limit.poll_ready(cx).is_ready());
157 Poll::Ready(())
158 })
159 .await;
160 }
161
162 #[tokio::test]
163 async fn test_rate_limit_initialization() {
164 let rate = Rate::new(5, Duration::from_secs(1));
165 let limit = RateLimit::new(rate);
166
167 assert_eq!(limit.limit(), 5);
169 }
170
171 #[tokio::test]
172 async fn test_rate_limit_allows_within_limit() {
173 let mut limit = RateLimit::new(Rate::new(3, Duration::from_millis(1)));
174
175 for _ in 0..3 {
177 poll_fn(|cx| {
178 assert!(limit.poll_ready(cx).is_ready());
180 Poll::Ready(())
181 })
182 .await;
183 limit.tick();
185 }
186
187 poll_fn(|cx| {
189 assert!(limit.poll_ready(cx).is_pending());
191 Poll::Ready(())
192 })
193 .await;
194 }
195
196 #[tokio::test]
197 async fn test_rate_limit_enforces_wait_after_limit() {
198 let mut limit = RateLimit::new(Rate::new(2, Duration::from_millis(500)));
199
200 for _ in 0..2 {
202 poll_fn(|cx| {
203 assert!(limit.poll_ready(cx).is_ready());
204 Poll::Ready(())
205 })
206 .await;
207 limit.tick();
208 }
209
210 poll_fn(|cx| {
212 assert!(limit.poll_ready(cx).is_pending());
213 Poll::Ready(())
214 })
215 .await;
216
217 time::sleep(limit.rate.duration()).await;
219
220 poll_fn(|cx| {
222 assert!(limit.poll_ready(cx).is_ready());
223 Poll::Ready(())
224 })
225 .await;
226 }
227
228 #[tokio::test]
229 async fn test_wait_method_awaits_readiness() {
230 let mut limit = RateLimit::new(Rate::new(1, Duration::from_millis(500)));
231
232 poll_fn(|cx| {
233 assert!(limit.poll_ready(cx).is_ready());
234 Poll::Ready(())
235 })
236 .await;
237
238 limit.tick();
239
240 poll_fn(|cx| {
242 assert!(limit.poll_ready(cx).is_pending());
243 Poll::Ready(())
244 })
245 .await;
246
247 limit.wait().await;
249
250 poll_fn(|cx| {
252 assert!(limit.poll_ready(cx).is_ready());
253 Poll::Ready(())
254 })
255 .await;
256 }
257
258 #[tokio::test]
259 #[should_panic(expected = "RateLimit limited; poll_ready must be called first")]
260 async fn test_tick_panics_when_limited() {
261 let mut limit = RateLimit::new(Rate::new(1, Duration::from_secs(1)));
262
263 poll_fn(|cx| {
264 assert!(limit.poll_ready(cx).is_ready());
265 Poll::Ready(())
266 })
267 .await;
268
269 limit.tick();
271
272 limit.tick();
274 }
275}