Rate limit websocket joins. (#2165)
* Rate limit websocket joins. * Removing async on mutex lock fn. * Removing redundant ip * Return early if check fails.
This commit is contained in:
parent
483e7ab168
commit
f2a0841586
10 changed files with 47 additions and 15 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -1962,6 +1962,7 @@ dependencies = [
|
||||||
"lemmy_utils",
|
"lemmy_utils",
|
||||||
"lemmy_websocket",
|
"lemmy_websocket",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
"parking_lot 0.12.0",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand 0.8.4",
|
"rand 0.8.4",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
@ -2129,6 +2130,7 @@ dependencies = [
|
||||||
"openssl",
|
"openssl",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
|
"parking_lot 0.12.0",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"reqwest-middleware",
|
"reqwest-middleware",
|
||||||
"reqwest-tracing",
|
"reqwest-tracing",
|
||||||
|
@ -2166,6 +2168,7 @@ dependencies = [
|
||||||
"lettre",
|
"lettre",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"openssl",
|
"openssl",
|
||||||
|
"parking_lot 0.12.0",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand 0.8.4",
|
"rand 0.8.4",
|
||||||
"regex",
|
"regex",
|
||||||
|
@ -2204,6 +2207,7 @@ dependencies = [
|
||||||
"lemmy_db_views_actor",
|
"lemmy_db_views_actor",
|
||||||
"lemmy_utils",
|
"lemmy_utils",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
|
"parking_lot 0.12.0",
|
||||||
"rand 0.8.4",
|
"rand 0.8.4",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"reqwest-middleware",
|
"reqwest-middleware",
|
||||||
|
|
|
@ -75,3 +75,4 @@ doku = "0.10.2"
|
||||||
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.9"
|
opentelemetry-otlp = "0.9"
|
||||||
tracing-opentelemetry = "0.16"
|
tracing-opentelemetry = "0.16"
|
||||||
|
parking_lot = "0.12"
|
||||||
|
|
|
@ -50,6 +50,7 @@ background-jobs = "0.11.0"
|
||||||
reqwest = { version = "0.11.7", features = ["json"] }
|
reqwest = { version = "0.11.7", features = ["json"] }
|
||||||
html2md = "0.2.13"
|
html2md = "0.2.13"
|
||||||
once_cell = "1.8.0"
|
once_cell = "1.8.0"
|
||||||
|
parking_lot = "0.12"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
serial_test = "0.5.1"
|
serial_test = "0.5.1"
|
||||||
|
|
|
@ -58,10 +58,10 @@ pub(crate) mod tests {
|
||||||
LemmyError,
|
LemmyError,
|
||||||
};
|
};
|
||||||
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
|
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use reqwest_middleware::ClientBuilder;
|
use reqwest_middleware::ClientBuilder;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
|
||||||
|
|
||||||
// TODO: would be nice if we didnt have to use a full context for tests.
|
// TODO: would be nice if we didnt have to use a full context for tests.
|
||||||
// or at least write a helper function so this code is shared with main.rs
|
// or at least write a helper function so this code is shared with main.rs
|
||||||
|
|
|
@ -48,6 +48,7 @@ uuid = { version = "0.8.2", features = ["serde", "v4"] }
|
||||||
encoding = "0.2.33"
|
encoding = "0.2.33"
|
||||||
html2text = "0.2.1"
|
html2text = "0.2.1"
|
||||||
rosetta-i18n = "0.1"
|
rosetta-i18n = "0.1"
|
||||||
|
parking_lot = "0.12"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
rosetta-build = "0.1"
|
rosetta-build = "0.1"
|
||||||
|
|
|
@ -4,6 +4,7 @@ use actix_web::{
|
||||||
HttpResponse,
|
HttpResponse,
|
||||||
};
|
};
|
||||||
use futures::future::{ok, Ready};
|
use futures::future::{ok, Ready};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use rate_limiter::{RateLimitType, RateLimiter};
|
use rate_limiter::{RateLimitType, RateLimiter};
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
|
@ -12,7 +13,6 @@ use std::{
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tokio::sync::Mutex;
|
|
||||||
|
|
||||||
pub mod rate_limiter;
|
pub mod rate_limiter;
|
||||||
|
|
||||||
|
@ -68,13 +68,11 @@ impl RateLimit {
|
||||||
|
|
||||||
impl RateLimited {
|
impl RateLimited {
|
||||||
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
|
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
|
||||||
pub async fn check(self, ip_addr: IpAddr) -> bool {
|
pub fn check(self, ip_addr: IpAddr) -> bool {
|
||||||
// Does not need to be blocking because the RwLock in settings never held across await points,
|
// Does not need to be blocking because the RwLock in settings never held across await points,
|
||||||
// and the operation here locks only long enough to clone
|
// and the operation here locks only long enough to clone
|
||||||
let rate_limit = self.rate_limit_config;
|
let rate_limit = self.rate_limit_config;
|
||||||
|
|
||||||
let mut limiter = self.rate_limiter.lock().await;
|
|
||||||
|
|
||||||
let (kind, interval) = match self.type_ {
|
let (kind, interval) = match self.type_ {
|
||||||
RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
|
RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
|
||||||
RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
|
RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
|
||||||
|
@ -82,6 +80,8 @@ impl RateLimited {
|
||||||
RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
|
RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
|
||||||
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
|
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
|
||||||
};
|
};
|
||||||
|
let mut limiter = self.rate_limiter.lock();
|
||||||
|
|
||||||
limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
|
limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ where
|
||||||
let service = self.service.clone();
|
let service = self.service.clone();
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
if rate_limited.check(ip_addr).await {
|
if rate_limited.check(ip_addr) {
|
||||||
service.call(req).await
|
service.call(req).await
|
||||||
} else {
|
} else {
|
||||||
let (http_req, _) = req.into_parts();
|
let (http_req, _) = req.into_parts();
|
||||||
|
|
|
@ -36,3 +36,4 @@ actix-web = { version = "4.0.0", default-features = false, features = ["rustls"]
|
||||||
actix-web-actors = { version = "4.1.0", default-features = false }
|
actix-web-actors = { version = "4.1.0", default-features = false }
|
||||||
opentelemetry = "0.16"
|
opentelemetry = "0.16"
|
||||||
tracing-opentelemetry = "0.16"
|
tracing-opentelemetry = "0.16"
|
||||||
|
parking_lot = "0.12"
|
||||||
|
|
|
@ -481,19 +481,19 @@ impl ChatServer {
|
||||||
// check if api call passes the rate limit, and generate future for later execution
|
// check if api call passes the rate limit, and generate future for later execution
|
||||||
let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
|
let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
|
||||||
let passed = match user_operation_crud {
|
let passed = match user_operation_crud {
|
||||||
UserOperationCrud::Register => rate_limiter.register().check(ip).await,
|
UserOperationCrud::Register => rate_limiter.register().check(ip),
|
||||||
UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await,
|
UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
|
||||||
UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await,
|
UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
|
||||||
UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await,
|
UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
|
||||||
_ => rate_limiter.message().check(ip).await,
|
_ => rate_limiter.message().check(ip),
|
||||||
};
|
};
|
||||||
let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
|
let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
|
||||||
(passed, fut)
|
(passed, fut)
|
||||||
} else {
|
} else {
|
||||||
let user_operation = UserOperation::from_str(op)?;
|
let user_operation = UserOperation::from_str(op)?;
|
||||||
let passed = match user_operation {
|
let passed = match user_operation {
|
||||||
UserOperation::GetCaptcha => rate_limiter.post().check(ip).await,
|
UserOperation::GetCaptcha => rate_limiter.post().check(ip),
|
||||||
_ => rate_limiter.message().check(ip).await,
|
_ => rate_limiter.message().check(ip),
|
||||||
};
|
};
|
||||||
let fut = (message_handler)(context, msg.id, user_operation, data);
|
let fut = (message_handler)(context, msg.id, user_operation, data);
|
||||||
(passed, fut)
|
(passed, fut)
|
||||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
||||||
use actix::prelude::*;
|
use actix::prelude::*;
|
||||||
use actix_web::{web, Error, HttpRequest, HttpResponse};
|
use actix_web::{web, Error, HttpRequest, HttpResponse};
|
||||||
use actix_web_actors::ws;
|
use actix_web_actors::ws;
|
||||||
use lemmy_utils::{utils::get_ip, ConnectionId, IpAddr};
|
use lemmy_utils::{rate_limit::RateLimit, utils::get_ip, ConnectionId, IpAddr};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ pub async fn chat_route(
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
stream: web::Payload,
|
stream: web::Payload,
|
||||||
context: web::Data<LemmyContext>,
|
context: web::Data<LemmyContext>,
|
||||||
|
rate_limiter: web::Data<RateLimit>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Result<HttpResponse, Error> {
|
||||||
ws::start(
|
ws::start(
|
||||||
WsSession {
|
WsSession {
|
||||||
|
@ -27,6 +28,7 @@ pub async fn chat_route(
|
||||||
id: 0,
|
id: 0,
|
||||||
hb: Instant::now(),
|
hb: Instant::now(),
|
||||||
ip: get_ip(&req.connection_info()),
|
ip: get_ip(&req.connection_info()),
|
||||||
|
rate_limiter: rate_limiter.as_ref().to_owned(),
|
||||||
},
|
},
|
||||||
&req,
|
&req,
|
||||||
stream,
|
stream,
|
||||||
|
@ -41,6 +43,8 @@ struct WsSession {
|
||||||
/// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
|
/// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
|
||||||
/// otherwise we drop connection.
|
/// otherwise we drop connection.
|
||||||
hb: Instant,
|
hb: Instant,
|
||||||
|
/// A rate limiter for websocket joins
|
||||||
|
rate_limiter: RateLimit,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Actor for WsSession {
|
impl Actor for WsSession {
|
||||||
|
@ -57,6 +61,11 @@ impl Actor for WsSession {
|
||||||
// before processing any other events.
|
// before processing any other events.
|
||||||
// across all routes within application
|
// across all routes within application
|
||||||
let addr = ctx.address();
|
let addr = ctx.address();
|
||||||
|
|
||||||
|
if !self.rate_limit_check(ctx) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
self
|
self
|
||||||
.cs_addr
|
.cs_addr
|
||||||
.send(Connect {
|
.send(Connect {
|
||||||
|
@ -98,6 +107,10 @@ impl Handler<WsMessage> for WsSession {
|
||||||
/// WebSocket message handler
|
/// WebSocket message handler
|
||||||
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
|
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
|
||||||
fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
|
fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
|
||||||
|
if !self.rate_limit_check(ctx) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let message = match result {
|
let message = match result {
|
||||||
Ok(m) => m,
|
Ok(m) => m,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -169,4 +182,14 @@ impl WsSession {
|
||||||
ctx.ping(b"");
|
ctx.ping(b"");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check the rate limit, and stop the ctx if it fails
|
||||||
|
fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext<Self>) -> bool {
|
||||||
|
let check = self.rate_limiter.message().check(self.ip.to_owned());
|
||||||
|
if !check {
|
||||||
|
debug!("Websocket join with IP: {} has been rate limited.", self.ip);
|
||||||
|
ctx.stop()
|
||||||
|
}
|
||||||
|
check
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,11 +29,11 @@ use lemmy_utils::{
|
||||||
REQWEST_TIMEOUT,
|
REQWEST_TIMEOUT,
|
||||||
};
|
};
|
||||||
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
|
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use reqwest_middleware::ClientBuilder;
|
use reqwest_middleware::ClientBuilder;
|
||||||
use reqwest_tracing::TracingMiddleware;
|
use reqwest_tracing::TracingMiddleware;
|
||||||
use std::{env, sync::Arc, thread};
|
use std::{env, sync::Arc, thread};
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tracing_actix_web::TracingLogger;
|
use tracing_actix_web::TracingLogger;
|
||||||
|
|
||||||
embed_migrations!();
|
embed_migrations!();
|
||||||
|
@ -136,6 +136,7 @@ async fn main() -> Result<(), LemmyError> {
|
||||||
.wrap(actix_web::middleware::Logger::default())
|
.wrap(actix_web::middleware::Logger::default())
|
||||||
.wrap(TracingLogger::<QuieterRootSpanBuilder>::new())
|
.wrap(TracingLogger::<QuieterRootSpanBuilder>::new())
|
||||||
.app_data(Data::new(context))
|
.app_data(Data::new(context))
|
||||||
|
.app_data(Data::new(rate_limiter.clone()))
|
||||||
// The routes
|
// The routes
|
||||||
.configure(|cfg| api_routes::config(cfg, &rate_limiter))
|
.configure(|cfg| api_routes::config(cfg, &rate_limiter))
|
||||||
.configure(|cfg| lemmy_apub::http::routes::config(cfg, &settings))
|
.configure(|cfg| lemmy_apub::http::routes::config(cfg, &settings))
|
||||||
|
|
Loading…
Reference in a new issue