reth_tokio_util/ratelimit.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
//! A rate limit implementation to enforce a specific rate.
use std::{
future::{poll_fn, Future},
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::Sleep;
/// Given a [`Rate`] this type enforces a rate limit.
#[derive(Debug)]
pub struct RateLimit {
rate: Rate,
state: State,
sleep: Pin<Box<Sleep>>,
}
// === impl RateLimit ===
impl RateLimit {
/// Create a new rate limiter
pub fn new(rate: Rate) -> Self {
let until = tokio::time::Instant::now();
let state = State::Ready { until, remaining: rate.limit() };
Self { rate, state, sleep: Box::pin(tokio::time::sleep_until(until)) }
}
/// Returns the configured limit of the [`RateLimit`]
pub const fn limit(&self) -> u64 {
self.rate.limit()
}
/// Checks if the [`RateLimit`] is ready to handle a new call
pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
match self.state {
State::Ready { .. } => return Poll::Ready(()),
State::Limited => {
if Pin::new(&mut self.sleep).poll(cx).is_pending() {
return Poll::Pending
}
}
}
self.state = State::Ready {
until: tokio::time::Instant::now() + self.rate.duration(),
remaining: self.rate.limit(),
};
Poll::Ready(())
}
/// Wait until the [`RateLimit`] is ready.
pub async fn wait(&mut self) {
poll_fn(|cx| self.poll_ready(cx)).await
}
/// Updates the [`RateLimit`] when a new call was triggered
///
/// # Panics
///
/// Panics if [`RateLimit::poll_ready`] returned [`Poll::Pending`]
pub fn tick(&mut self) {
match self.state {
State::Ready { mut until, remaining: mut rem } => {
let now = tokio::time::Instant::now();
// If the period has elapsed, reset it.
if now >= until {
until = now + self.rate.duration();
rem = self.rate.limit();
}
if rem > 1 {
rem -= 1;
self.state = State::Ready { until, remaining: rem };
} else {
// rate limited until elapsed
self.sleep.as_mut().reset(until);
self.state = State::Limited;
}
}
State::Limited => panic!("RateLimit limited; poll_ready must be called first"),
}
}
}
/// Tracks the state of the [`RateLimit`]
#[derive(Debug)]
enum State {
/// Currently limited
Limited,
Ready {
until: tokio::time::Instant,
remaining: u64,
},
}
/// A rate of requests per time period.
#[derive(Debug, Copy, Clone)]
pub struct Rate {
limit: u64,
duration: Duration,
}
impl Rate {
/// Create a new [Rate] with the given `limit/duration` ratio.
pub const fn new(limit: u64, duration: Duration) -> Self {
Self { limit, duration }
}
const fn limit(&self) -> u64 {
self.limit
}
const fn duration(&self) -> Duration {
self.duration
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time;
#[tokio::test]
async fn test_rate_limit() {
let mut limit = RateLimit::new(Rate::new(2, Duration::from_millis(500)));
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
limit.tick();
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
limit.tick();
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_pending());
Poll::Ready(())
})
.await;
tokio::time::sleep(limit.rate.duration).await;
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
}
#[tokio::test]
async fn test_rate_limit_initialization() {
let rate = Rate::new(5, Duration::from_secs(1));
let limit = RateLimit::new(rate);
// Verify the limit is correctly set
assert_eq!(limit.limit(), 5);
}
#[tokio::test]
async fn test_rate_limit_allows_within_limit() {
let mut limit = RateLimit::new(Rate::new(3, Duration::from_millis(1)));
// Check that the rate limiter is ready initially
for _ in 0..3 {
poll_fn(|cx| {
// Should be ready within the limit
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
// Signal that a request has been made
limit.tick();
}
// After 3 requests, it should be pending (rate limit hit)
poll_fn(|cx| {
// Exceeded limit, should now be limited
assert!(limit.poll_ready(cx).is_pending());
Poll::Ready(())
})
.await;
}
#[tokio::test]
async fn test_rate_limit_enforces_wait_after_limit() {
let mut limit = RateLimit::new(Rate::new(2, Duration::from_millis(500)));
// Consume the limit
for _ in 0..2 {
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
limit.tick();
}
// Should now be limited (pending)
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_pending());
Poll::Ready(())
})
.await;
// Wait until the rate period elapses
time::sleep(limit.rate.duration()).await;
// Now it should be ready again after the wait
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
}
#[tokio::test]
async fn test_wait_method_awaits_readiness() {
let mut limit = RateLimit::new(Rate::new(1, Duration::from_millis(500)));
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
limit.tick();
// The limit should now be exceeded
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_pending());
Poll::Ready(())
})
.await;
// The `wait` method should block until the rate period elapses
limit.wait().await;
// After `wait`, it should now be ready
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
}
#[tokio::test]
#[should_panic(expected = "RateLimit limited; poll_ready must be called first")]
async fn test_tick_panics_when_limited() {
let mut limit = RateLimit::new(Rate::new(1, Duration::from_secs(1)));
poll_fn(|cx| {
assert!(limit.poll_ready(cx).is_ready());
Poll::Ready(())
})
.await;
// Consume the limit
limit.tick();
// Attempting to tick again without poll_ready being ready should panic
limit.tick();
}
}