Refactor rate limiter and improve rate limit bucket cleanup (#3937)

* Update rate_limiter.rs

* Update mod.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update mod.rs

* Update scheduled_tasks.rs

* Shrink `RateLimitBucket`

* Update rate_limiter.rs

* Update mod.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update mod.rs

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* rerun ci

* Update rate_limiter.rs

* Undo changes to  fields

* Manually undo changes to RateLimitBucket fields

* fmt

* Bucket cleanup loop in rate_limit/mod.rs

* Remove rate limit bucket cleanup from scheduled_tasks.rs

* Remove ;

* Remove UNINITIALIZED_TOKEN_AMOUNT

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* stuff

* MapLevel trait

* fix merge

* Prevent negative numbers in buckets

* Clean up MapLevel::check

* MapLevel::remove_full_buckets

* stuff

* Use remove_full_buckets to avoid allocations

* stuff

* remove tx

* Remove RateLimitConfig

* Rename settings_updated_channel to rate_limit_cell

* Remove global rate limit cell

* impl Default for RateLimitCell

* bucket_configs doc comment to explain EnumMap

* improve test_rate_limiter

* rename default to with_test_config

---------

Co-authored-by: Dessalines <dessalines@users.noreply.github.com>
Co-authored-by: Nutomic <me@nutomic.com>
This commit is contained in:
dullbananas 2023-10-19 06:31:51 -07:00 committed by GitHub
parent a675fecacd
commit a14657d124
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 388 additions and 395 deletions

1
Cargo.lock generated
View file

@ -2659,6 +2659,7 @@ dependencies = [
"anyhow", "anyhow",
"chrono", "chrono",
"encoding", "encoding",
"enum-map",
"futures", "futures",
"getrandom", "getrandom",
"jsonwebtoken", "jsonwebtoken",

View file

@ -129,6 +129,7 @@ rustls = { version = "0.21.3", features = ["dangerous_configuration"] }
futures-util = "0.3.28" futures-util = "0.3.28"
tokio-postgres = "0.7.8" tokio-postgres = "0.7.8"
tokio-postgres-rustls = "0.10.0" tokio-postgres-rustls = "0.10.0"
enum-map = "2.6"
[dependencies] [dependencies]
lemmy_api = { workspace = true } lemmy_api = { workspace = true }

View file

@ -68,6 +68,7 @@ actix-web = { workspace = true, optional = true }
jsonwebtoken = { version = "8.3.0", optional = true } jsonwebtoken = { version = "8.3.0", optional = true }
# necessary for wasmt compilation # necessary for wasmt compilation
getrandom = { version = "0.2.10", features = ["js"] } getrandom = { version = "0.2.10", features = ["js"] }
enum-map = { workspace = true }
[dev-dependencies] [dev-dependencies]
serial_test = { workspace = true } serial_test = { workspace = true }

View file

@ -88,7 +88,7 @@ mod tests {
traits::Crud, traits::Crud,
utils::build_db_pool_for_tests, utils::build_db_pool_for_tests,
}; };
use lemmy_utils::rate_limit::{RateLimitCell, RateLimitConfig}; use lemmy_utils::rate_limit::RateLimitCell;
use reqwest::Client; use reqwest::Client;
use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientBuilder;
use serial_test::serial; use serial_test::serial;
@ -103,9 +103,7 @@ mod tests {
pool_.clone(), pool_.clone(),
ClientBuilder::new(Client::default()).build(), ClientBuilder::new(Client::default()).build(),
secret, secret,
RateLimitCell::new(RateLimitConfig::builder().build()) RateLimitCell::with_test_config(),
.await
.clone(),
); );
let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string()) let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string())

View file

@ -46,7 +46,7 @@ impl LemmyContext {
pub fn secret(&self) -> &Secret { pub fn secret(&self) -> &Secret {
&self.secret &self.secret
} }
pub fn settings_updated_channel(&self) -> &RateLimitCell { pub fn rate_limit_cell(&self) -> &RateLimitCell {
&self.rate_limit_cell &self.rate_limit_cell
} }
} }

View file

