2023-06-21 08:28:20 +00:00
|
|
|
use enum_map::{enum_map, EnumMap};
|
|
|
|
use once_cell::sync::Lazy;
|
|
|
|
use std::{
|
|
|
|
collections::HashMap,
|
2023-07-10 20:52:37 +00:00
|
|
|
hash::Hash,
|
2023-06-21 08:28:20 +00:00
|
|
|
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
|
|
|
time::{Duration, Instant},
|
|
|
|
};
|
2021-11-23 12:16:47 +00:00
|
|
|
use tracing::debug;
|
2020-04-19 22:08:25 +00:00
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
|
|
|
|
|
|
|
|
static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
|
|
|
|
|
|
|
|
/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
|
|
|
|
/// store nanoseconds
|
|
|
|
#[derive(PartialEq, Debug, Clone, Copy)]
|
|
|
|
pub struct InstantSecs {
|
|
|
|
secs: u32,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl InstantSecs {
|
|
|
|
pub fn now() -> Self {
|
|
|
|
InstantSecs {
|
|
|
|
secs: u32::try_from(START_TIME.elapsed().as_secs())
|
|
|
|
.expect("server has been running for over 136 years"),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn secs_since(self, earlier: Self) -> u32 {
|
|
|
|
self.secs.saturating_sub(earlier.secs)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn to_instant(self) -> Instant {
|
|
|
|
*START_TIME + Duration::from_secs(self.secs.into())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(PartialEq, Debug, Clone)]
|
2020-11-16 15:44:04 +00:00
|
|
|
struct RateLimitBucket {
|
2023-06-21 08:28:20 +00:00
|
|
|
last_checked: InstantSecs,
|
|
|
|
/// This field stores the amount of tokens that were present at `last_checked`.
|
|
|
|
/// The amount of tokens steadily increases until it reaches the bucket's capacity.
|
|
|
|
/// Performing the rate-limited action consumes 1 token.
|
|
|
|
tokens: f32,
|
2020-04-19 22:08:25 +00:00
|
|
|
}
|
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
#[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
|
2020-11-16 15:44:04 +00:00
|
|
|
pub(crate) enum RateLimitType {
|
2020-04-19 22:08:25 +00:00
|
|
|
Message,
|
|
|
|
Register,
|
|
|
|
Post,
|
2020-08-05 16:00:00 +00:00
|
|
|
Image,
|
2021-11-11 20:40:25 +00:00
|
|
|
Comment,
|
2022-03-29 15:46:03 +00:00
|
|
|
Search,
|
2020-04-19 22:08:25 +00:00
|
|
|
}
|
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
|
|
|
|
|
|
|
|
#[derive(PartialEq, Debug, Clone)]
|
|
|
|
struct RateLimitedGroup<C> {
|
|
|
|
total: EnumMap<RateLimitType, RateLimitBucket>,
|
|
|
|
children: C,
|
2020-04-19 22:08:25 +00:00
|
|
|
}
|
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
impl<C: Default> RateLimitedGroup<C> {
|
|
|
|
fn new(now: InstantSecs) -> Self {
|
|
|
|
RateLimitedGroup {
|
|
|
|
total: enum_map! {
|
|
|
|
_ => RateLimitBucket {
|
|
|
|
last_checked: now,
|
|
|
|
tokens: UNINITIALIZED_TOKEN_AMOUNT,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
children: Default::default(),
|
|
|
|
}
|
|
|
|
}
|
2020-04-19 22:08:25 +00:00
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
fn check_total(
|
|
|
|
&mut self,
|
|
|
|
type_: RateLimitType,
|
|
|
|
now: InstantSecs,
|
|
|
|
capacity: i32,
|
|
|
|
secs_to_refill: i32,
|
|
|
|
) -> bool {
|
|
|
|
let capacity = capacity as f32;
|
|
|
|
let secs_to_refill = secs_to_refill as f32;
|
|
|
|
|
|
|
|
#[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
|
|
|
|
let bucket = &mut self.total[type_];
|
|
|
|
|
|
|
|
if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT {
|
|
|
|
bucket.tokens = capacity;
|
|
|
|
}
|
|
|
|
|
|
|
|
let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32;
|
|
|
|
bucket.last_checked = now;
|
|
|
|
|
|
|
|
// For `secs_since_last_checked` seconds, increase `bucket.tokens`
|
|
|
|
// by `capacity` every `secs_to_refill` seconds
|
|
|
|
bucket.tokens += {
|
|
|
|
let tokens_per_sec = capacity / secs_to_refill;
|
|
|
|
secs_since_last_checked * tokens_per_sec
|
|
|
|
};
|
|
|
|
|
|
|
|
// Prevent `bucket.tokens` from exceeding `capacity`
|
|
|
|
if bucket.tokens > capacity {
|
|
|
|
bucket.tokens = capacity;
|
|
|
|
}
|
|
|
|
|
|
|
|
if bucket.tokens < 1.0 {
|
|
|
|
// Not enough tokens yet
|
|
|
|
debug!(
|
|
|
|
"Rate limited type: {}, time_passed: {}, allowance: {}",
|
|
|
|
type_.as_ref(),
|
|
|
|
secs_since_last_checked,
|
|
|
|
bucket.tokens
|
|
|
|
);
|
|
|
|
false
|
|
|
|
} else {
|
|
|
|
// Consume 1 token
|
|
|
|
bucket.tokens -= 1.0;
|
|
|
|
true
|
2020-04-19 22:08:25 +00:00
|
|
|
}
|
|
|
|
}
|
2023-06-21 08:28:20 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Rate limiting based on rate type and IP addr
|
|
|
|
#[derive(PartialEq, Debug, Clone, Default)]
|
|
|
|
pub struct RateLimitStorage {
|
|
|
|
/// One bucket per individual IPv4 address
|
|
|
|
ipv4_buckets: Map<Ipv4Addr, ()>,
|
|
|
|
/// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses
|
|
|
|
ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
|
|
|
|
}
|
2020-04-19 22:08:25 +00:00
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
impl RateLimitStorage {
|
2022-03-18 15:26:16 +00:00
|
|
|
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
|
2022-03-25 15:41:38 +00:00
|
|
|
///
|
|
|
|
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
|
2020-04-20 03:59:07 +00:00
|
|
|
pub(super) fn check_rate_limit_full(
|
2020-04-19 22:08:25 +00:00
|
|
|
&mut self,
|
|
|
|
type_: RateLimitType,
|
2023-06-21 08:28:20 +00:00
|
|
|
ip: IpAddr,
|
|
|
|
capacity: i32,
|
|
|
|
secs_to_refill: i32,
|
|
|
|
now: InstantSecs,
|
2022-03-25 15:41:38 +00:00
|
|
|
) -> bool {
|
2023-06-21 08:28:20 +00:00
|
|
|
let mut result = true;
|
|
|
|
|
|
|
|
match ip {
|
|
|
|
IpAddr::V4(ipv4) => {
|
|
|
|
// Only used by one address.
|
|
|
|
let group = self
|
|
|
|
.ipv4_buckets
|
|
|
|
.entry(ipv4)
|
|
|
|
.or_insert(RateLimitedGroup::new(now));
|
|
|
|
|
|
|
|
result &= group.check_total(type_, now, capacity, secs_to_refill);
|
2020-04-19 22:08:25 +00:00
|
|
|
}
|
2023-06-21 08:28:20 +00:00
|
|
|
|
|
|
|
IpAddr::V6(ipv6) => {
|
|
|
|
let (key_48, key_56, key_64) = split_ipv6(ipv6);
|
|
|
|
|
|
|
|
// Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
|
|
|
|
let group_48 = self
|
|
|
|
.ipv6_buckets
|
|
|
|
.entry(key_48)
|
|
|
|
.or_insert(RateLimitedGroup::new(now));
|
|
|
|
result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
|
|
|
|
|
|
|
|
// Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
|
|
|
|
let group_56 = group_48
|
|
|
|
.children
|
|
|
|
.entry(key_56)
|
|
|
|
.or_insert(RateLimitedGroup::new(now));
|
|
|
|
result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
|
|
|
|
|
|
|
|
// A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
|
|
|
|
let group_64 = group_56
|
|
|
|
.children
|
|
|
|
.entry(key_64)
|
|
|
|
.or_insert(RateLimitedGroup::new(now));
|
|
|
|
|
|
|
|
result &= group_64.check_total(type_, now, capacity, secs_to_refill);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
if !result {
|
|
|
|
debug!("Rate limited IP: {ip}");
|
2020-04-19 22:08:25 +00:00
|
|
|
}
|
2023-06-21 08:28:20 +00:00
|
|
|
|
|
|
|
result
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Remove buckets older than the given duration
|
|
|
|
pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
|
|
|
|
// Only retain buckets that were last used after `instant`
|
2023-07-03 09:45:53 +00:00
|
|
|
let Some(instant) = now.to_instant().checked_sub(duration) else {
|
|
|
|
return;
|
|
|
|
};
|
2023-06-21 08:28:20 +00:00
|
|
|
|
|
|
|
let is_recently_used = |group: &RateLimitedGroup<_>| {
|
|
|
|
group
|
|
|
|
.total
|
|
|
|
.values()
|
|
|
|
.all(|bucket| bucket.last_checked.to_instant() > instant)
|
|
|
|
};
|
|
|
|
|
2023-07-10 20:52:37 +00:00
|
|
|
retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group));
|
2023-06-21 08:28:20 +00:00
|
|
|
|
2023-07-10 20:52:37 +00:00
|
|
|
retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| {
|
|
|
|
retain_and_shrink(&mut group_48.children, |_, group_56| {
|
|
|
|
retain_and_shrink(&mut group_56.children, |_, group_64| {
|
|
|
|
is_recently_used(group_64)
|
|
|
|
});
|
2023-06-21 08:28:20 +00:00
|
|
|
!group_56.children.is_empty()
|
|
|
|
});
|
|
|
|
!group_48.children.is_empty()
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-07-10 20:52:37 +00:00
|
|
|
fn retain_and_shrink<K, V, F>(map: &mut HashMap<K, V>, f: F)
|
|
|
|
where
|
|
|
|
K: Eq + Hash,
|
|
|
|
F: FnMut(&K, &mut V) -> bool,
|
|
|
|
{
|
|
|
|
map.retain(f);
|
|
|
|
map.shrink_to_fit();
|
|
|
|
}
|
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
|
|
|
|
let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
|
|
|
|
([a0, a1, a2, a3, a4, a5], b, c)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
2023-07-17 15:04:14 +00:00
|
|
|
#![allow(clippy::unwrap_used)]
|
|
|
|
#![allow(clippy::indexing_slicing)]
|
|
|
|
|
2023-06-21 08:28:20 +00:00
|
|
|
#[test]
|
|
|
|
fn test_split_ipv6() {
|
|
|
|
let ip = std::net::Ipv6Addr::new(
|
|
|
|
0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
|
|
|
|
);
|
|
|
|
assert_eq!(
|
|
|
|
super::split_ipv6(ip),
|
|
|
|
([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_rate_limiter() {
|
|
|
|
let mut rate_limiter = super::RateLimitStorage::default();
|
|
|
|
let mut now = super::InstantSecs::now();
|
|
|
|
|
|
|
|
let ips = [
|
|
|
|
"123.123.123.123",
|
|
|
|
"1:2:3::",
|
|
|
|
"1:2:3:0400::",
|
|
|
|
"1:2:3:0405::",
|
|
|
|
"1:2:3:0405:6::",
|
|
|
|
];
|
|
|
|
for ip in ips {
|
|
|
|
let ip = ip.parse().unwrap();
|
|
|
|
let message_passed =
|
|
|
|
rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
|
|
|
|
let post_passed =
|
|
|
|
rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
|
|
|
|
assert!(message_passed);
|
|
|
|
assert!(post_passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[allow(clippy::indexing_slicing)]
|
|
|
|
let expected_buckets = |factor: f32, tokens_consumed: f32| {
|
|
|
|
let mut buckets = super::RateLimitedGroup::<()>::new(now).total;
|
|
|
|
buckets[super::RateLimitType::Message] = super::RateLimitBucket {
|
|
|
|
last_checked: now,
|
|
|
|
tokens: (2.0 * factor) - tokens_consumed,
|
|
|
|
};
|
|
|
|
buckets[super::RateLimitType::Post] = super::RateLimitBucket {
|
|
|
|
last_checked: now,
|
|
|
|
tokens: (3.0 * factor) - tokens_consumed,
|
|
|
|
};
|
|
|
|
buckets
|
|
|
|
};
|
|
|
|
|
|
|
|
let bottom_group = |tokens_consumed| super::RateLimitedGroup {
|
|
|
|
total: expected_buckets(1.0, tokens_consumed),
|
|
|
|
children: (),
|
|
|
|
};
|
|
|
|
|
|
|
|
assert_eq!(
|
|
|
|
rate_limiter,
|
|
|
|
super::RateLimitStorage {
|
|
|
|
ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(),
|
|
|
|
ipv6_buckets: [(
|
|
|
|
[0, 1, 0, 2, 0, 3],
|
|
|
|
super::RateLimitedGroup {
|
|
|
|
total: expected_buckets(16.0, 4.0),
|
|
|
|
children: [
|
|
|
|
(
|
|
|
|
0,
|
|
|
|
super::RateLimitedGroup {
|
|
|
|
total: expected_buckets(4.0, 1.0),
|
|
|
|
children: [(0, bottom_group(1.0)),].into(),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
(
|
|
|
|
4,
|
|
|
|
super::RateLimitedGroup {
|
|
|
|
total: expected_buckets(4.0, 3.0),
|
|
|
|
children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
]
|
|
|
|
.into(),
|
|
|
|
}
|
|
|
|
),]
|
|
|
|
.into(),
|
|
|
|
}
|
|
|
|
);
|
|
|
|
|
|
|
|
now.secs += 2;
|
|
|
|
rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now);
|
|
|
|
assert!(rate_limiter.ipv4_buckets.is_empty());
|
|
|
|
assert!(rate_limiter.ipv6_buckets.is_empty());
|
2020-04-19 22:08:25 +00:00
|
|
|
}
|
|
|
|
}
|