Reduce memory usage of rate limiting (#3111)

* Reduce Vec allocations

* Optimize stuff

* Move embedded migrations to separate crate

* Revert "Move embedded migrations to separate crate"

This reverts commit 44b104997016ee2a1f2c0bb735b75e654666860d.

* clippy, fmt

* Shrink rate limit allowance to f32

* Initialize rate limit allowance directly

* Add removal of old rate limit buckets

* Improve readability

* Remove usage of is_okay_and for Rust 1.67 compatibility

* Add dhat-heap feature

* Fix api_benchmark.sh and add run_and_benchmark.sh

* Revert "Fix api_benchmark.sh and add run_and_benchmark.sh"

This reverts commit b4528e5b85dd3f13cea43d72ada9382200c8fc77.

* Revert "Add dhat-heap feature"

This reverts commit 08e835d487b983c44ce2570d8c396d570d426916.

* Manually revert remaining stuff

* Use Ipv6Addr in RateLimitStorage

* Shrink last_checked in RateLimitBucket to 32 bits

* Fix rate_limit::get_ip

* Stuff (#1)

* Update rate_limiter.rs

* Update mod.rs

* Update scheduled_tasks.rs

* Fix rate_limiter.rs

* Dullbananas patch 1 (#2)

* Update rate_limiter.rs

* Update mod.rs

* Update scheduled_tasks.rs

* Fix rate_limiter.rs

* Rate limit IPv6 addresses in groups

* Fmt lib.rs

* woodpicker trigger

* Refactor and comment `check_rate_limit_full`

* Add `test_split_ipv6`

* Replace -2.0 with UNINITIALIZED_TOKEN_AMOUNT

* Add `test_rate_limiter`

---------

Co-authored-by: Dessalines <dessalines@users.noreply.github.com>
This commit is contained in:
dullbananas 2023-06-21 01:28:20 -07:00 committed by GitHub
parent b214d3dc00
commit 45818fb4c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 405 additions and 100 deletions

21
Cargo.lock generated
View file

@ -1685,6 +1685,26 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca"
[[package]]
name = "enum-map"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "988f0d17a0fa38291e5f41f71ea8d46a5d5497b9054d5a759fae2cbb819f2356"
dependencies = [
"enum-map-derive",
]
[[package]]
name = "enum-map-derive"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a4da76b3b6116d758c7ba93f7ec6a35d2e2cf24feda76c6e38a375f4d5c59f2"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.103",
]
[[package]] [[package]]
name = "enum_delegate" name = "enum_delegate"
version = "0.2.0" version = "0.2.0"
@ -2816,6 +2836,7 @@ dependencies = [
"deser-hjson", "deser-hjson",
"diesel", "diesel",
"doku", "doku",
"enum-map",
"futures", "futures",
"html2text", "html2text",
"http", "http",

View file

@ -45,6 +45,7 @@ jsonwebtoken = "8.1.1"
lettre = "0.10.1" lettre = "0.10.1"
markdown-it = "0.5.0" markdown-it = "0.5.0"
totp-rs = { version = "5.0.2", features = ["gen_secret", "otpauth"] } totp-rs = { version = "5.0.2", features = ["gen_secret", "otpauth"] }
enum-map = "2.5"
[dev-dependencies] [dev-dependencies]
reqwest = { workspace = true } reqwest = { workspace = true }

View file

@ -14,21 +14,12 @@ pub mod request;
pub mod utils; pub mod utils;
pub mod version; pub mod version;
use std::{fmt, time::Duration}; use std::time::Duration;
pub type ConnectionId = usize; pub type ConnectionId = usize;
pub const REQWEST_TIMEOUT: Duration = Duration::from_secs(10); pub const REQWEST_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
pub struct IpAddr(pub String);
impl fmt::Display for IpAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[macro_export] #[macro_export]
macro_rules! location_info { macro_rules! location_info {
() => { () => {

View file

@ -1,14 +1,18 @@
use crate::{error::LemmyError, IpAddr}; use crate::error::LemmyError;
use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
use enum_map::enum_map;
use futures::future::{ok, Ready}; use futures::future::{ok, Ready};
use rate_limiter::{RateLimitStorage, RateLimitType}; use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
future::Future, future::Future,
net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin, pin::Pin,
rc::Rc, rc::Rc,
str::FromStr,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
task::{Context, Poll}, task::{Context, Poll},
time::Duration,
}; };
use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
use typed_builder::TypedBuilder; use typed_builder::TypedBuilder;
@ -105,6 +109,35 @@ impl RateLimitCell {
Ok(()) Ok(())
} }
/// Remove buckets older than the given duration
pub fn remove_older_than(&self, mut duration: Duration) {
let mut guard = self
.rate_limit
.lock()
.expect("Failed to lock rate limit mutex for reading");
let rate_limit = &guard.rate_limit_config;
// If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
let max_interval_secs = enum_map! {
RateLimitType::Message => rate_limit.message_per_second,
RateLimitType::Post => rate_limit.post_per_second,
RateLimitType::Register => rate_limit.register_per_second,
RateLimitType::Image => rate_limit.image_per_second,
RateLimitType::Comment => rate_limit.comment_per_second,
RateLimitType::Search => rate_limit.search_per_second,
}
.into_values()
.max()
.and_then(|max| u64::try_from(max).ok())
.unwrap_or(0);
duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
guard
.rate_limiter
.remove_older_than(duration, InstantSecs::now())
}
pub fn message(&self) -> RateLimitedGuard { pub fn message(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Message) self.kind(RateLimitType::Message)
} }
@ -163,7 +196,7 @@ impl RateLimitedGuard {
}; };
let limiter = &mut guard.rate_limiter; let limiter = &mut guard.rate_limiter;
limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
} }
} }
@ -222,13 +255,37 @@ where
} }
fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { fn get_ip(conn_info: &ConnectionInfo) -> IpAddr {
IpAddr(
conn_info conn_info
.realip_remote_addr() .realip_remote_addr()
.unwrap_or("127.0.0.1:12345") .and_then(parse_ip)
.split(':') .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
.next() }
.unwrap_or("127.0.0.1")
.to_string(), fn parse_ip(addr: &str) -> Option<IpAddr> {
) if let Some(s) = addr.strip_suffix(']') {
IpAddr::from_str(s.get(1..)?).ok()
} else if let Ok(ip) = IpAddr::from_str(addr) {
Some(ip)
} else if let Ok(socket) = SocketAddr::from_str(addr) {
Some(socket.ip())
} else {
None
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_parse_ip() {
let ip_addrs = [
"1.2.3.4",
"1.2.3.4:8000",
"2001:db8::",
"[2001:db8::]",
"[2001:db8::]:8000",
];
for addr in ip_addrs {
assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}");
}
}
} }

