Created
June 14, 2024 08:49
-
-
Save boyswan/c4948c81835a18a97fc4b80fd678814d to your computer and use it in GitHub Desktop.
limiter.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures::Future; | |
use std::sync::Arc; | |
use tokio::sync::Notify; | |
use tokio::sync::Semaphore; | |
use tokio::task::spawn; | |
use tokio::time::{interval, Duration}; | |
#[derive(Clone)] | |
pub struct Limiter { | |
capacity: usize, | |
sem: Arc<Semaphore>, | |
notify: Arc<Notify>, | |
} | |
impl Limiter { | |
pub fn new(capacity: usize) -> Self { | |
Self { | |
capacity, | |
sem: Arc::new(Semaphore::new(capacity)), | |
notify: Arc::new(Notify::new()), | |
} | |
} | |
pub fn available_permits(&self) -> usize { | |
self.sem.available_permits() | |
} | |
pub fn replenish(&self) { | |
self.sem.forget_permits(self.capacity); | |
self.sem.add_permits(self.capacity); | |
self.notify.notify_waiters(); | |
} | |
pub async fn acquire(&self, amount: usize) { | |
if self.sem.available_permits() < amount { | |
self.notify.notified().await; | |
} | |
self.sem.acquire_many(amount as u32).await.unwrap().forget(); | |
} | |
} | |
#[derive(Clone)] | |
pub struct ChatGptLimiter { | |
request_limiter: Limiter, | |
token_limiter: Limiter, | |
} | |
pub struct ChatGptLimiterConfig { | |
pub interval_duration: Duration, | |
pub limit_requests: usize, | |
pub limit_tokens: usize, | |
} | |
impl ChatGptLimiter { | |
pub fn new(config: ChatGptLimiterConfig) -> Self { | |
let token_limiter = Limiter::new(config.limit_tokens); | |
let request_limiter = Limiter::new(config.limit_requests); | |
spawn({ | |
let token_limiter = token_limiter.clone(); | |
let request_limiter = request_limiter.clone(); | |
let mut interval = interval(config.interval_duration); | |
async move { | |
loop { | |
interval.tick().await; | |
request_limiter.replenish(); | |
token_limiter.replenish(); | |
} | |
} | |
}); | |
Self { | |
token_limiter, | |
request_limiter, | |
} | |
} | |
pub async fn acquire<T>(&self, tokens: usize, fut: impl Future<Output = T>) -> T { | |
let a = self.token_limiter.acquire(tokens); | |
let b = self.request_limiter.acquire(1); | |
tokio::join!(a, b); | |
fut.await | |
} | |
pub fn available_tokens(&self) -> usize { | |
self.token_limiter.available_permits() | |
} | |
pub fn available_requests(&self) -> usize { | |
self.request_limiter.available_permits() | |
} | |
} | |
mod tests { | |
use super::ChatGptLimiter; | |
use super::ChatGptLimiterConfig; | |
use futures::Future; | |
use tokio::time; | |
use tokio::time::Duration; | |
async fn sleep(s: u64) { | |
time::sleep(Duration::from_millis(s)).await; | |
} | |
async fn acquire(limiter: &ChatGptLimiter, s: usize) { | |
limiter.acquire(s, sleep(0)).await; | |
} | |
#[tokio::test(flavor = "multi_thread")] | |
pub async fn works() { | |
let limiter = ChatGptLimiter::new(ChatGptLimiterConfig { | |
interval_duration: Duration::from_millis(500), | |
limit_requests: 5, | |
limit_tokens: 15_000, | |
}); | |
sleep(100).await; | |
// ...Interval 1 | |
acquire(&limiter, 5000).await; | |
acquire(&limiter, 5000).await; | |
assert_eq!(limiter.available_tokens(), 5000); | |
assert_eq!(limiter.available_requests(), 3); | |
// ...Interval 2 | |
acquire(&limiter, 9999).await; | |
// This passes | |
assert_eq!(limiter.available_tokens(), 5001); | |
// This fails with 5 | |
assert_eq!(limiter.available_requests(), 4); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Futures can only exist in an interval if there are available permits for both tokens and requests. If not available, the future should be deferred into the next interval. No interval can consume more than the allocated limits.
The 9999 acquisition is "pushed" into interval 2, as there are not enough permits remaining in interval 1. This works, and the available tokens in interval2 are updated, but the available requests are not updating accordingly.