SleeplessOne1917 1ef375e5c3
Replace 3rd party once_cell crate with recently stabilized standard library API (#4919)
* Replace 3rd party once_cell crate with recently stabilized standard library API

* Bump CI rust version
2024-07-30 10:11:39 -04:00

444 lines
13 KiB

use enum_map::EnumMap;
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
use strum::{AsRefStr, Display};
use tracing::debug;
static START_TIME: LazyLock<Instant> = LazyLock::new(Instant::now);
/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
/// store nanoseconds
#[derive(PartialEq, Debug, Clone, Copy)]
pub struct InstantSecs {
secs: u32,
impl InstantSecs {
pub fn now() -> Self {
InstantSecs {
secs: u32::try_from(START_TIME.elapsed().as_secs())
.expect("server has been running for over 136 years"),
#[derive(PartialEq, Debug, Clone, Copy)]
struct Bucket {
last_checked: InstantSecs,
/// This field stores the amount of tokens that were present at `last_checked`.
/// The amount of tokens steadily increases until it reaches the bucket's capacity.
/// Performing the rate-limited action consumes 1 token.
tokens: u32,
#[derive(PartialEq, Debug, Copy, Clone)]
pub struct BucketConfig {
pub capacity: u32,
pub secs_to_refill: u32,
impl Bucket {
fn update(self, now: InstantSecs, config: BucketConfig) -> Self {
let secs_since_last_checked = now.secs.saturating_sub(self.last_checked.secs);
// For `secs_since_last_checked` seconds, the amount of tokens increases by `capacity` every
// `secs_to_refill` seconds. The amount of tokens added per second is `capacity /
// secs_to_refill`. The expression below is like `secs_since_last_checked * (capacity /
// secs_to_refill)` but with precision and non-overflowing multiplication.
let added_tokens = u64::from(secs_since_last_checked) * u64::from(config.capacity)
/ u64::from(config.secs_to_refill);
// The amount of tokens there would be if the bucket had infinite capacity
let unbounded_tokens = self.tokens + (added_tokens as u32);
// Bucket stops filling when capacity is reached
let tokens = std::cmp::min(unbounded_tokens, config.capacity);
Bucket {
last_checked: now,
#[derive(Debug, enum_map::Enum, Copy, Clone, Display, AsRefStr)]
pub enum ActionType {
#[derive(PartialEq, Debug, Clone)]
struct RateLimitedGroup<C> {
total: EnumMap<ActionType, Bucket>,
children: C,
type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
/// Implemented for `()`, `Map<T, ()>`, `Map<T, Map<U, ()>>`, etc.
trait MapLevel: Default {
type CapacityFactors;
type AddrParts;
fn check(
&mut self,
action_type: ActionType,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
capacity_factors: Self::CapacityFactors,
addr_parts: Self::AddrParts,
) -> bool;
/// Remove full buckets and return `true` if there's any buckets remaining
fn remove_full_buckets(
&mut self,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
) -> bool;
impl<K: Eq + Hash, C: MapLevel> MapLevel for Map<K, C> {
type CapacityFactors = (u32, C::CapacityFactors);
type AddrParts = (K, C::AddrParts);
fn check(
&mut self,
action_type: ActionType,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
(capacity_factor, child_capacity_factors): Self::CapacityFactors,
(addr_part, child_addr_parts): Self::AddrParts,
) -> bool {
// Multiplies capacities by `capacity_factor` for groups in `self`
let adjusted_configs = configs.map(|_, config| BucketConfig {
capacity: config.capacity.saturating_mul(capacity_factor),
// Remove groups that are no longer needed if the hash map's existing allocation has no space
// for new groups. This is done before calling `HashMap::entry` because that immediately
// allocates just like `HashMap::insert`.
if (self.capacity() == self.len()) && !self.contains_key(&addr_part) {
self.remove_full_buckets(now, configs);
let group = self
.or_insert(RateLimitedGroup::new(now, adjusted_configs));
let total_passes = group.check_total(action_type, now, adjusted_configs[action_type]);
let children_pass = group.children.check(
total_passes && children_pass
fn remove_full_buckets(
&mut self,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
) -> bool {
self.retain(|_key, group| {
let some_children_remaining = group.children.remove_full_buckets(now, configs);
// Evaluated if `some_children_remaining` is false
let total_has_refill_in_future = || {
group.total.into_iter().any(|(action_type, bucket)| {
let config = configs[action_type];
bucket.update(now, config).tokens != config.capacity
some_children_remaining || total_has_refill_in_future()
impl MapLevel for () {
type CapacityFactors = ();
type AddrParts = ();
fn check(
&mut self,
_: ActionType,
_: InstantSecs,
_: EnumMap<ActionType, BucketConfig>,
_: Self::CapacityFactors,
_: Self::AddrParts,
) -> bool {
fn remove_full_buckets(&mut self, _: InstantSecs, _: EnumMap<ActionType, BucketConfig>) -> bool {
impl<C: Default> RateLimitedGroup<C> {
fn new(now: InstantSecs, configs: EnumMap<ActionType, BucketConfig>) -> Self {
RateLimitedGroup {
total: configs.map(|_, config| Bucket {
last_checked: now,
tokens: config.capacity,
// `HashMap::new()` or `()`
children: Default::default(),
fn check_total(
&mut self,
action_type: ActionType,
now: InstantSecs,
config: BucketConfig,
) -> bool {
#[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` function
let bucket = &mut self.total[action_type];
let new_bucket = bucket.update(now, config);
if new_bucket.tokens == 0 {
// Not enough tokens yet
// Setting `bucket` to `new_bucket` here is useless and would cause the bucket to start over
// at 0 tokens because of rounding
} else {
// Consume 1 token
*bucket = new_bucket;
bucket.tokens -= 1;
/// Rate limiting based on rate type and IP addr
#[derive(PartialEq, Debug, Clone)]
pub struct RateLimitState {
/// Each individual IPv4 address gets one `RateLimitedGroup`.
ipv4_buckets: Map<Ipv4Addr, ()>,
/// All IPv6 addresses that share the same first 64 bits share the same `RateLimitedGroup`.
/// The same thing happens for the first 48 and 56 bits, but with increased capacity.
/// This is done because all users can easily switch to any other IPv6 address that has the same
/// first 64 bits. It could be as low as 48 bits for some networks, which is the reason for 48
/// and 56 bit address groups.
ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
/// This stores a `BucketConfig` for each `ActionType`. `EnumMap` makes it impossible to have a
/// missing `BucketConfig`.
bucket_configs: EnumMap<ActionType, BucketConfig>,
impl RateLimitState {
pub fn new(bucket_configs: EnumMap<ActionType, BucketConfig>) -> Self {
RateLimitState {
ipv4_buckets: HashMap::new(),
ipv6_buckets: HashMap::new(),
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
pub fn check(&mut self, action_type: ActionType, ip: IpAddr, now: InstantSecs) -> bool {
let result = match ip {
IpAddr::V4(ipv4) => {
.check(action_type, now, self.bucket_configs, (1, ()), (ipv4, ()))
IpAddr::V6(ipv6) => {
let (key_48, key_56, key_64) = split_ipv6(ipv6);
(16, (4, (1, ()))),
(key_48, (key_56, (key_64, ()))),
if !result {
debug!("Rate limited IP: {ip}, type: {action_type:?}");
/// Remove buckets that are now full
pub fn remove_full_buckets(&mut self, now: InstantSecs) {
.remove_full_buckets(now, self.bucket_configs);
.remove_full_buckets(now, self.bucket_configs);
pub fn set_config(&mut self, new_configs: EnumMap<ActionType, BucketConfig>) {
self.bucket_configs = new_configs;
fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
([a0, a1, a2, a3, a4, a5], b, c)
mod tests {
use super::{ActionType, BucketConfig, InstantSecs, RateLimitState, RateLimitedGroup};
use pretty_assertions::assert_eq;
fn test_split_ipv6() {
let ip = std::net::Ipv6Addr::new(
0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
fn test_rate_limiter() {
let bucket_configs = enum_map::enum_map! {
ActionType::Message => BucketConfig {
capacity: 2,
secs_to_refill: 1,
_ => BucketConfig {
capacity: 2,
secs_to_refill: 1,
let mut rate_limiter = RateLimitState::new(bucket_configs);
let mut now = InstantSecs::now();
// Do 1 `Message` and 1 `Post` action for each IP address, and expect the limit to not be
// reached
let ips = [
for ip in ips {
let ip = ip.parse().unwrap();
let message_passed = rate_limiter.check(ActionType::Message, ip, now);
let post_passed = rate_limiter.check(ActionType::Post, ip, now);
let expected_buckets = |factor: u32, tokens_consumed: u32| {
let adjusted_configs = bucket_configs.map(|_, config| BucketConfig {
capacity: config.capacity.saturating_mul(factor),
let mut buckets = RateLimitedGroup::<()>::new(now, adjusted_configs).total;
buckets[ActionType::Message].tokens -= tokens_consumed;
buckets[ActionType::Post].tokens -= tokens_consumed;
let bottom_group = |tokens_consumed| RateLimitedGroup {
total: expected_buckets(1, tokens_consumed),
children: (),
RateLimitState {
ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1))].into(),
ipv6_buckets: [(
[0, 1, 0, 2, 0, 3],
RateLimitedGroup {
total: expected_buckets(16, 4),
children: [
RateLimitedGroup {
total: expected_buckets(4, 1),
children: [(0, bottom_group(1))].into(),
RateLimitedGroup {
total: expected_buckets(4, 3),
children: [(0, bottom_group(1)), (5, bottom_group(2))].into(),
// Do 2 `Message` actions for 1 IP address and expect only the 2nd one to fail
for expected_to_pass in [true, false] {
let ip = "1:2:3:0400::".parse().unwrap();
let passed = rate_limiter.check(ActionType::Message, ip, now);
assert_eq!(passed, expected_to_pass);
// Expect `remove_full_buckets` to remove everything when called 2 seconds later
now.secs += 2;
// `remove full buckets` should not remove empty buckets
let ip = "".parse().unwrap();
// empty the bucket with 2 requests
assert!(rate_limiter.check(ActionType::Post, ip, now));
assert!(rate_limiter.check(ActionType::Post, ip, now));
// `remove full buckets` should not remove partial buckets
now.secs += 2;
let ip = "".parse().unwrap();
// Only make one request, so bucket still has 1 token
assert!(rate_limiter.check(ActionType::Post, ip, now));