@ -7,6 +7,7 @@ use crate::{
use actix_web::cookie::{Cookie, SameSite}; use actix_web::cookie::{Cookie, SameSite};
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Days, Local, TimeZone, Utc}; use chrono::{DateTime, Days, Local, TimeZone, Utc};
use enum_map::{enum_map, EnumMap};
use lemmy_db_schema::{ use lemmy_db_schema::{
newtypes::{CommunityId, DbUrl, PersonId, PostId}, newtypes::{CommunityId, DbUrl, PersonId, PostId},
source::{ source::{
@ -34,7 +35,7 @@ use lemmy_utils::{
email::{send_email, translations::Lang}, email::{send_email, translations::Lang},
error::{LemmyError, LemmyErrorExt, LemmyErrorType, LemmyResult}, error::{LemmyError, LemmyErrorExt, LemmyErrorType, LemmyResult},
location_info, location_info,
rate_limit::RateLimitConfig, rate_limit::{ActionType, BucketConfig},
settings::structs::Settings, settings::structs::Settings,
utils::slurs::build_slur_regex, utils::slurs::build_slur_regex,
}; };
@ -390,25 +391,21 @@ fn lang_str_to_lang(lang: &str) -> Lang {
} }
pub fn local_site_rate_limit_to_rate_limit_config( pub fn local_site_rate_limit_to_rate_limit_config(
local_site_rate_limit: &LocalSiteRateLimit, l: &LocalSiteRateLimit,
) -> RateLimitConfig { ) -> EnumMap<ActionType, BucketConfig> {
let l = local_site_rate_limit; enum_map! {
RateLimitConfig { ActionType::Message => (l.message, l.message_per_second),
message: l.message, ActionType::Post => (l.post, l.post_per_second),
message_per_second: l.message_per_second, ActionType::Register => (l.register, l.register_per_second),
post: l.post, ActionType::Image => (l.image, l.image_per_second),
post_per_second: l.post_per_second, ActionType::Comment => (l.comment, l.comment_per_second),
register: l.register, ActionType::Search => (l.search, l.search_per_second),
register_per_second: l.register_per_second, ActionType::ImportUserSettings => (l.import_user_settings, l.import_user_settings_per_second),
image: l.image,
image_per_second: l.image_per_second,
comment: l.comment,
comment_per_second: l.comment_per_second,
search: l.search,
search_per_second: l.search_per_second,
import_user_settings: l.import_user_settings,
import_user_settings_per_second: l.import_user_settings_per_second,
} }
.map(|_key, (capacity, secs_to_refill)| BucketConfig {
capacity: u32::try_from(capacity).unwrap_or(0),
secs_to_refill: u32::try_from(secs_to_refill).unwrap_or(0),
})
} }
pub fn local_site_to_slur_regex(local_site: &LocalSite) -> Option<Regex> { pub fn local_site_to_slur_regex(local_site: &LocalSite) -> Option<Regex> {

View file

@ -119,10 +119,7 @@ pub async fn create_site(
let rate_limit_config = let rate_limit_config =
local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit);
context context.rate_limit_cell().set_config(rate_limit_config);
.settings_updated_channel()
.send(rate_limit_config)
.await?;
Ok(Json(SiteResponse { Ok(Json(SiteResponse {
site_view, site_view,

View file

@ -157,10 +157,7 @@ pub async fn update_site(
let rate_limit_config = let rate_limit_config =
local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit);
context context.rate_limit_cell().set_config(rate_limit_config);
.settings_updated_channel()
.send(rate_limit_config)
.await?;
Ok(Json(SiteResponse { Ok(Json(SiteResponse {
site_view, site_view,

View file

@ -61,10 +61,7 @@ pub(crate) mod tests {
use anyhow::anyhow; use anyhow::anyhow;
use lemmy_api_common::{context::LemmyContext, request::build_user_agent}; use lemmy_api_common::{context::LemmyContext, request::build_user_agent};
use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool_for_tests}; use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool_for_tests};
use lemmy_utils::{ use lemmy_utils::{rate_limit::RateLimitCell, settings::SETTINGS};
rate_limit::{RateLimitCell, RateLimitConfig},
settings::SETTINGS,
};
use reqwest::{Client, Request, Response}; use reqwest::{Client, Request, Response};
use reqwest_middleware::{ClientBuilder, Middleware, Next}; use reqwest_middleware::{ClientBuilder, Middleware, Next};
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
@ -101,8 +98,7 @@ pub(crate) mod tests {
jwt_secret: String::new(), jwt_secret: String::new(),
}; };
let rate_limit_config = RateLimitConfig::builder().build(); let rate_limit_cell = RateLimitCell::with_test_config();
let rate_limit_cell = RateLimitCell::new(rate_limit_config).await;
let context = LemmyContext::create(pool, client, secret, rate_limit_cell.clone()); let context = LemmyContext::create(pool, client, secret, rate_limit_cell.clone());
let config = FederationConfig::builder() let config = FederationConfig::builder()

View file

@ -47,7 +47,7 @@ smart-default = "0.7.1"
lettre = { version = "0.10.4", features = ["tokio1", "tokio1-native-tls"] } lettre = { version = "0.10.4", features = ["tokio1", "tokio1-native-tls"] }
markdown-it = "0.5.1" markdown-it = "0.5.1"
ts-rs = { workspace = true, optional = true } ts-rs = { workspace = true, optional = true }
enum-map = "2.6" enum-map = { workspace = true }
[dev-dependencies] [dev-dependencies]
reqwest = { workspace = true } reqwest = { workspace = true }

View file

@ -1,9 +1,9 @@
use crate::error::{LemmyError, LemmyErrorType}; use crate::error::{LemmyError, LemmyErrorType};
use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
use enum_map::enum_map; use enum_map::{enum_map, EnumMap};
use futures::future::{ok, Ready}; use futures::future::{ok, Ready};
use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; pub use rate_limiter::{ActionType, BucketConfig};
use serde::{Deserialize, Serialize}; use rate_limiter::{InstantSecs, RateLimitState};
use std::{ use std::{
future::Future, future::Future,
net::{IpAddr, Ipv4Addr, SocketAddr}, net::{IpAddr, Ipv4Addr, SocketAddr},
@ -14,208 +14,140 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
}; };
use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
use typed_builder::TypedBuilder;
pub mod rate_limiter; pub mod rate_limiter;
#[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)]
pub struct RateLimitConfig {
#[builder(default = 180)]
/// Maximum number of messages created in interval
pub message: i32,
#[builder(default = 60)]
/// Interval length for message limit, in seconds
pub message_per_second: i32,
#[builder(default = 6)]
/// Maximum number of posts created in interval
pub post: i32,
#[builder(default = 300)]
/// Interval length for post limit, in seconds
pub post_per_second: i32,
#[builder(default = 3)]
/// Maximum number of registrations in interval
pub register: i32,
#[builder(default = 3600)]
/// Interval length for registration limit, in seconds
pub register_per_second: i32,
#[builder(default = 6)]
/// Maximum number of image uploads in interval
pub image: i32,
#[builder(default = 3600)]
/// Interval length for image uploads, in seconds
pub image_per_second: i32,
#[builder(default = 6)]
/// Maximum number of comments created in interval
pub comment: i32,
#[builder(default = 600)]
/// Interval length for comment limit, in seconds
pub comment_per_second: i32,
#[builder(default = 60)]
/// Maximum number of searches created in interval
pub search: i32,
#[builder(default = 600)]
/// Interval length for search limit, in seconds
pub search_per_second: i32,
#[builder(default = 1)]
/// Maximum number of user settings imports in interval
pub import_user_settings: i32,
#[builder(default = 24 * 60 * 60)]
/// Interval length for importing user settings, in seconds (defaults to 24 hours)
pub import_user_settings_per_second: i32,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct RateLimit { pub struct RateLimitChecker {
pub rate_limiter: RateLimitStorage, state: Arc<Mutex<RateLimitState>>,
pub rate_limit_config: RateLimitConfig, action_type: ActionType,
}
#[derive(Debug, Clone)]
pub struct RateLimitedGuard {
rate_limit: Arc<Mutex<RateLimit>>,
type_: RateLimitType,
} }
/// Single instance of rate limit config and buckets, which is shared across all threads. /// Single instance of rate limit config and buckets, which is shared across all threads.
#[derive(Clone)] #[derive(Clone)]
pub struct RateLimitCell { pub struct RateLimitCell {
tx: Sender<RateLimitConfig>, state: Arc<Mutex<RateLimitState>>,
rate_limit: Arc<Mutex<RateLimit>>,
} }
impl RateLimitCell { impl RateLimitCell {
/// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell. pub fn new(rate_limit_config: EnumMap<ActionType, BucketConfig>) -> Self {
pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self { let state = Arc::new(Mutex::new(RateLimitState::new(rate_limit_config)));
static LOCAL_INSTANCE: OnceCell<RateLimitCell> = OnceCell::const_new();
LOCAL_INSTANCE let state_weak_ref = Arc::downgrade(&state);
.get_or_init(|| async {
let (tx, mut rx) = mpsc::channel::<RateLimitConfig>(4);
let rate_limit = Arc::new(Mutex::new(RateLimit {
rate_limiter: Default::default(),
rate_limit_config,
}));
let rate_limit2 = rate_limit.clone();
tokio::spawn(async move { tokio::spawn(async move {
while let Some(r) = rx.recv().await { let hour = Duration::from_secs(3600);
rate_limit2
// This loop stops when all other references to `state` are dropped
while let Some(state) = state_weak_ref.upgrade() {
tokio::time::sleep(hour).await;
state
.lock() .lock()
.expect("Failed to lock rate limit mutex for updating") .expect("Failed to lock rate limit mutex for reading")
.rate_limit_config = r; .remove_full_buckets(InstantSecs::now());
} }
}); });
RateLimitCell { tx, rate_limit }
}) RateLimitCell { state }
.await
} }
/// Call this when the config was updated, to update all in-memory cells. pub fn set_config(&self, config: EnumMap<ActionType, BucketConfig>) {
pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> { self
self.tx.send(config).await?; .state
Ok(())
}
/// Remove buckets older than the given duration
pub fn remove_older_than(&self, mut duration: Duration) {
let mut guard = self
.rate_limit
.lock() .lock()
.expect("Failed to lock rate limit mutex for reading"); .expect("Failed to lock rate limit mutex for updating")
let rate_limit = &guard.rate_limit_config; .set_config(config);
// If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
let max_interval_secs = enum_map! {
RateLimitType::Message => rate_limit.message_per_second,
RateLimitType::Post => rate_limit.post_per_second,
RateLimitType::Register => rate_limit.register_per_second,
RateLimitType::Image => rate_limit.image_per_second,
RateLimitType::Comment => rate_limit.comment_per_second,
RateLimitType::Search => rate_limit.search_per_second,
RateLimitType::ImportUserSettings => rate_limit.import_user_settings_per_second
}
.into_values()
.max()
.and_then(|max| u64::try_from(max).ok())
.unwrap_or(0);
duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
guard
.rate_limiter
.remove_older_than(duration, InstantSecs::now())
} }
pub fn message(&self) -> RateLimitedGuard { pub fn message(&self) -> RateLimitChecker {
self.kind(RateLimitType::Message) self.new_checker(ActionType::Message)
} }
pub fn post(&self) -> RateLimitedGuard { pub fn post(&self) -> RateLimitChecker {
self.kind(RateLimitType::Post) self.new_checker(ActionType::Post)
} }
pub fn register(&self) -> RateLimitedGuard { pub fn register(&self) -> RateLimitChecker {
self.kind(RateLimitType::Register) self.new_checker(ActionType::Register)
} }
pub fn image(&self) -> RateLimitedGuard { pub fn image(&self) -> RateLimitChecker {
self.kind(RateLimitType::Image) self.new_checker(ActionType::Image)
} }
pub fn comment(&self) -> RateLimitedGuard { pub fn comment(&self) -> RateLimitChecker {
self.kind(RateLimitType::Comment) self.new_checker(ActionType::Comment)
} }
pub fn search(&self) -> RateLimitedGuard { pub fn search(&self) -> RateLimitChecker {
self.kind(RateLimitType::Search) self.new_checker(ActionType::Search)
} }
pub fn import_user_settings(&self) -> RateLimitedGuard { pub fn import_user_settings(&self) -> RateLimitChecker {
self.kind(RateLimitType::ImportUserSettings) self.new_checker(ActionType::ImportUserSettings)
} }
fn kind(&self, type_: RateLimitType) -> RateLimitedGuard { fn new_checker(&self, action_type: ActionType) -> RateLimitChecker {
RateLimitedGuard { RateLimitChecker {
rate_limit: self.rate_limit.clone(), state: self.state.clone(),
type_, action_type,
} }
} }
pub fn with_test_config() -> Self {
Self::new(enum_map! {
ActionType::Message => BucketConfig {
capacity: 180,
secs_to_refill: 60,
},
ActionType::Post => BucketConfig {
capacity: 6,
secs_to_refill: 300,
},
ActionType::Register => BucketConfig {
capacity: 3,
secs_to_refill: 3600,
},
ActionType::Image => BucketConfig {
capacity: 6,
secs_to_refill: 3600,
},
ActionType::Comment => BucketConfig {
capacity: 6,
secs_to_refill: 600,
},
ActionType::Search => BucketConfig {
capacity: 60,
secs_to_refill: 600,
},
ActionType::ImportUserSettings => BucketConfig {
capacity: 1,
secs_to_refill: 24 * 60 * 60,
},
})
}
} }
pub struct RateLimitedMiddleware<S> { pub struct RateLimitedMiddleware<S> {
rate_limited: RateLimitedGuard, checker: RateLimitChecker,
service: Rc<S>, service: Rc<S>,
} }
impl RateLimitedGuard { impl RateLimitChecker {
/// Returns true if the request passed the rate limit, false if it failed and should be rejected. /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
pub fn check(self, ip_addr: IpAddr) -> bool { pub fn check(self, ip_addr: IpAddr) -> bool {
// Does not need to be blocking because the RwLock in settings never held across await points, // Does not need to be blocking because the RwLock in settings never held across await points,
// and the operation here locks only long enough to clone // and the operation here locks only long enough to clone
let mut guard = self let mut state = self
.rate_limit .state
.lock() .lock()
.expect("Failed to lock rate limit mutex for reading"); .expect("Failed to lock rate limit mutex for reading");
let rate_limit = &guard.rate_limit_config;
let (kind, interval) = match self.type_ { state.check(self.action_type, ip_addr, InstantSecs::now())
RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second),
RateLimitType::ImportUserSettings => (
rate_limit.import_user_settings,
rate_limit.import_user_settings_per_second,
),
};
let limiter = &mut guard.rate_limiter;
limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
} }
} }
impl<S> Transform<S, ServiceRequest> for RateLimitedGuard impl<S> Transform<S, ServiceRequest> for RateLimitChecker
where where
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static, S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
S::Future: 'static, S::Future: 'static,
@ -228,7 +160,7 @@ where
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(RateLimitedMiddleware { ok(RateLimitedMiddleware {
rate_limited: self.clone(), checker: self.clone(),
service: Rc::new(service), service: Rc::new(service),
}) })
} }
@ -252,11 +184,11 @@ where
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
let ip_addr = get_ip(&req.connection_info()); let ip_addr = get_ip(&req.connection_info());
let rate_limited = self.rate_limited.clone(); let checker = self.checker.clone();
let service = self.service.clone(); let service = self.service.clone();
Box::pin(async move { Box::pin(async move {
if rate_limited.check(ip_addr) { if checker.check(ip_addr) {
service.call(req).await service.call(req).await
} else { } else {
let (http_req, _) = req.into_parts(); let (http_req, _) = req.into_parts();

View file

@ -1,15 +1,13 @@
use enum_map::{enum_map, EnumMap}; use enum_map::EnumMap;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::{ use std::{
collections::HashMap, collections::HashMap,
hash::Hash, hash::Hash,
net::{IpAddr, Ipv4Addr, Ipv6Addr}, net::{IpAddr, Ipv4Addr, Ipv6Addr},
time::{Duration, Instant}, time::Instant,
}; };
use tracing::debug; use tracing::debug;
const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
static START_TIME: Lazy<Instant> = Lazy::new(Instant::now); static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't /// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
@ -26,27 +24,48 @@ impl InstantSecs {
.expect("server has been running for over 136 years"), .expect("server has been running for over 136 years"),
} }
} }
fn secs_since(self, earlier: Self) -> u32 {
self.secs.saturating_sub(earlier.secs)
}
fn to_instant(self) -> Instant {
*START_TIME + Duration::from_secs(self.secs.into())
}
} }
#[derive(PartialEq, Debug, Clone)] #[derive(PartialEq, Debug, Clone, Copy)]
struct RateLimitBucket { struct Bucket {
last_checked: InstantSecs, last_checked: InstantSecs,
/// This field stores the amount of tokens that were present at `last_checked`. /// This field stores the amount of tokens that were present at `last_checked`.
/// The amount of tokens steadily increases until it reaches the bucket's capacity. /// The amount of tokens steadily increases until it reaches the bucket's capacity.
/// Performing the rate-limited action consumes 1 token. /// Performing the rate-limited action consumes 1 token.
tokens: f32, tokens: u32,
}
#[derive(PartialEq, Debug, Copy, Clone)]
pub struct BucketConfig {
pub capacity: u32,
pub secs_to_refill: u32,
}
impl Bucket {
fn update(self, now: InstantSecs, config: BucketConfig) -> Self {
let secs_since_last_checked = now.secs.saturating_sub(self.last_checked.secs);
// For `secs_since_last_checked` seconds, the amount of tokens increases by `capacity` every `secs_to_refill` seconds.
// The amount of tokens added per second is `capacity / secs_to_refill`.
// The expression below is like `secs_since_last_checked * (capacity / secs_to_refill)` but with precision and non-overflowing multiplication.
let added_tokens = u64::from(secs_since_last_checked) * u64::from(config.capacity)
/ u64::from(config.secs_to_refill);
// The amount of tokens there would be if the bucket had infinite capacity
let unbounded_tokens = self.tokens + (added_tokens as u32);
// Bucket stops filling when capacity is reached
let tokens = std::cmp::min(unbounded_tokens, config.capacity);
Bucket {
last_checked: now,
tokens,
}
}
} }
#[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)] #[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
pub(crate) enum RateLimitType { pub enum ActionType {
Message, Message,
Register, Register,
Post, Post,
@ -56,179 +75,228 @@ pub(crate) enum RateLimitType {
ImportUserSettings, ImportUserSettings,
} }
type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
#[derive(PartialEq, Debug, Clone)] #[derive(PartialEq, Debug, Clone)]
struct RateLimitedGroup<C> { struct RateLimitedGroup<C> {
total: EnumMap<RateLimitType, RateLimitBucket>, total: EnumMap<ActionType, Bucket>,
children: C, children: C,
} }
type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
/// Implemented for `()`, `Map<T, ()>`, `Map<T, Map<U, ()>>`, etc.
trait MapLevel: Default {
type CapacityFactors;
type AddrParts;
fn check(
&mut self,
action_type: ActionType,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
capacity_factors: Self::CapacityFactors,
addr_parts: Self::AddrParts,
) -> bool;
/// Remove full buckets and return `true` if there's any buckets remaining
fn remove_full_buckets(
&mut self,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
) -> bool;
}
impl<K: Eq + Hash, C: MapLevel> MapLevel for Map<K, C> {
type CapacityFactors = (u32, C::CapacityFactors);
type AddrParts = (K, C::AddrParts);
fn check(
&mut self,
action_type: ActionType,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
(capacity_factor, child_capacity_factors): Self::CapacityFactors,
(addr_part, child_addr_parts): Self::AddrParts,
) -> bool {
// Multiplies capacities by `capacity_factor` for groups in `self`
let adjusted_configs = configs.map(|_, config| BucketConfig {
capacity: config.capacity.saturating_mul(capacity_factor),
..config
});
// Remove groups that are no longer needed if the hash map's existing allocation has no space for new groups.
// This is done before calling `HashMap::entry` because that immediately allocates just like `HashMap::insert`.
if (self.capacity() == self.len()) && !self.contains_key(&addr_part) {
self.remove_full_buckets(now, configs);
}
let group = self
.entry(addr_part)
.or_insert(RateLimitedGroup::new(now, adjusted_configs));
#[allow(clippy::indexing_slicing)]
let total_passes = group.check_total(action_type, now, adjusted_configs[action_type]);
let children_pass = group.children.check(
action_type,
now,
configs,
child_capacity_factors,
child_addr_parts,
);
total_passes && children_pass
}
fn remove_full_buckets(
&mut self,
now: InstantSecs,
configs: EnumMap<ActionType, BucketConfig>,
) -> bool {
self.retain(|_key, group| {
let some_children_remaining = group.children.remove_full_buckets(now, configs);
// Evaluated if `some_children_remaining` is false
let total_has_refill_in_future = || {
group.total.into_iter().all(|(action_type, bucket)| {
#[allow(clippy::indexing_slicing)]
let config = configs[action_type];
bucket.update(now, config).tokens != config.capacity
})
};
some_children_remaining || total_has_refill_in_future()
});
self.shrink_to_fit();
!self.is_empty()
}
}
impl MapLevel for () {
type CapacityFactors = ();
type AddrParts = ();
fn check(
&mut self,
_: ActionType,
_: InstantSecs,
_: EnumMap<ActionType, BucketConfig>,
_: Self::CapacityFactors,
_: Self::AddrParts,
) -> bool {
true
}
fn remove_full_buckets(&mut self, _: InstantSecs, _: EnumMap<ActionType, BucketConfig>) -> bool {
false
}
}
impl<C: Default> RateLimitedGroup<C> { impl<C: Default> RateLimitedGroup<C> {
fn new(now: InstantSecs) -> Self { fn new(now: InstantSecs, configs: EnumMap<ActionType, BucketConfig>) -> Self {
RateLimitedGroup { RateLimitedGroup {
total: enum_map! { total: configs.map(|_, config| Bucket {
_ => RateLimitBucket {
last_checked: now, last_checked: now,
tokens: UNINITIALIZED_TOKEN_AMOUNT, tokens: config.capacity,
}, }),
}, // `HashMap::new()` or `()`
children: Default::default(), children: Default::default(),
} }
} }
fn check_total( fn check_total(
&mut self, &mut self,
type_: RateLimitType, action_type: ActionType,
now: InstantSecs, now: InstantSecs,
capacity: i32, config: BucketConfig,
secs_to_refill: i32,
) -> bool { ) -> bool {
let capacity = capacity as f32;
let secs_to_refill = secs_to_refill as f32;
#[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
let bucket = &mut self.total[type_]; let bucket = &mut self.total[action_type];
if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT { let new_bucket = bucket.update(now, config);
bucket.tokens = capacity;
}
let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32; if new_bucket.tokens == 0 {
bucket.last_checked = now;
// For `secs_since_last_checked` seconds, increase `bucket.tokens`
// by `capacity` every `secs_to_refill` seconds
bucket.tokens += {
let tokens_per_sec = capacity / secs_to_refill;
secs_since_last_checked * tokens_per_sec
};
// Prevent `bucket.tokens` from exceeding `capacity`
if bucket.tokens > capacity {
bucket.tokens = capacity;
}
if bucket.tokens < 1.0 {
// Not enough tokens yet // Not enough tokens yet
debug!( // Setting `bucket` to `new_bucket` here is useless and would cause the bucket to start over at 0 tokens because of rounding
"Rate limited type: {}, time_passed: {}, allowance: {}",
type_.as_ref(),
secs_since_last_checked,
bucket.tokens
);
false false
} else { } else {
// Consume 1 token // Consume 1 token
bucket.tokens -= 1.0; *bucket = new_bucket;
bucket.tokens -= 1;
true true
} }
} }
} }
/// Rate limiting based on rate type and IP addr /// Rate limiting based on rate type and IP addr
#[derive(PartialEq, Debug, Clone, Default)] #[derive(PartialEq, Debug, Clone)]
pub struct RateLimitStorage { pub struct RateLimitState {
/// One bucket per individual IPv4 address /// Each individual IPv4 address gets one `RateLimitedGroup`.
ipv4_buckets: Map<Ipv4Addr, ()>, ipv4_buckets: Map<Ipv4Addr, ()>,
/// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses /// All IPv6 addresses that share the same first 64 bits share the same `RateLimitedGroup`.
///
/// The same thing happens for the first 48 and 56 bits, but with increased capacity.
///
/// This is done because all users can easily switch to any other IPv6 address that has the same first 64 bits.
/// It could be as low as 48 bits for some networks, which is the reason for 48 and 56 bit address groups.
ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>, ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
/// This stores a `BucketConfig` for each `ActionType`. `EnumMap` makes it impossible to have a missing `BucketConfig`.
bucket_configs: EnumMap<ActionType, BucketConfig>,
} }
impl RateLimitStorage { impl RateLimitState {
pub fn new(bucket_configs: EnumMap<ActionType, BucketConfig>) -> Self {
RateLimitState {
ipv4_buckets: HashMap::new(),
ipv6_buckets: HashMap::new(),
bucket_configs,
}
}
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478 /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
/// ///
/// Returns true if the request passed the rate limit, false if it failed and should be rejected. /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
pub(super) fn check_rate_limit_full( pub fn check(&mut self, action_type: ActionType, ip: IpAddr, now: InstantSecs) -> bool {
&mut self, let result = match ip {
type_: RateLimitType,
ip: IpAddr,
capacity: i32,
secs_to_refill: i32,
now: InstantSecs,
) -> bool {
let mut result = true;
match ip {
IpAddr::V4(ipv4) => { IpAddr::V4(ipv4) => {
// Only used by one address. self
let group = self
.ipv4_buckets .ipv4_buckets
.entry(ipv4) .check(action_type, now, self.bucket_configs, (1, ()), (ipv4, ()))
.or_insert(RateLimitedGroup::new(now));
result &= group.check_total(type_, now, capacity, secs_to_refill);
} }
IpAddr::V6(ipv6) => { IpAddr::V6(ipv6) => {
let (key_48, key_56, key_64) = split_ipv6(ipv6); let (key_48, key_56, key_64) = split_ipv6(ipv6);
self.ipv6_buckets.check(
// Contains all addresses with the same first 48 bits. These addresses might be part of the same network. action_type,
let group_48 = self now,
.ipv6_buckets self.bucket_configs,
.entry(key_48) (16, (4, (1, ()))),
.or_insert(RateLimitedGroup::new(now)); (key_48, (key_56, (key_64, ()))),
result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill); )
// Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
let group_56 = group_48
.children
.entry(key_56)
.or_insert(RateLimitedGroup::new(now));
result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
// A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
let group_64 = group_56
.children
.entry(key_64)
.or_insert(RateLimitedGroup::new(now));
result &= group_64.check_total(type_, now, capacity, secs_to_refill);
} }
}; };
if !result { if !result {
debug!("Rate limited IP: {ip}"); debug!("Rate limited IP: {ip}, type: {action_type:?}");
} }
result result
} }
/// Remove buckets older than the given duration /// Remove buckets that are now full
pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) { pub fn remove_full_buckets(&mut self, now: InstantSecs) {
// Only retain buckets that were last used after `instant` self
let Some(instant) = now.to_instant().checked_sub(duration) else { .ipv4_buckets
return; .remove_full_buckets(now, self.bucket_configs);
}; self
.ipv6_buckets
let is_recently_used = |group: &RateLimitedGroup<_>| { .remove_full_buckets(now, self.bucket_configs);
group
.total
.values()
.all(|bucket| bucket.last_checked.to_instant() > instant)
};
retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group));
retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| {
retain_and_shrink(&mut group_48.children, |_, group_56| {
retain_and_shrink(&mut group_56.children, |_, group_64| {
is_recently_used(group_64)
});
!group_56.children.is_empty()
});
!group_48.children.is_empty()
})
} }
}
fn retain_and_shrink<K, V, F>(map: &mut HashMap<K, V>, f: F) pub fn set_config(&mut self, new_configs: EnumMap<ActionType, BucketConfig>) {
where self.bucket_configs = new_configs;
K: Eq + Hash, }
F: FnMut(&K, &mut V) -> bool,
{
map.retain(f);
map.shrink_to_fit();
} }
fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) { fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
@ -241,6 +309,8 @@ mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::unwrap_used)]
#![allow(clippy::indexing_slicing)] #![allow(clippy::indexing_slicing)]
use super::{ActionType, BucketConfig, InstantSecs, RateLimitState, RateLimitedGroup};
#[test] #[test]
fn test_split_ipv6() { fn test_split_ipv6() {
let ip = std::net::Ipv6Addr::new( let ip = std::net::Ipv6Addr::new(
@ -254,9 +324,20 @@ mod tests {
#[test] #[test]
fn test_rate_limiter() { fn test_rate_limiter() {
let mut rate_limiter = super::RateLimitStorage::default(); let bucket_configs = enum_map::enum_map! {
let mut now = super::InstantSecs::now(); ActionType::Message => BucketConfig {
capacity: 2,
secs_to_refill: 1,
},
_ => BucketConfig {
capacity: 2,
secs_to_refill: 1,
},
};
let mut rate_limiter = RateLimitState::new(bucket_configs);
let mut now = InstantSecs::now();
// Do 1 `Message` and 1 `Post` action for each IP address, and expect the limit to not be reached
let ips = [ let ips = [
"123.123.123.123", "123.123.123.123",
"1:2:3::", "1:2:3::",
@ -266,66 +347,71 @@ mod tests {
]; ];
for ip in ips { for ip in ips {
let ip = ip.parse().unwrap(); let ip = ip.parse().unwrap();
let message_passed = let message_passed = rate_limiter.check(ActionType::Message, ip, now);
rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now); let post_passed = rate_limiter.check(ActionType::Post, ip, now);
let post_passed =
rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
assert!(message_passed); assert!(message_passed);
assert!(post_passed); assert!(post_passed);
} }
#[allow(clippy::indexing_slicing)] #[allow(clippy::indexing_slicing)]
let expected_buckets = |factor: f32, tokens_consumed: f32| { let expected_buckets = |factor: u32, tokens_consumed: u32| {
let mut buckets = super::RateLimitedGroup::<()>::new(now).total; let adjusted_configs = bucket_configs.map(|_, config| BucketConfig {
buckets[super::RateLimitType::Message] = super::RateLimitBucket { capacity: config.capacity.saturating_mul(factor),
last_checked: now, ..config
tokens: (2.0 * factor) - tokens_consumed, });
}; let mut buckets = RateLimitedGroup::<()>::new(now, adjusted_configs).total;
buckets[super::RateLimitType::Post] = super::RateLimitBucket { buckets[ActionType::Message].tokens -= tokens_consumed;
last_checked: now, buckets[ActionType::Post].tokens -= tokens_consumed;
tokens: (3.0 * factor) - tokens_consumed,
};
buckets buckets
}; };
let bottom_group = |tokens_consumed| super::RateLimitedGroup { let bottom_group = |tokens_consumed| RateLimitedGroup {
total: expected_buckets(1.0, tokens_consumed), total: expected_buckets(1, tokens_consumed),
children: (), children: (),
}; };
assert_eq!( assert_eq!(
rate_limiter, rate_limiter,
super::RateLimitStorage { RateLimitState {
ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(), bucket_configs,
ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1))].into(),
ipv6_buckets: [( ipv6_buckets: [(
[0, 1, 0, 2, 0, 3], [0, 1, 0, 2, 0, 3],
super::RateLimitedGroup { RateLimitedGroup {
total: expected_buckets(16.0, 4.0), total: expected_buckets(16, 4),
children: [ children: [
( (
0, 0,
super::RateLimitedGroup { RateLimitedGroup {
total: expected_buckets(4.0, 1.0), total: expected_buckets(4, 1),
children: [(0, bottom_group(1.0)),].into(), children: [(0, bottom_group(1))].into(),
} }
), ),
( (
4, 4,
super::RateLimitedGroup { RateLimitedGroup {
total: expected_buckets(4.0, 3.0), total: expected_buckets(4, 3),
children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(), children: [(0, bottom_group(1)), (5, bottom_group(2))].into(),
} }
), ),
] ]
.into(), .into(),
} }
),] )]
.into(), .into(),
} }
); );
// Do 2 `Message` actions for 1 IP address and expect only the 2nd one to fail
for expected_to_pass in [true, false] {
let ip = "1:2:3:0400::".parse().unwrap();
let passed = rate_limiter.check(ActionType::Message, ip, now);
assert_eq!(passed, expected_to_pass);
}
// Expect `remove_full_buckets` to remove everything when called 2 seconds later
now.secs += 2; now.secs += 2;
rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now); rate_limiter.remove_full_buckets(now);
assert!(rate_limiter.ipv4_buckets.is_empty()); assert!(rate_limiter.ipv4_buckets.is_empty());
assert!(rate_limiter.ipv6_buckets.is_empty()); assert!(rate_limiter.ipv6_buckets.is_empty());
} }

View file

@ -156,7 +156,7 @@ pub async fn start_lemmy_server(args: CmdArgs) -> Result<(), LemmyError> {
// Set up the rate limiter // Set up the rate limiter
let rate_limit_config = let rate_limit_config =
local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit);
let rate_limit_cell = RateLimitCell::new(rate_limit_config).await; let rate_limit_cell = RateLimitCell::new(rate_limit_config);
println!( println!(
"Starting http server at {}:{}", "Starting http server at {}:{}",
@ -298,7 +298,7 @@ fn create_http_server(
.expect("Should always be buildable"); .expect("Should always be buildable");
let context: LemmyContext = federation_config.deref().clone(); let context: LemmyContext = federation_config.deref().clone();
let rate_limit_cell = federation_config.settings_updated_channel().clone(); let rate_limit_cell = federation_config.rate_limit_cell().clone();
let self_origin = settings.get_protocol_and_hostname(); let self_origin = settings.get_protocol_and_hostname();
// Create Http server with websocket support // Create Http server with websocket support
let server = HttpServer::new(move || { let server = HttpServer::new(move || {

View file

@ -78,17 +78,6 @@ pub async fn setup(context: LemmyContext) -> Result<(), LemmyError> {
} }
}); });
let context_1 = context.clone();
// Remove old rate limit buckets after 1 to 2 hours of inactivity
scheduler.every(CTimeUnits::hour(1)).run(move || {
let context = context_1.clone();
async move {
let hour = Duration::from_secs(3600);
context.settings_updated_channel().remove_older_than(hour);
}
});
let context_1 = context.clone(); let context_1 = context.clone();
// Overwrite deleted & removed posts and comments every day // Overwrite deleted & removed posts and comments every day
scheduler.every(CTimeUnits::days(1)).run(move || { scheduler.every(CTimeUnits::days(1)).run(move || {

View file

@ -112,7 +112,7 @@ mod tests {
traits::Crud, traits::Crud,
utils::build_db_pool_for_tests, utils::build_db_pool_for_tests,
}; };
use lemmy_utils::rate_limit::{RateLimitCell, RateLimitConfig}; use lemmy_utils::rate_limit::RateLimitCell;
use reqwest::Client; use reqwest::Client;
use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientBuilder;
use serial_test::serial; use serial_test::serial;
@ -131,9 +131,7 @@ mod tests {
pool_.clone(), pool_.clone(),
ClientBuilder::new(Client::default()).build(), ClientBuilder::new(Client::default()).build(),
secret, secret,
RateLimitCell::new(RateLimitConfig::builder().build()) RateLimitCell::with_test_config(),
.await
.clone(),
); );
let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string()) let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string())