diff --git a/Cargo.lock b/Cargo.lock index 760688125..d454831e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1685,6 +1685,26 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" +[[package]] +name = "enum-map" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "988f0d17a0fa38291e5f41f71ea8d46a5d5497b9054d5a759fae2cbb819f2356" +dependencies = [ + "enum-map-derive", +] + +[[package]] +name = "enum-map-derive" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a4da76b3b6116d758c7ba93f7ec6a35d2e2cf24feda76c6e38a375f4d5c59f2" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.103", +] + [[package]] name = "enum_delegate" version = "0.2.0" @@ -2816,6 +2836,7 @@ dependencies = [ "deser-hjson", "diesel", "doku", + "enum-map", "futures", "html2text", "http", diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 09c89b9e8..1ec8d4ba2 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -45,6 +45,7 @@ jsonwebtoken = "8.1.1" lettre = "0.10.1" markdown-it = "0.5.0" totp-rs = { version = "5.0.2", features = ["gen_secret", "otpauth"] } +enum-map = "2.5" [dev-dependencies] reqwest = { workspace = true } diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index f3213390a..e5d07db2c 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -14,21 +14,12 @@ pub mod request; pub mod utils; pub mod version; -use std::{fmt, time::Duration}; +use std::time::Duration; pub type ConnectionId = usize; pub const REQWEST_TIMEOUT: Duration = Duration::from_secs(10); -#[derive(PartialEq, Eq, Hash, Debug, Clone)] -pub struct IpAddr(pub String); - -impl fmt::Display for IpAddr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - #[macro_export] macro_rules! location_info { () => { diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index b7000e3e6..d1c51265d 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,14 +1,18 @@ -use crate::{error::LemmyError, IpAddr}; +use crate::error::LemmyError; use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; +use enum_map::enum_map; use futures::future::{ok, Ready}; -use rate_limiter::{RateLimitStorage, RateLimitType}; +use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; use serde::{Deserialize, Serialize}; use std::{ future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, rc::Rc, + str::FromStr, sync::{Arc, Mutex}, task::{Context, Poll}, + time::Duration, }; use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; use typed_builder::TypedBuilder; @@ -105,6 +109,35 @@ impl RateLimitCell { Ok(()) } + /// Remove buckets older than the given duration + pub fn remove_older_than(&self, mut duration: Duration) { + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; + + // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check. + let max_interval_secs = enum_map! { + RateLimitType::Message => rate_limit.message_per_second, + RateLimitType::Post => rate_limit.post_per_second, + RateLimitType::Register => rate_limit.register_per_second, + RateLimitType::Image => rate_limit.image_per_second, + RateLimitType::Comment => rate_limit.comment_per_second, + RateLimitType::Search => rate_limit.search_per_second, + } + .into_values() + .max() + .and_then(|max| u64::try_from(max).ok()) + .unwrap_or(0); + + duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs)); + + guard + .rate_limiter + .remove_older_than(duration, InstantSecs::now()) + } + pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } @@ -163,7 +196,7 @@ impl RateLimitedGuard { }; let limiter = &mut guard.rate_limiter; - limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) + limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now()) } } @@ -222,13 +255,37 @@ where } fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { - IpAddr( - conn_info - .realip_remote_addr() - .unwrap_or("127.0.0.1:12345") - .split(':') - .next() - .unwrap_or("127.0.0.1") - .to_string(), - ) + conn_info + .realip_remote_addr() + .and_then(parse_ip) + .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) +} + +fn parse_ip(addr: &str) -> Option { + if let Some(s) = addr.strip_suffix(']') { + IpAddr::from_str(s.get(1..)?).ok() + } else if let Ok(ip) = IpAddr::from_str(addr) { + Some(ip) + } else if let Ok(socket) = SocketAddr::from_str(addr) { + Some(socket.ip()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_parse_ip() { + let ip_addrs = [ + "1.2.3.4", + "1.2.3.4:8000", + "2001:db8::", + "[2001:db8::]", + "[2001:db8::]:8000", + ]; + for addr in ip_addrs { + assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}"); + } + } } diff --git a/crates/utils/src/rate_limit/rate_limiter.rs b/crates/utils/src/rate_limit/rate_limiter.rs index d40db5239..12f264fae 100644 --- a/crates/utils/src/rate_limit/rate_limiter.rs +++ b/crates/utils/src/rate_limit/rate_limiter.rs @@ -1,15 +1,50 @@ -use crate::IpAddr; -use std::{collections::HashMap, time::Instant}; -use strum::IntoEnumIterator; +use enum_map::{enum_map, EnumMap}; +use once_cell::sync::Lazy; +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + time::{Duration, Instant}, +}; use tracing::debug; -#[derive(Debug, Clone)] -struct RateLimitBucket { - last_checked: Instant, - allowance: f64, +const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0; + +static START_TIME: Lazy = Lazy::new(Instant::now); + +/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't +/// store nanoseconds +#[derive(PartialEq, Debug, Clone, Copy)] +pub struct InstantSecs { + secs: u32, } -#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)] +impl InstantSecs { + pub fn now() -> Self { + InstantSecs { + secs: u32::try_from(START_TIME.elapsed().as_secs()) + .expect("server has been running for over 136 years"), + } + } + + fn secs_since(self, earlier: Self) -> u32 { + self.secs.saturating_sub(earlier.secs) + } + + fn to_instant(self) -> Instant { + *START_TIME + Duration::from_secs(self.secs.into()) + } +} + +#[derive(PartialEq, Debug, Clone)] +struct RateLimitBucket { + last_checked: InstantSecs, + /// This field stores the amount of tokens that were present at `last_checked`. + /// The amount of tokens steadily increases until it reaches the bucket's capacity. + /// Performing the rate-limited action consumes 1 token. + tokens: f32, +} + +#[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)] pub(crate) enum RateLimitType { Message, Register, @@ -19,79 +54,263 @@ pub(crate) enum RateLimitType { Search, } -/// Rate limiting based on rate type and IP addr -#[derive(Debug, Clone, Default)] -pub struct RateLimitStorage { - buckets: HashMap>, +type Map = HashMap>; + +#[derive(PartialEq, Debug, Clone)] +struct RateLimitedGroup { + total: EnumMap, + children: C, } -impl RateLimitStorage { - fn insert_ip(&mut self, ip: &IpAddr) { - for rate_limit_type in RateLimitType::iter() { - if self.buckets.get(&rate_limit_type).is_none() { - self.buckets.insert(rate_limit_type, HashMap::new()); - } - - if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) { - if bucket.get(ip).is_none() { - bucket.insert( - ip.clone(), - RateLimitBucket { - last_checked: Instant::now(), - allowance: -2f64, - }, - ); - } - } +impl RateLimitedGroup { + fn new(now: InstantSecs) -> Self { + RateLimitedGroup { + total: enum_map! { + _ => RateLimitBucket { + last_checked: now, + tokens: UNINITIALIZED_TOKEN_AMOUNT, + }, + }, + children: Default::default(), } } - /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478 - /// - /// Returns true if the request passed the rate limit, false if it failed and should be rejected. - #[allow(clippy::float_cmp)] - pub(super) fn check_rate_limit_full( + fn check_total( &mut self, type_: RateLimitType, - ip: &IpAddr, - rate: i32, - per: i32, + now: InstantSecs, + capacity: i32, + secs_to_refill: i32, ) -> bool { - self.insert_ip(ip); - if let Some(bucket) = self.buckets.get_mut(&type_) { - if let Some(rate_limit) = bucket.get_mut(ip) { - let current = Instant::now(); - let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64; + let capacity = capacity as f32; + let secs_to_refill = secs_to_refill as f32; - // The initial value - if rate_limit.allowance == -2f64 { - rate_limit.allowance = f64::from(rate); - }; + #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton + let bucket = &mut self.total[type_]; - rate_limit.last_checked = current; - rate_limit.allowance += time_passed * (f64::from(rate) / f64::from(per)); - if rate_limit.allowance > f64::from(rate) { - rate_limit.allowance = f64::from(rate); - } + if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT { + bucket.tokens = capacity; + } - if rate_limit.allowance < 1.0 { - debug!( - "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}", - type_.as_ref(), - ip, - time_passed, - rate_limit.allowance - ); - false - } else { - rate_limit.allowance -= 1.0; - true - } - } else { - true - } + let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32; + bucket.last_checked = now; + + // For `secs_since_last_checked` seconds, increase `bucket.tokens` + // by `capacity` every `secs_to_refill` seconds + bucket.tokens += { + let tokens_per_sec = capacity / secs_to_refill; + secs_since_last_checked * tokens_per_sec + }; + + // Prevent `bucket.tokens` from exceeding `capacity` + if bucket.tokens > capacity { + bucket.tokens = capacity; + } + + if bucket.tokens < 1.0 { + // Not enough tokens yet + debug!( + "Rate limited type: {}, time_passed: {}, allowance: {}", + type_.as_ref(), + secs_since_last_checked, + bucket.tokens + ); + false } else { + // Consume 1 token + bucket.tokens -= 1.0; true } } } + +/// Rate limiting based on rate type and IP addr +#[derive(PartialEq, Debug, Clone, Default)] +pub struct RateLimitStorage { + /// One bucket per individual IPv4 address + ipv4_buckets: Map, + /// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses + ipv6_buckets: Map<[u8; 6], Map>>, +} + +impl RateLimitStorage { + /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478 + /// + /// Returns true if the request passed the rate limit, false if it failed and should be rejected. + pub(super) fn check_rate_limit_full( + &mut self, + type_: RateLimitType, + ip: IpAddr, + capacity: i32, + secs_to_refill: i32, + now: InstantSecs, + ) -> bool { + let mut result = true; + + match ip { + IpAddr::V4(ipv4) => { + // Only used by one address. + let group = self + .ipv4_buckets + .entry(ipv4) + .or_insert(RateLimitedGroup::new(now)); + + result &= group.check_total(type_, now, capacity, secs_to_refill); + } + + IpAddr::V6(ipv6) => { + let (key_48, key_56, key_64) = split_ipv6(ipv6); + + // Contains all addresses with the same first 48 bits. These addresses might be part of the same network. + let group_48 = self + .ipv6_buckets + .entry(key_48) + .or_insert(RateLimitedGroup::new(now)); + result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill); + + // Contains all addresses with the same first 56 bits. These addresses might be part of the same network. + let group_56 = group_48 + .children + .entry(key_56) + .or_insert(RateLimitedGroup::new(now)); + result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill); + + // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network. + let group_64 = group_56 + .children + .entry(key_64) + .or_insert(RateLimitedGroup::new(now)); + + result &= group_64.check_total(type_, now, capacity, secs_to_refill); + } + }; + + if !result { + debug!("Rate limited IP: {ip}"); + } + + result + } + + /// Remove buckets older than the given duration + pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) { + // Only retain buckets that were last used after `instant` + let Some(instant) = now.to_instant().checked_sub(duration) else { return }; + + let is_recently_used = |group: &RateLimitedGroup<_>| { + group + .total + .values() + .all(|bucket| bucket.last_checked.to_instant() > instant) + }; + + self.ipv4_buckets.retain(|_, group| is_recently_used(group)); + + self.ipv6_buckets.retain(|_, group_48| { + group_48.children.retain(|_, group_56| { + group_56 + .children + .retain(|_, group_64| is_recently_used(group_64)); + !group_56.children.is_empty() + }); + !group_48.children.is_empty() + }) + } +} + +fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) { + let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets(); + ([a0, a1, a2, a3, a4, a5], b, c) +} + +#[cfg(test)] +mod tests { + #[test] + fn test_split_ipv6() { + let ip = std::net::Ipv6Addr::new( + 0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF, + ); + assert_eq!( + super::split_ipv6(ip), + ([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77) + ); + } + + #[test] + fn test_rate_limiter() { + let mut rate_limiter = super::RateLimitStorage::default(); + let mut now = super::InstantSecs::now(); + + let ips = [ + "123.123.123.123", + "1:2:3::", + "1:2:3:0400::", + "1:2:3:0405::", + "1:2:3:0405:6::", + ]; + for ip in ips { + let ip = ip.parse().unwrap(); + let message_passed = + rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now); + let post_passed = + rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now); + assert!(message_passed); + assert!(post_passed); + } + + #[allow(clippy::indexing_slicing)] + let expected_buckets = |factor: f32, tokens_consumed: f32| { + let mut buckets = super::RateLimitedGroup::<()>::new(now).total; + buckets[super::RateLimitType::Message] = super::RateLimitBucket { + last_checked: now, + tokens: (2.0 * factor) - tokens_consumed, + }; + buckets[super::RateLimitType::Post] = super::RateLimitBucket { + last_checked: now, + tokens: (3.0 * factor) - tokens_consumed, + }; + buckets + }; + + let bottom_group = |tokens_consumed| super::RateLimitedGroup { + total: expected_buckets(1.0, tokens_consumed), + children: (), + }; + + assert_eq!( + rate_limiter, + super::RateLimitStorage { + ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(), + ipv6_buckets: [( + [0, 1, 0, 2, 0, 3], + super::RateLimitedGroup { + total: expected_buckets(16.0, 4.0), + children: [ + ( + 0, + super::RateLimitedGroup { + total: expected_buckets(4.0, 1.0), + children: [(0, bottom_group(1.0)),].into(), + } + ), + ( + 4, + super::RateLimitedGroup { + total: expected_buckets(4.0, 3.0), + children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(), + } + ), + ] + .into(), + } + ),] + .into(), + } + ); + + now.secs += 2; + rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now); + assert!(rate_limiter.ipv4_buckets.is_empty()); + assert!(rate_limiter.ipv6_buckets.is_empty()); + } +} diff --git a/src/lib.rs b/src/lib.rs index b76dd106c..1bc00f70c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -121,23 +121,28 @@ pub async fn start_lemmy_server() -> Result<(), LemmyError> { .with(TracingMiddleware::default()) .build(); + let context = LemmyContext::create( + pool.clone(), + client.clone(), + secret.clone(), + rate_limit_cell.clone(), + ); + if scheduled_tasks_enabled { // Schedules various cleanup tasks for the DB - thread::spawn(move || { - scheduled_tasks::setup(db_url, user_agent).expect("Couldn't set up scheduled_tasks"); + thread::spawn({ + let context = context.clone(); + move || { + scheduled_tasks::setup(db_url, user_agent, context) + .expect("Couldn't set up scheduled_tasks"); + } }); } // Create Http server with websocket support let settings_bind = settings.clone(); HttpServer::new(move || { - let context = LemmyContext::create( - pool.clone(), - client.clone(), - secret.clone(), - rate_limit_cell.clone(), - ); - + let context = context.clone(); let federation_config = FederationConfig::builder() .domain(settings.hostname.clone()) .app_data(context.clone()) diff --git a/src/scheduled_tasks.rs b/src/scheduled_tasks.rs index aae78b6f8..9fb1ba702 100644 --- a/src/scheduled_tasks.rs +++ b/src/scheduled_tasks.rs @@ -7,6 +7,7 @@ use diesel::{ }; // Import week days and WeekDay use diesel::{sql_query, PgConnection, RunQueryDsl}; +use lemmy_api_common::context::LemmyContext; use lemmy_db_schema::{ schema::{ activity, @@ -27,7 +28,11 @@ use std::{thread, time::Duration}; use tracing::{error, info}; /// Schedules various cleanup tasks for lemmy in a background thread -pub fn setup(db_url: String, user_agent: String) -> Result<(), LemmyError> { +pub fn setup( + db_url: String, + user_agent: String, + context_1: LemmyContext, +) -> Result<(), LemmyError> { // Setup the connections let mut scheduler = Scheduler::new(); @@ -55,6 +60,12 @@ pub fn setup(db_url: String, user_agent: String) -> Result<(), LemmyError> { clear_old_activities(&mut conn); }); + // Remove old rate limit buckets after 1 to 2 hours of inactivity + scheduler.every(CTimeUnits::hour(1)).run(move || { + let hour = Duration::from_secs(3600); + context_1.settings_updated_channel().remove_older_than(hour); + }); + // Update the Instance Software scheduler.every(CTimeUnits::days(1)).run(move || { let mut conn = PgConnection::establish(&db_url).expect("could not establish connection");