132 lines
3.3 KiB
Rust
132 lines
3.3 KiB
Rust
|
use super::*;
|
||
|
|
||
|
#[derive(Debug, Clone)]
|
||
|
pub struct RateLimitBucket {
|
||
|
last_checked: SystemTime,
|
||
|
allowance: f64,
|
||
|
}
|
||
|
|
||
|
#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone)]
|
||
|
pub enum RateLimitType {
|
||
|
Message,
|
||
|
Register,
|
||
|
Post,
|
||
|
}
|
||
|
|
||
|
/// Rate limiting based on rate type and IP addr
|
||
|
#[derive(Debug, Clone)]
|
||
|
pub struct RateLimiter {
|
||
|
pub buckets: HashMap<RateLimitType, HashMap<IPAddr, RateLimitBucket>>,
|
||
|
}
|
||
|
|
||
|
impl Default for RateLimiter {
|
||
|
fn default() -> Self {
|
||
|
Self {
|
||
|
buckets: HashMap::new(),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl RateLimiter {
|
||
|
fn insert_ip(&mut self, ip: &str) {
|
||
|
for rate_limit_type in RateLimitType::iter() {
|
||
|
if self.buckets.get(&rate_limit_type).is_none() {
|
||
|
self.buckets.insert(rate_limit_type, HashMap::new());
|
||
|
}
|
||
|
|
||
|
if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
|
||
|
if bucket.get(ip).is_none() {
|
||
|
bucket.insert(
|
||
|
ip.to_string(),
|
||
|
RateLimitBucket {
|
||
|
last_checked: SystemTime::now(),
|
||
|
allowance: -2f64,
|
||
|
},
|
||
|
);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub fn check_rate_limit_register(&mut self, ip: &str, check_only: bool) -> Result<(), Error> {
|
||
|
self.check_rate_limit_full(
|
||
|
RateLimitType::Register,
|
||
|
ip,
|
||
|
Settings::get().rate_limit.register,
|
||
|
Settings::get().rate_limit.register_per_second,
|
||
|
check_only,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
pub fn check_rate_limit_post(&mut self, ip: &str, check_only: bool) -> Result<(), Error> {
|
||
|
self.check_rate_limit_full(
|
||
|
RateLimitType::Post,
|
||
|
ip,
|
||
|
Settings::get().rate_limit.post,
|
||
|
Settings::get().rate_limit.post_per_second,
|
||
|
check_only,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
pub fn check_rate_limit_message(&mut self, ip: &str, check_only: bool) -> Result<(), Error> {
|
||
|
self.check_rate_limit_full(
|
||
|
RateLimitType::Message,
|
||
|
ip,
|
||
|
Settings::get().rate_limit.message,
|
||
|
Settings::get().rate_limit.message_per_second,
|
||
|
check_only,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
#[allow(clippy::float_cmp)]
|
||
|
fn check_rate_limit_full(
|
||
|
&mut self,
|
||
|
type_: RateLimitType,
|
||
|
ip: &str,
|
||
|
rate: i32,
|
||
|
per: i32,
|
||
|
check_only: bool,
|
||
|
) -> Result<(), Error> {
|
||
|
self.insert_ip(ip);
|
||
|
if let Some(bucket) = self.buckets.get_mut(&type_) {
|
||
|
if let Some(rate_limit) = bucket.get_mut(ip) {
|
||
|
let current = SystemTime::now();
|
||
|
let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
|
||
|
|
||
|
// The initial value
|
||
|
if rate_limit.allowance == -2f64 {
|
||
|
rate_limit.allowance = rate as f64;
|
||
|
};
|
||
|
|
||
|
rate_limit.last_checked = current;
|
||
|
rate_limit.allowance += time_passed * (rate as f64 / per as f64);
|
||
|
if !check_only && rate_limit.allowance > rate as f64 {
|
||
|
rate_limit.allowance = rate as f64;
|
||
|
}
|
||
|
|
||
|
if rate_limit.allowance < 1.0 {
|
||
|
warn!(
|
||
|
"Rate limited IP: {}, time_passed: {}, allowance: {}",
|
||
|
ip, time_passed, rate_limit.allowance
|
||
|
);
|
||
|
Err(
|
||
|
APIError {
|
||
|
message: format!("Too many requests. {} per {} seconds", rate, per),
|
||
|
}
|
||
|
.into(),
|
||
|
)
|
||
|
} else {
|
||
|
if !check_only {
|
||
|
rate_limit.allowance -= 1.0;
|
||
|
}
|
||
|
Ok(())
|
||
|
}
|
||
|
} else {
|
||
|
Ok(())
|
||
|
}
|
||
|
} else {
|
||
|
Ok(())
|
||
|
}
|
||
|
}
|
||
|
}
|