View file

@ -1,15 +1,50 @@
use crate::IpAddr; use enum_map::{enum_map, EnumMap};
use std::{collections::HashMap, time::Instant}; use once_cell::sync::Lazy;
use strum::IntoEnumIterator; use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
time::{Duration, Instant},
};
use tracing::debug; use tracing::debug;
#[derive(Debug, Clone)] const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
struct RateLimitBucket {
last_checked: Instant, static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
allowance: f64,
/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
/// store nanoseconds
#[derive(PartialEq, Debug, Clone, Copy)]
pub struct InstantSecs {
secs: u32,
} }
#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)] impl InstantSecs {
pub fn now() -> Self {
InstantSecs {
secs: u32::try_from(START_TIME.elapsed().as_secs())
.expect("server has been running for over 136 years"),
}
}
fn secs_since(self, earlier: Self) -> u32 {
self.secs.saturating_sub(earlier.secs)
}
fn to_instant(self) -> Instant {
*START_TIME + Duration::from_secs(self.secs.into())
}
}
#[derive(PartialEq, Debug, Clone)]
struct RateLimitBucket {
last_checked: InstantSecs,
/// This field stores the amount of tokens that were present at `last_checked`.
/// The amount of tokens steadily increases until it reaches the bucket's capacity.
/// Performing the rate-limited action consumes 1 token.
tokens: f32,
}
#[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
pub(crate) enum RateLimitType { pub(crate) enum RateLimitType {
Message, Message,
Register, Register,
@ -19,79 +54,263 @@ pub(crate) enum RateLimitType {
Search, Search,
} }
/// Rate limiting based on rate type and IP addr type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
#[derive(Debug, Clone, Default)]
pub struct RateLimitStorage { #[derive(PartialEq, Debug, Clone)]
buckets: HashMap<RateLimitType, HashMap<IpAddr, RateLimitBucket>>, struct RateLimitedGroup<C> {
total: EnumMap<RateLimitType, RateLimitBucket>,
children: C,
} }
impl RateLimitStorage { impl<C: Default> RateLimitedGroup<C> {
fn insert_ip(&mut self, ip: &IpAddr) { fn new(now: InstantSecs) -> Self {
for rate_limit_type in RateLimitType::iter() { RateLimitedGroup {
if self.buckets.get(&rate_limit_type).is_none() { total: enum_map! {
self.buckets.insert(rate_limit_type, HashMap::new()); _ => RateLimitBucket {
} last_checked: now,
tokens: UNINITIALIZED_TOKEN_AMOUNT,
if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
if bucket.get(ip).is_none() {
bucket.insert(
ip.clone(),
RateLimitBucket {
last_checked: Instant::now(),
allowance: -2f64,
}, },
); },
} children: Default::default(),
}
} }
} }
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478 fn check_total(
///
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
#[allow(clippy::float_cmp)]
pub(super) fn check_rate_limit_full(
&mut self, &mut self,
type_: RateLimitType, type_: RateLimitType,
ip: &IpAddr, now: InstantSecs,
rate: i32, capacity: i32,
per: i32, secs_to_refill: i32,
) -> bool { ) -> bool {
self.insert_ip(ip); let capacity = capacity as f32;
if let Some(bucket) = self.buckets.get_mut(&type_) { let secs_to_refill = secs_to_refill as f32;
if let Some(rate_limit) = bucket.get_mut(ip) {
let current = Instant::now();
let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64;
// The initial value #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
if rate_limit.allowance == -2f64 { let bucket = &mut self.total[type_];
rate_limit.allowance = f64::from(rate);
};
rate_limit.last_checked = current; if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT {
rate_limit.allowance += time_passed * (f64::from(rate) / f64::from(per)); bucket.tokens = capacity;
if rate_limit.allowance > f64::from(rate) {
rate_limit.allowance = f64::from(rate);
} }
if rate_limit.allowance < 1.0 { let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32;
bucket.last_checked = now;
// For `secs_since_last_checked` seconds, increase `bucket.tokens`
// by `capacity` every `secs_to_refill` seconds
bucket.tokens += {
let tokens_per_sec = capacity / secs_to_refill;
secs_since_last_checked * tokens_per_sec
};
// Prevent `bucket.tokens` from exceeding `capacity`
if bucket.tokens > capacity {
bucket.tokens = capacity;
}
if bucket.tokens < 1.0 {
// Not enough tokens yet
debug!( debug!(
"Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}", "Rate limited type: {}, time_passed: {}, allowance: {}",
type_.as_ref(), type_.as_ref(),
ip, secs_since_last_checked,
time_passed, bucket.tokens
rate_limit.allowance
); );
false false
} else { } else {
rate_limit.allowance -= 1.0; // Consume 1 token
true bucket.tokens -= 1.0;
}
} else {
true
}
} else {
true true
} }
} }
} }
/// Rate limiting based on rate type and IP addr
#[derive(PartialEq, Debug, Clone, Default)]
pub struct RateLimitStorage {
/// One bucket per individual IPv4 address
ipv4_buckets: Map<Ipv4Addr, ()>,
/// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses
ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
}
impl RateLimitStorage {
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
///
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
pub(super) fn check_rate_limit_full(
&mut self,
type_: RateLimitType,
ip: IpAddr,
capacity: i32,
secs_to_refill: i32,
now: InstantSecs,
) -> bool {
let mut result = true;
match ip {
IpAddr::V4(ipv4) => {
// Only used by one address.
let group = self
.ipv4_buckets
.entry(ipv4)
.or_insert(RateLimitedGroup::new(now));
result &= group.check_total(type_, now, capacity, secs_to_refill);
}
IpAddr::V6(ipv6) => {
let (key_48, key_56, key_64) = split_ipv6(ipv6);
// Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
let group_48 = self
.ipv6_buckets
.entry(key_48)
.or_insert(RateLimitedGroup::new(now));
result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
// Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
let group_56 = group_48
.children
.entry(key_56)
.or_insert(RateLimitedGroup::new(now));
result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
// A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
let group_64 = group_56
.children
.entry(key_64)
.or_insert(RateLimitedGroup::new(now));
result &= group_64.check_total(type_, now, capacity, secs_to_refill);
}
};
if !result {
debug!("Rate limited IP: {ip}");
}
result
}
/// Remove buckets older than the given duration
pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
// Only retain buckets that were last used after `instant`
let Some(instant) = now.to_instant().checked_sub(duration) else { return };
let is_recently_used = |group: &RateLimitedGroup<_>| {
group
.total
.values()
.all(|bucket| bucket.last_checked.to_instant() > instant)
};
self.ipv4_buckets.retain(|_, group| is_recently_used(group));
self.ipv6_buckets.retain(|_, group_48| {
group_48.children.retain(|_, group_56| {
group_56
.children
.retain(|_, group_64| is_recently_used(group_64));
!group_56.children.is_empty()
});
!group_48.children.is_empty()
})
}
}
fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
([a0, a1, a2, a3, a4, a5], b, c)
}
#[cfg(test)]
mod tests {
#[test]
fn test_split_ipv6() {
let ip = std::net::Ipv6Addr::new(
0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
);
assert_eq!(
super::split_ipv6(ip),
([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
);
}
#[test]
fn test_rate_limiter() {
let mut rate_limiter = super::RateLimitStorage::default();
let mut now = super::InstantSecs::now();
let ips = [
"123.123.123.123",
"1:2:3::",
"1:2:3:0400::",
"1:2:3:0405::",
"1:2:3:0405:6::",
];
for ip in ips {
let ip = ip.parse().unwrap();
let message_passed =
rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
let post_passed =
rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
assert!(message_passed);
assert!(post_passed);
}
#[allow(clippy::indexing_slicing)]
let expected_buckets = |factor: f32, tokens_consumed: f32| {
let mut buckets = super::RateLimitedGroup::<()>::new(now).total;
buckets[super::RateLimitType::Message] = super::RateLimitBucket {
last_checked: now,
tokens: (2.0 * factor) - tokens_consumed,
};
buckets[super::RateLimitType::Post] = super::RateLimitBucket {
last_checked: now,
tokens: (3.0 * factor) - tokens_consumed,
};
buckets
};
let bottom_group = |tokens_consumed| super::RateLimitedGroup {
total: expected_buckets(1.0, tokens_consumed),
children: (),
};
assert_eq!(
rate_limiter,
super::RateLimitStorage {
ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(),
ipv6_buckets: [(
[0, 1, 0, 2, 0, 3],
super::RateLimitedGroup {
total: expected_buckets(16.0, 4.0),
children: [
(
0,
super::RateLimitedGroup {
total: expected_buckets(4.0, 1.0),
children: [(0, bottom_group(1.0)),].into(),
}
),
(
4,
super::RateLimitedGroup {
total: expected_buckets(4.0, 3.0),
children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(),
}
),
]
.into(),
}
),]
.into(),
}
);
now.secs += 2;
rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now);
assert!(rate_limiter.ipv4_buckets.is_empty());
assert!(rate_limiter.ipv6_buckets.is_empty());
}
}

View file

@ -121,16 +121,6 @@ pub async fn start_lemmy_server() -> Result<(), LemmyError> {
.with(TracingMiddleware::default()) .with(TracingMiddleware::default())
.build(); .build();
if scheduled_tasks_enabled {
// Schedules various cleanup tasks for the DB
thread::spawn(move || {
scheduled_tasks::setup(db_url, user_agent).expect("Couldn't set up scheduled_tasks");
});
}
// Create Http server with websocket support
let settings_bind = settings.clone();
HttpServer::new(move || {
let context = LemmyContext::create( let context = LemmyContext::create(
pool.clone(), pool.clone(),
client.clone(), client.clone(),
@ -138,6 +128,21 @@ pub async fn start_lemmy_server() -> Result<(), LemmyError> {
rate_limit_cell.clone(), rate_limit_cell.clone(),
); );
if scheduled_tasks_enabled {
// Schedules various cleanup tasks for the DB
thread::spawn({
let context = context.clone();
move || {
scheduled_tasks::setup(db_url, user_agent, context)
.expect("Couldn't set up scheduled_tasks");
}
});
}
// Create Http server with websocket support
let settings_bind = settings.clone();
HttpServer::new(move || {
let context = context.clone();
let federation_config = FederationConfig::builder() let federation_config = FederationConfig::builder()
.domain(settings.hostname.clone()) .domain(settings.hostname.clone())
.app_data(context.clone()) .app_data(context.clone())

View file

@ -7,6 +7,7 @@ use diesel::{
}; };
// Import week days and WeekDay // Import week days and WeekDay
use diesel::{sql_query, PgConnection, RunQueryDsl}; use diesel::{sql_query, PgConnection, RunQueryDsl};
use lemmy_api_common::context::LemmyContext;
use lemmy_db_schema::{ use lemmy_db_schema::{
schema::{ schema::{
activity, activity,
@ -27,7 +28,11 @@ use std::{thread, time::Duration};
use tracing::{error, info}; use tracing::{error, info};
/// Schedules various cleanup tasks for lemmy in a background thread /// Schedules various cleanup tasks for lemmy in a background thread
pub fn setup(db_url: String, user_agent: String) -> Result<(), LemmyError> { pub fn setup(
db_url: String,
user_agent: String,
context_1: LemmyContext,
) -> Result<(), LemmyError> {
// Setup the connections // Setup the connections
let mut scheduler = Scheduler::new(); let mut scheduler = Scheduler::new();
@ -55,6 +60,12 @@ pub fn setup(db_url: String, user_agent: String) -> Result<(), LemmyError> {
clear_old_activities(&mut conn); clear_old_activities(&mut conn);
}); });
// Remove old rate limit buckets after 1 to 2 hours of inactivity
scheduler.every(CTimeUnits::hour(1)).run(move || {
let hour = Duration::from_secs(3600);
context_1.settings_updated_channel().remove_older_than(hour);
});
// Update the Instance Software // Update the Instance Software
scheduler.every(CTimeUnits::days(1)).run(move || { scheduler.every(CTimeUnits::days(1)).run(move || {
let mut conn = PgConnection::establish(&db_url).expect("could not establish connection"); let mut conn = PgConnection::establish(&db_url).expect("could not establish connection");