Reduce memory usage of rate limiting (#3111)
* Reduce Vec allocations * Optimize stuff * Move embedded migrations to separate crate * Revert "Move embedded migrations to separate crate" This reverts commit 44b104997016ee2a1f2c0bb735b75e654666860d. * clippy, fmt * Shrink rate limit allowance to f32 * Initialize rate limit allowance directly * Add removal of old rate limit buckets * Improve readability * Remove usage of is_okay_and for Rust 1.67 compatibility * Add dhat-heap feature * Fix api_benchmark.sh and add run_and_benchmark.sh * Revert "Fix api_benchmark.sh and add run_and_benchmark.sh" This reverts commit b4528e5b85dd3f13cea43d72ada9382200c8fc77. * Revert "Add dhat-heap feature" This reverts commit 08e835d487b983c44ce2570d8c396d570d426916. * Manually revert remaining stuff * Use Ipv6Addr in RateLimitStorage * Shrink last_checked in RateLimitBucket to 32 bits * Fix rate_limit::get_ip * Stuff (#1) * Update rate_limiter.rs * Update mod.rs * Update scheduled_tasks.rs * Fix rate_limiter.rs * Dullbananas patch 1 (#2) * Update rate_limiter.rs * Update mod.rs * Update scheduled_tasks.rs * Fix rate_limiter.rs * Rate limit IPv6 addresses in groups * Fmt lib.rs * woodpicker trigger * Refactor and comment `check_rate_limit_full` * Add `test_split_ipv6` * Replace -2.0 with UNINITIALIZED_TOKEN_AMOUNT * Add `test_rate_limiter` --------- Co-authored-by: Dessalines <dessalines@users.noreply.github.com>
This commit is contained in:
parent
b214d3dc00
commit
45818fb4c5
7 changed files with 405 additions and 100 deletions
21
Cargo.lock
generated
21
Cargo.lock
generated
|
@ -1685,6 +1685,26 @@ version = "1.0.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca"
|
checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "enum-map"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "988f0d17a0fa38291e5f41f71ea8d46a5d5497b9054d5a759fae2cbb819f2356"
|
||||||
|
dependencies = [
|
||||||
|
"enum-map-derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "enum-map-derive"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2a4da76b3b6116d758c7ba93f7ec6a35d2e2cf24feda76c6e38a375f4d5c59f2"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.103",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "enum_delegate"
|
name = "enum_delegate"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
|
@ -2816,6 +2836,7 @@ dependencies = [
|
||||||
"deser-hjson",
|
"deser-hjson",
|
||||||
"diesel",
|
"diesel",
|
||||||
"doku",
|
"doku",
|
||||||
|
"enum-map",
|
||||||
"futures",
|
"futures",
|
||||||
"html2text",
|
"html2text",
|
||||||
"http",
|
"http",
|
||||||
|
|
|
@ -45,6 +45,7 @@ jsonwebtoken = "8.1.1"
|
||||||
lettre = "0.10.1"
|
lettre = "0.10.1"
|
||||||
markdown-it = "0.5.0"
|
markdown-it = "0.5.0"
|
||||||
totp-rs = { version = "5.0.2", features = ["gen_secret", "otpauth"] }
|
totp-rs = { version = "5.0.2", features = ["gen_secret", "otpauth"] }
|
||||||
|
enum-map = "2.5"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
|
|
|
@ -14,21 +14,12 @@ pub mod request;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub mod version;
|
pub mod version;
|
||||||
|
|
||||||
use std::{fmt, time::Duration};
|
use std::time::Duration;
|
||||||
|
|
||||||
pub type ConnectionId = usize;
|
pub type ConnectionId = usize;
|
||||||
|
|
||||||
pub const REQWEST_TIMEOUT: Duration = Duration::from_secs(10);
|
pub const REQWEST_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
|
|
||||||
pub struct IpAddr(pub String);
|
|
||||||
|
|
||||||
impl fmt::Display for IpAddr {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "{}", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! location_info {
|
macro_rules! location_info {
|
||||||
() => {
|
() => {
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
use crate::{error::LemmyError, IpAddr};
|
use crate::error::LemmyError;
|
||||||
use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
|
use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
|
||||||
|
use enum_map::enum_map;
|
||||||
use futures::future::{ok, Ready};
|
use futures::future::{ok, Ready};
|
||||||
use rate_limiter::{RateLimitStorage, RateLimitType};
|
use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
|
net::{IpAddr, Ipv4Addr, SocketAddr},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
|
str::FromStr,
|
||||||
sync::{Arc, Mutex},
|
sync::{Arc, Mutex},
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
|
use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
|
||||||
use typed_builder::TypedBuilder;
|
use typed_builder::TypedBuilder;
|
||||||
|
@ -105,6 +109,35 @@ impl RateLimitCell {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Remove buckets older than the given duration
|
||||||
|
pub fn remove_older_than(&self, mut duration: Duration) {
|
||||||
|
let mut guard = self
|
||||||
|
.rate_limit
|
||||||
|
.lock()
|
||||||
|
.expect("Failed to lock rate limit mutex for reading");
|
||||||
|
let rate_limit = &guard.rate_limit_config;
|
||||||
|
|
||||||
|
// If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
|
||||||
|
let max_interval_secs = enum_map! {
|
||||||
|
RateLimitType::Message => rate_limit.message_per_second,
|
||||||
|
RateLimitType::Post => rate_limit.post_per_second,
|
||||||
|
RateLimitType::Register => rate_limit.register_per_second,
|
||||||
|
RateLimitType::Image => rate_limit.image_per_second,
|
||||||
|
RateLimitType::Comment => rate_limit.comment_per_second,
|
||||||
|
RateLimitType::Search => rate_limit.search_per_second,
|
||||||
|
}
|
||||||
|
.into_values()
|
||||||
|
.max()
|
||||||
|
.and_then(|max| u64::try_from(max).ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
|
||||||
|
|
||||||
|
guard
|
||||||
|
.rate_limiter
|
||||||
|
.remove_older_than(duration, InstantSecs::now())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn message(&self) -> RateLimitedGuard {
|
pub fn message(&self) -> RateLimitedGuard {
|
||||||
self.kind(RateLimitType::Message)
|
self.kind(RateLimitType::Message)
|
||||||
}
|
}
|
||||||
|
@ -163,7 +196,7 @@ impl RateLimitedGuard {
|
||||||
};
|
};
|
||||||
let limiter = &mut guard.rate_limiter;
|
let limiter = &mut guard.rate_limiter;
|
||||||
|
|
||||||
limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
|
limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,13 +255,37 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_ip(conn_info: &ConnectionInfo) -> IpAddr {
|
fn get_ip(conn_info: &ConnectionInfo) -> IpAddr {
|
||||||
IpAddr(
|
conn_info
|
||||||
conn_info
|
.realip_remote_addr()
|
||||||
.realip_remote_addr()
|
.and_then(parse_ip)
|
||||||
.unwrap_or("127.0.0.1:12345")
|
.unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
|
||||||
.split(':')
|
}
|
||||||
.next()
|
|
||||||
.unwrap_or("127.0.0.1")
|
fn parse_ip(addr: &str) -> Option<IpAddr> {
|
||||||
.to_string(),
|
if let Some(s) = addr.strip_suffix(']') {
|
||||||
)
|
IpAddr::from_str(s.get(1..)?).ok()
|
||||||
|
} else if let Ok(ip) = IpAddr::from_str(addr) {
|
||||||
|
Some(ip)
|
||||||
|
} else if let Ok(socket) = SocketAddr::from_str(addr) {
|
||||||
|
Some(socket.ip())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[test]
|
||||||
|
fn test_parse_ip() {
|
||||||
|
let ip_addrs = [
|
||||||
|
"1.2.3.4",
|
||||||
|
"1.2.3.4:8000",
|
||||||
|
"2001:db8::",
|
||||||
|
"[2001:db8::]",
|
||||||
|
"[2001:db8::]:8000",
|
||||||
|
];
|
||||||
|
for addr in ip_addrs {
|
||||||
|
assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,50 @@
|
||||||
use crate::IpAddr;
|
use enum_map::{enum_map, EnumMap};
|
||||||
use std::{collections::HashMap, time::Instant};
|
use once_cell::sync::Lazy;
|
||||||
use strum::IntoEnumIterator;
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
|
||||||
struct RateLimitBucket {
|
|
||||||
last_checked: Instant,
|
static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
|
||||||
allowance: f64,
|
|
||||||
|
/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
|
||||||
|
/// store nanoseconds
|
||||||
|
#[derive(PartialEq, Debug, Clone, Copy)]
|
||||||
|
pub struct InstantSecs {
|
||||||
|
secs: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)]
|
impl InstantSecs {
|
||||||
|
pub fn now() -> Self {
|
||||||
|
InstantSecs {
|
||||||
|
secs: u32::try_from(START_TIME.elapsed().as_secs())
|
||||||
|
.expect("server has been running for over 136 years"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn secs_since(self, earlier: Self) -> u32 {
|
||||||
|
self.secs.saturating_sub(earlier.secs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_instant(self) -> Instant {
|
||||||
|
*START_TIME + Duration::from_secs(self.secs.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Debug, Clone)]
|
||||||
|
struct RateLimitBucket {
|
||||||
|
last_checked: InstantSecs,
|
||||||
|
/// This field stores the amount of tokens that were present at `last_checked`.
|
||||||
|
/// The amount of tokens steadily increases until it reaches the bucket's capacity.
|
||||||
|
/// Performing the rate-limited action consumes 1 token.
|
||||||
|
tokens: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
|
||||||
pub(crate) enum RateLimitType {
|
pub(crate) enum RateLimitType {
|
||||||
Message,
|
Message,
|
||||||
Register,
|
Register,
|
||||||
|
@ -19,79 +54,263 @@ pub(crate) enum RateLimitType {
|
||||||
Search,
|
Search,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rate limiting based on rate type and IP addr
|
type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct RateLimitStorage {
|
#[derive(PartialEq, Debug, Clone)]
|
||||||
buckets: HashMap<RateLimitType, HashMap<IpAddr, RateLimitBucket>>,
|
struct RateLimitedGroup<C> {
|
||||||
|
total: EnumMap<RateLimitType, RateLimitBucket>,
|
||||||
|
children: C,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RateLimitStorage {
|
impl<C: Default> RateLimitedGroup<C> {
|
||||||
fn insert_ip(&mut self, ip: &IpAddr) {
|
fn new(now: InstantSecs) -> Self {
|
||||||
for rate_limit_type in RateLimitType::iter() {
|
RateLimitedGroup {
|
||||||
if self.buckets.get(&rate_limit_type).is_none() {
|
total: enum_map! {
|
||||||
self.buckets.insert(rate_limit_type, HashMap::new());
|
_ => RateLimitBucket {
|
||||||
}
|
last_checked: now,
|
||||||
|
tokens: UNINITIALIZED_TOKEN_AMOUNT,
|
||||||
if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
|
},
|
||||||
if bucket.get(ip).is_none() {
|
},
|
||||||
bucket.insert(
|
children: Default::default(),
|
||||||
ip.clone(),
|
|
||||||
RateLimitBucket {
|
|
||||||
last_checked: Instant::now(),
|
|
||||||
allowance: -2f64,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
|
fn check_total(
|
||||||
///
|
|
||||||
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
|
|
||||||
#[allow(clippy::float_cmp)]
|
|
||||||
pub(super) fn check_rate_limit_full(
|
|
||||||
&mut self,
|
&mut self,
|
||||||
type_: RateLimitType,
|
type_: RateLimitType,
|
||||||
ip: &IpAddr,
|
now: InstantSecs,
|
||||||
rate: i32,
|
capacity: i32,
|
||||||
per: i32,
|
secs_to_refill: i32,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
self.insert_ip(ip);
|
let capacity = capacity as f32;
|
||||||
if let Some(bucket) = self.buckets.get_mut(&type_) {
|
let secs_to_refill = secs_to_refill as f32;
|
||||||
if let Some(rate_limit) = bucket.get_mut(ip) {
|
|
||||||
let current = Instant::now();
|
|
||||||
let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64;
|
|
||||||
|
|
||||||
// The initial value
|
#[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
|
||||||
if rate_limit.allowance == -2f64 {
|
let bucket = &mut self.total[type_];
|
||||||
rate_limit.allowance = f64::from(rate);
|
|
||||||
};
|
|
||||||
|
|
||||||
rate_limit.last_checked = current;
|
if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT {
|
||||||
rate_limit.allowance += time_passed * (f64::from(rate) / f64::from(per));
|
bucket.tokens = capacity;
|
||||||
if rate_limit.allowance > f64::from(rate) {
|
}
|
||||||
rate_limit.allowance = f64::from(rate);
|
|
||||||
}
|
|
||||||
|
|
||||||
if rate_limit.allowance < 1.0 {
|
let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32;
|
||||||
debug!(
|
bucket.last_checked = now;
|
||||||
"Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}",
|
|
||||||
type_.as_ref(),
|
// For `secs_since_last_checked` seconds, increase `bucket.tokens`
|
||||||
ip,
|
// by `capacity` every `secs_to_refill` seconds
|
||||||
time_passed,
|
bucket.tokens += {
|
||||||
rate_limit.allowance
|
let tokens_per_sec = capacity / secs_to_refill;
|
||||||
);
|
secs_since_last_checked * tokens_per_sec
|
||||||
false
|
};
|
||||||
} else {
|
|
||||||
rate_limit.allowance -= 1.0;
|
// Prevent `bucket.tokens` from exceeding `capacity`
|
||||||
true
|
if bucket.tokens > capacity {
|
||||||
}
|
bucket.tokens = capacity;
|
||||||
} else {
|
}
|
||||||
true
|
|
||||||
}
|
if bucket.tokens < 1.0 {
|
||||||
|
// Not enough tokens yet
|
||||||
|
debug!(
|
||||||
|
"Rate limited type: {}, time_passed: {}, allowance: {}",
|
||||||
|
type_.as_ref(),
|
||||||
|
secs_since_last_checked,
|
||||||
|
bucket.tokens
|
||||||
|
);
|
||||||
|
false
|
||||||
} else {
|
} else {
|
||||||
|
// Consume 1 token
|
||||||
|
bucket.tokens -= 1.0;
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Rate limiting based on rate type and IP addr
|
||||||
|
#[derive(PartialEq, Debug, Clone, Default)]
|
||||||
|
pub struct RateLimitStorage {
|
||||||
|
/// One bucket per individual IPv4 address
|
||||||
|
ipv4_buckets: Map<Ipv4Addr, ()>,
|
||||||
|
/// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses
|
||||||
|
ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimitStorage {
|
||||||
|
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
|
||||||
|
///
|
||||||
|
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
|
||||||
|
pub(super) fn check_rate_limit_full(
|
||||||
|
&mut self,
|
||||||
|
type_: RateLimitType,
|
||||||
|
ip: IpAddr,
|
||||||
|
capacity: i32,
|
||||||
|
secs_to_refill: i32,
|
||||||
|
now: InstantSecs,
|
||||||
|
) -> bool {
|
||||||
|
let mut result = true;
|
||||||
|
|
||||||
|
match ip {
|
||||||
|
IpAddr::V4(ipv4) => {
|
||||||
|
// Only used by one address.
|
||||||
|
let group = self
|
||||||
|
.ipv4_buckets
|
||||||
|
.entry(ipv4)
|
||||||
|
.or_insert(RateLimitedGroup::new(now));
|
||||||
|
|
||||||
|
result &= group.check_total(type_, now, capacity, secs_to_refill);
|
||||||
|
}
|
||||||
|
|
||||||
|
IpAddr::V6(ipv6) => {
|
||||||
|
let (key_48, key_56, key_64) = split_ipv6(ipv6);
|
||||||
|
|
||||||
|
// Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
|
||||||
|
let group_48 = self
|
||||||
|
.ipv6_buckets
|
||||||
|
.entry(key_48)
|
||||||
|
.or_insert(RateLimitedGroup::new(now));
|
||||||
|
result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
|
||||||
|
|
||||||
|
// Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
|
||||||
|
let group_56 = group_48
|
||||||
|
.children
|
||||||
|
.entry(key_56)
|
||||||
|
.or_insert(RateLimitedGroup::new(now));
|
||||||
|
result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
|
||||||
|
|
||||||
|
// A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
|
||||||
|
let group_64 = group_56
|
||||||
|
.children
|
||||||
|
.entry(key_64)
|
||||||
|
.or_insert(RateLimitedGroup::new(now));
|
||||||
|
|
||||||
|
result &= group_64.check_total(type_, now, capacity, secs_to_refill);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !result {
|
||||||
|
debug!("Rate limited IP: {ip}");
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove buckets older than the given duration
|
||||||
|
pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
|
||||||
|
// Only retain buckets that were last used after `instant`
|
||||||
|
let Some(instant) = now.to_instant().checked_sub(duration) else { return };
|
||||||
|
|
||||||
|
let is_recently_used = |group: &RateLimitedGroup<_>| {
|
||||||
|
group
|
||||||
|
.total
|
||||||
|
.values()
|
||||||
|
.all(|bucket| bucket.last_checked.to_instant() > instant)
|
||||||
|
};
|
||||||
|
|
||||||
|
self.ipv4_buckets.retain(|_, group| is_recently_used(group));
|
||||||
|
|
||||||
|
self.ipv6_buckets.retain(|_, group_48| {
|
||||||
|
group_48.children.retain(|_, group_56| {
|
||||||
|
group_56
|
||||||
|
.children
|
||||||
|
.retain(|_, group_64| is_recently_used(group_64));
|
||||||
|
!group_56.children.is_empty()
|
||||||
|
});
|
||||||
|
!group_48.children.is_empty()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
|
||||||
|
let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
|
||||||
|
([a0, a1, a2, a3, a4, a5], b, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[test]
|
||||||
|
fn test_split_ipv6() {
|
||||||
|
let ip = std::net::Ipv6Addr::new(
|
||||||
|
0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
super::split_ipv6(ip),
|
||||||
|
([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_limiter() {
|
||||||
|
let mut rate_limiter = super::RateLimitStorage::default();
|
||||||
|
let mut now = super::InstantSecs::now();
|
||||||
|
|
||||||
|
let ips = [
|
||||||
|
"123.123.123.123",
|
||||||
|
"1:2:3::",
|
||||||
|
"1:2:3:0400::",
|
||||||
|
"1:2:3:0405::",
|
||||||
|
"1:2:3:0405:6::",
|
||||||
|
];
|
||||||
|
for ip in ips {
|
||||||
|
let ip = ip.parse().unwrap();
|
||||||
|
let message_passed =
|
||||||
|
rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
|
||||||
|
let post_passed =
|
||||||
|
rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
|
||||||
|
assert!(message_passed);
|
||||||
|
assert!(post_passed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::indexing_slicing)]
|
||||||
|
let expected_buckets = |factor: f32, tokens_consumed: f32| {
|
||||||
|
let mut buckets = super::RateLimitedGroup::<()>::new(now).total;
|
||||||
|
buckets[super::RateLimitType::Message] = super::RateLimitBucket {
|
||||||
|
last_checked: now,
|
||||||
|
tokens: (2.0 * factor) - tokens_consumed,
|
||||||
|
};
|
||||||
|
buckets[super::RateLimitType::Post] = super::RateLimitBucket {
|
||||||
|
last_checked: now,
|
||||||
|
tokens: (3.0 * factor) - tokens_consumed,
|
||||||
|
};
|
||||||
|
buckets
|
||||||
|
};
|
||||||
|
|
||||||
|
let bottom_group = |tokens_consumed| super::RateLimitedGroup {
|
||||||
|
total: expected_buckets(1.0, tokens_consumed),
|
||||||
|
children: (),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
rate_limiter,
|
||||||
|
super::RateLimitStorage {
|
||||||
|
ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(),
|
||||||
|
ipv6_buckets: [(
|
||||||
|
[0, 1, 0, 2, 0, 3],
|
||||||
|
super::RateLimitedGroup {
|
||||||
|
total: expected_buckets(16.0, 4.0),
|
||||||
|
children: [
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
super::RateLimitedGroup {
|
||||||
|
total: expected_buckets(4.0, 1.0),
|
||||||
|
children: [(0, bottom_group(1.0)),].into(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
super::RateLimitedGroup {
|
||||||
|
total: expected_buckets(4.0, 3.0),
|
||||||
|
children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
),]
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
now.secs += 2;
|
||||||
|
rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now);
|
||||||
|
assert!(rate_limiter.ipv4_buckets.is_empty());
|
||||||
|
assert!(rate_limiter.ipv6_buckets.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
23
src/lib.rs
23
src/lib.rs
|
@ -121,23 +121,28 @@ pub async fn start_lemmy_server() -> Result<(), LemmyError> {
|
||||||
.with(TracingMiddleware::default())
|
.with(TracingMiddleware::default())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
let context = LemmyContext::create(
|
||||||
|
pool.clone(),
|
||||||
|
client.clone(),
|
||||||
|
secret.clone(),
|
||||||
|
rate_limit_cell.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
if scheduled_tasks_enabled {
|
if scheduled_tasks_enabled {
|
||||||
// Schedules various cleanup tasks for the DB
|
// Schedules various cleanup tasks for the DB
|
||||||
thread::spawn(move || {
|
thread::spawn({
|
||||||
scheduled_tasks::setup(db_url, user_agent).expect("Couldn't set up scheduled_tasks");
|
let context = context.clone();
|
||||||
|
move || {
|
||||||
|
scheduled_tasks::setup(db_url, user_agent, context)
|
||||||
|
.expect("Couldn't set up scheduled_tasks");
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Http server with websocket support
|
// Create Http server with websocket support
|
||||||
let settings_bind = settings.clone();
|
let settings_bind = settings.clone();
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
let context = LemmyContext::create(
|
let context = context.clone();
|
||||||
pool.clone(),
|
|
||||||
client.clone(),
|
|
||||||
secret.clone(),
|
|
||||||
rate_limit_cell.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let federation_config = FederationConfig::builder()
|
let federation_config = FederationConfig::builder()
|
||||||
.domain(settings.hostname.clone())
|
.domain(settings.hostname.clone())
|
||||||
.app_data(context.clone())
|
.app_data(context.clone())
|
||||||
|
|
|
@ -7,6 +7,7 @@ use diesel::{
|
||||||
};
|
};
|
||||||
// Import week days and WeekDay
|
// Import week days and WeekDay
|
||||||
use diesel::{sql_query, PgConnection, RunQueryDsl};
|
use diesel::{sql_query, PgConnection, RunQueryDsl};
|
||||||
|
use lemmy_api_common::context::LemmyContext;
|
||||||
use lemmy_db_schema::{
|
use lemmy_db_schema::{
|
||||||
schema::{
|
schema::{
|
||||||
activity,
|
activity,
|
||||||
|
@ -27,7 +28,11 @@ use std::{thread, time::Duration};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
/// Schedules various cleanup tasks for lemmy in a background thread
|
/// Schedules various cleanup tasks for lemmy in a background thread
|
||||||
pub fn setup(db_url: String, user_agent: String) -> Result<(), LemmyError> {
|
pub fn setup(
|
||||||
|
db_url: String,
|
||||||
|
user_agent: String,
|
||||||
|
context_1: LemmyContext,
|
||||||
|
) -> Result<(), LemmyError> {
|
||||||
// Setup the connections
|
// Setup the connections
|
||||||
let mut scheduler = Scheduler::new();
|
let mut scheduler = Scheduler::new();
|
||||||
|
|
||||||
|
@ -55,6 +60,12 @@ pub fn setup(db_url: String, user_agent: String) -> Result<(), LemmyError> {
|
||||||
clear_old_activities(&mut conn);
|
clear_old_activities(&mut conn);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Remove old rate limit buckets after 1 to 2 hours of inactivity
|
||||||
|
scheduler.every(CTimeUnits::hour(1)).run(move || {
|
||||||
|
let hour = Duration::from_secs(3600);
|
||||||
|
context_1.settings_updated_channel().remove_older_than(hour);
|
||||||
|
});
|
||||||
|
|
||||||
// Update the Instance Software
|
// Update the Instance Software
|
||||||
scheduler.every(CTimeUnits::days(1)).run(move || {
|
scheduler.every(CTimeUnits::days(1)).run(move || {
|
||||||
let mut conn = PgConnection::establish(&db_url).expect("could not establish connection");
|
let mut conn = PgConnection::establish(&db_url).expect("could not establish connection");
|
||||||
|
|
Loading…
Reference in a new issue