Merge branch 'abstract_websocket_sends'

This commit is contained in:
Dessalines 2020-04-20 15:39:34 -04:00
commit d0c38b0927
22 changed files with 2527 additions and 1198 deletions

778
server/Cargo.lock generated vendored

File diff suppressed because it is too large Load diff

2
server/Cargo.toml vendored
View file

@ -37,3 +37,5 @@ hjson = "0.8.2"
percent-encoding = "2.1.0" percent-encoding = "2.1.0"
isahc = "0.9" isahc = "0.9"
comrak = "0.7" comrak = "0.7"
tokio = "0.2.18"
futures = "0.3.4"

View file

@ -1,9 +1,4 @@
use super::*; use super::*;
use crate::send_email;
use crate::settings::Settings;
use diesel::PgConnection;
use log::error;
use std::str::FromStr;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct CreateComment { pub struct CreateComment {
@ -64,8 +59,14 @@ pub struct GetCommentsResponse {
comments: Vec<CommentView>, comments: Vec<CommentView>,
} }
impl Perform<CommentResponse> for Oper<CreateComment> { impl Perform for Oper<CreateComment> {
fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> { type Response = CommentResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<CommentResponse, Error> {
let data: &CreateComment = &self.data; let data: &CreateComment = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -77,6 +78,8 @@ impl Perform<CommentResponse> for Oper<CreateComment> {
let hostname = &format!("https://{}", Settings::get().hostname); let hostname = &format!("https://{}", Settings::get().hostname);
let conn = pool.get()?;
// Check for a community ban // Check for a community ban
let post = Post::read(&conn, data.post_id)?; let post = Post::read(&conn, data.post_id)?;
if CommunityUserBanView::get(&conn, user_id, post.community_id).is_ok() { if CommunityUserBanView::get(&conn, user_id, post.community_id).is_ok() {
@ -223,15 +226,35 @@ impl Perform<CommentResponse> for Oper<CreateComment> {
let comment_view = CommentView::read(&conn, inserted_comment.id, Some(user_id))?; let comment_view = CommentView::read(&conn, inserted_comment.id, Some(user_id))?;
Ok(CommentResponse { let mut res = CommentResponse {
comment: comment_view, comment: comment_view,
recipient_ids, recipient_ids,
}) };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendComment {
op: UserOperation::CreateComment,
comment: res.clone(),
my_id: ws.id,
});
// strip out the recipient_ids, so that
// users don't get double notifs
res.recipient_ids = Vec::new();
}
Ok(res)
} }
} }
impl Perform<CommentResponse> for Oper<EditComment> { impl Perform for Oper<EditComment> {
fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> { type Response = CommentResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<CommentResponse, Error> {
let data: &EditComment = &self.data; let data: &EditComment = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -241,6 +264,8 @@ impl Perform<CommentResponse> for Oper<EditComment> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let orig_comment = CommentView::read(&conn, data.edit_id, None)?; let orig_comment = CommentView::read(&conn, data.edit_id, None)?;
// You are allowed to mark the comment as read even if you're banned. // You are allowed to mark the comment as read even if you're banned.
@ -353,15 +378,35 @@ impl Perform<CommentResponse> for Oper<EditComment> {
let comment_view = CommentView::read(&conn, data.edit_id, Some(user_id))?; let comment_view = CommentView::read(&conn, data.edit_id, Some(user_id))?;
Ok(CommentResponse { let mut res = CommentResponse {
comment: comment_view, comment: comment_view,
recipient_ids, recipient_ids,
}) };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendComment {
op: UserOperation::EditComment,
comment: res.clone(),
my_id: ws.id,
});
// strip out the recipient_ids, so that
// users don't get double notifs
res.recipient_ids = Vec::new();
}
Ok(res)
} }
} }
impl Perform<CommentResponse> for Oper<SaveComment> { impl Perform for Oper<SaveComment> {
fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> { type Response = CommentResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<CommentResponse, Error> {
let data: &SaveComment = &self.data; let data: &SaveComment = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -376,6 +421,8 @@ impl Perform<CommentResponse> for Oper<SaveComment> {
user_id, user_id,
}; };
let conn = pool.get()?;
if data.save { if data.save {
match CommentSaved::save(&conn, &comment_saved_form) { match CommentSaved::save(&conn, &comment_saved_form) {
Ok(comment) => comment, Ok(comment) => comment,
@ -397,8 +444,14 @@ impl Perform<CommentResponse> for Oper<SaveComment> {
} }
} }
impl Perform<CommentResponse> for Oper<CreateCommentLike> { impl Perform for Oper<CreateCommentLike> {
fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> { type Response = CommentResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<CommentResponse, Error> {
let data: &CreateCommentLike = &self.data; let data: &CreateCommentLike = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -410,6 +463,8 @@ impl Perform<CommentResponse> for Oper<CreateCommentLike> {
let mut recipient_ids = Vec::new(); let mut recipient_ids = Vec::new();
let conn = pool.get()?;
// Don't do a downvote if site has downvotes disabled // Don't do a downvote if site has downvotes disabled
if data.score == -1 { if data.score == -1 {
let site = SiteView::read(&conn)?; let site = SiteView::read(&conn)?;
@ -467,15 +522,35 @@ impl Perform<CommentResponse> for Oper<CreateCommentLike> {
// Have to refetch the comment to get the current state // Have to refetch the comment to get the current state
let liked_comment = CommentView::read(&conn, data.comment_id, Some(user_id))?; let liked_comment = CommentView::read(&conn, data.comment_id, Some(user_id))?;
Ok(CommentResponse { let mut res = CommentResponse {
comment: liked_comment, comment: liked_comment,
recipient_ids, recipient_ids,
}) };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendComment {
op: UserOperation::CreateCommentLike,
comment: res.clone(),
my_id: ws.id,
});
// strip out the recipient_ids, so that
// users don't get double notifs
res.recipient_ids = Vec::new();
}
Ok(res)
} }
} }
impl Perform<GetCommentsResponse> for Oper<GetComments> { impl Perform for Oper<GetComments> {
fn perform(&self, conn: &PgConnection) -> Result<GetCommentsResponse, Error> { type Response = GetCommentsResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<GetCommentsResponse, Error> {
let data: &GetComments = &self.data; let data: &GetComments = &self.data;
let user_claims: Option<Claims> = match &data.auth { let user_claims: Option<Claims> = match &data.auth {
@ -494,6 +569,8 @@ impl Perform<GetCommentsResponse> for Oper<GetComments> {
let type_ = ListingType::from_str(&data.type_)?; let type_ = ListingType::from_str(&data.type_)?;
let sort = SortType::from_str(&data.sort)?; let sort = SortType::from_str(&data.sort)?;
let conn = pool.get()?;
let comments = match CommentQueryBuilder::create(&conn) let comments = match CommentQueryBuilder::create(&conn)
.listing_type(type_) .listing_type(type_)
.sort(&sort) .sort(&sort)
@ -507,6 +584,20 @@ impl Perform<GetCommentsResponse> for Oper<GetComments> {
Err(_e) => return Err(APIError::err("couldnt_get_comments").into()), Err(_e) => return Err(APIError::err("couldnt_get_comments").into()),
}; };
if let Some(ws) = websocket_info {
// You don't need to join the specific community room, bc this is already handled by
// GetCommunity
if data.community_id.is_none() {
if let Some(id) = ws.id {
// 0 is the "all" community
ws.chatserver.do_send(JoinCommunityRoom {
community_id: 0,
id,
});
}
}
}
Ok(GetCommentsResponse { comments }) Ok(GetCommentsResponse { comments })
} }
} }

View file

@ -1,6 +1,4 @@
use super::*; use super::*;
use diesel::PgConnection;
use std::str::FromStr;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct GetCommunity { pub struct GetCommunity {
@ -55,7 +53,7 @@ pub struct BanFromCommunity {
auth: String, auth: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub struct BanFromCommunityResponse { pub struct BanFromCommunityResponse {
user: UserView, user: UserView,
banned: bool, banned: bool,
@ -69,7 +67,7 @@ pub struct AddModToCommunity {
auth: String, auth: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub struct AddModToCommunityResponse { pub struct AddModToCommunityResponse {
moderators: Vec<CommunityModeratorView>, moderators: Vec<CommunityModeratorView>,
} }
@ -113,8 +111,14 @@ pub struct TransferCommunity {
auth: String, auth: String,
} }
impl Perform<GetCommunityResponse> for Oper<GetCommunity> { impl Perform for Oper<GetCommunity> {
fn perform(&self, conn: &PgConnection) -> Result<GetCommunityResponse, Error> { type Response = GetCommunityResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<GetCommunityResponse, Error> {
let data: &GetCommunity = &self.data; let data: &GetCommunity = &self.data;
let user_id: Option<i32> = match &data.auth { let user_id: Option<i32> = match &data.auth {
@ -128,6 +132,8 @@ impl Perform<GetCommunityResponse> for Oper<GetCommunity> {
None => None, None => None,
}; };
let conn = pool.get()?;
let community_id = match data.id { let community_id = match data.id {
Some(id) => id, Some(id) => id,
None => { None => {
@ -157,18 +163,42 @@ impl Perform<GetCommunityResponse> for Oper<GetCommunity> {
let creator_user = admins.remove(creator_index); let creator_user = admins.remove(creator_index);
admins.insert(0, creator_user); admins.insert(0, creator_user);
// Return the jwt let online = if let Some(ws) = websocket_info {
Ok(GetCommunityResponse { if let Some(id) = ws.id {
ws.chatserver
.do_send(JoinCommunityRoom { community_id, id });
}
// TODO
1
// let fut = async {
// ws.chatserver.send(GetCommunityUsersOnline {community_id}).await.unwrap()
// };
// Runtime::new().unwrap().block_on(fut)
} else {
0
};
let res = GetCommunityResponse {
community: community_view, community: community_view,
moderators, moderators,
admins, admins,
online: 0, online,
}) };
// Return the jwt
Ok(res)
} }
} }
impl Perform<CommunityResponse> for Oper<CreateCommunity> { impl Perform for Oper<CreateCommunity> {
fn perform(&self, conn: &PgConnection) -> Result<CommunityResponse, Error> { type Response = CommunityResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<CommunityResponse, Error> {
let data: &CreateCommunity = &self.data; let data: &CreateCommunity = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -192,6 +222,8 @@ impl Perform<CommunityResponse> for Oper<CreateCommunity> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Check for a site ban // Check for a site ban
if UserView::read(&conn, user_id)?.banned { if UserView::read(&conn, user_id)?.banned {
return Err(APIError::err("site_ban").into()); return Err(APIError::err("site_ban").into());
@ -245,8 +277,14 @@ impl Perform<CommunityResponse> for Oper<CreateCommunity> {
} }
} }
impl Perform<CommunityResponse> for Oper<EditCommunity> { impl Perform for Oper<EditCommunity> {
fn perform(&self, conn: &PgConnection) -> Result<CommunityResponse, Error> { type Response = CommunityResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<CommunityResponse, Error> {
let data: &EditCommunity = &self.data; let data: &EditCommunity = &self.data;
if let Err(slurs) = slur_check(&data.name) { if let Err(slurs) = slur_check(&data.name) {
@ -270,6 +308,8 @@ impl Perform<CommunityResponse> for Oper<EditCommunity> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Check for a site ban // Check for a site ban
if UserView::read(&conn, user_id)?.banned { if UserView::read(&conn, user_id)?.banned {
return Err(APIError::err("site_ban").into()); return Err(APIError::err("site_ban").into());
@ -323,14 +363,36 @@ impl Perform<CommunityResponse> for Oper<EditCommunity> {
let community_view = CommunityView::read(&conn, data.edit_id, Some(user_id))?; let community_view = CommunityView::read(&conn, data.edit_id, Some(user_id))?;
Ok(CommunityResponse { let res = CommunityResponse {
community: community_view, community: community_view,
}) };
if let Some(ws) = websocket_info {
// Strip out the user id and subscribed when sending to others
let mut res_sent = res.clone();
res_sent.community.user_id = None;
res_sent.community.subscribed = None;
ws.chatserver.do_send(SendCommunityRoomMessage {
op: UserOperation::EditCommunity,
response: res_sent,
community_id: data.edit_id,
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<ListCommunitiesResponse> for Oper<ListCommunities> { impl Perform for Oper<ListCommunities> {
fn perform(&self, conn: &PgConnection) -> Result<ListCommunitiesResponse, Error> { type Response = ListCommunitiesResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<ListCommunitiesResponse, Error> {
let data: &ListCommunities = &self.data; let data: &ListCommunities = &self.data;
let user_claims: Option<Claims> = match &data.auth { let user_claims: Option<Claims> = match &data.auth {
@ -353,6 +415,8 @@ impl Perform<ListCommunitiesResponse> for Oper<ListCommunities> {
let sort = SortType::from_str(&data.sort)?; let sort = SortType::from_str(&data.sort)?;
let conn = pool.get()?;
let communities = CommunityQueryBuilder::create(&conn) let communities = CommunityQueryBuilder::create(&conn)
.sort(&sort) .sort(&sort)
.for_user(user_id) .for_user(user_id)
@ -366,8 +430,14 @@ impl Perform<ListCommunitiesResponse> for Oper<ListCommunities> {
} }
} }
impl Perform<CommunityResponse> for Oper<FollowCommunity> { impl Perform for Oper<FollowCommunity> {
fn perform(&self, conn: &PgConnection) -> Result<CommunityResponse, Error> { type Response = CommunityResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<CommunityResponse, Error> {
let data: &FollowCommunity = &self.data; let data: &FollowCommunity = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -382,6 +452,8 @@ impl Perform<CommunityResponse> for Oper<FollowCommunity> {
user_id, user_id,
}; };
let conn = pool.get()?;
if data.follow { if data.follow {
match CommunityFollower::follow(&conn, &community_follower_form) { match CommunityFollower::follow(&conn, &community_follower_form) {
Ok(user) => user, Ok(user) => user,
@ -402,8 +474,14 @@ impl Perform<CommunityResponse> for Oper<FollowCommunity> {
} }
} }
impl Perform<GetFollowedCommunitiesResponse> for Oper<GetFollowedCommunities> { impl Perform for Oper<GetFollowedCommunities> {
fn perform(&self, conn: &PgConnection) -> Result<GetFollowedCommunitiesResponse, Error> { type Response = GetFollowedCommunitiesResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetFollowedCommunitiesResponse, Error> {
let data: &GetFollowedCommunities = &self.data; let data: &GetFollowedCommunities = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -413,6 +491,8 @@ impl Perform<GetFollowedCommunitiesResponse> for Oper<GetFollowedCommunities> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let communities: Vec<CommunityFollowerView> = let communities: Vec<CommunityFollowerView> =
match CommunityFollowerView::for_user(&conn, user_id) { match CommunityFollowerView::for_user(&conn, user_id) {
Ok(communities) => communities, Ok(communities) => communities,
@ -424,8 +504,14 @@ impl Perform<GetFollowedCommunitiesResponse> for Oper<GetFollowedCommunities> {
} }
} }
impl Perform<BanFromCommunityResponse> for Oper<BanFromCommunity> { impl Perform for Oper<BanFromCommunity> {
fn perform(&self, conn: &PgConnection) -> Result<BanFromCommunityResponse, Error> { type Response = BanFromCommunityResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<BanFromCommunityResponse, Error> {
let data: &BanFromCommunity = &self.data; let data: &BanFromCommunity = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -440,6 +526,8 @@ impl Perform<BanFromCommunityResponse> for Oper<BanFromCommunity> {
user_id: data.user_id, user_id: data.user_id,
}; };
let conn = pool.get()?;
if data.ban { if data.ban {
match CommunityUserBan::ban(&conn, &community_user_ban_form) { match CommunityUserBan::ban(&conn, &community_user_ban_form) {
Ok(user) => user, Ok(user) => user,
@ -470,15 +558,32 @@ impl Perform<BanFromCommunityResponse> for Oper<BanFromCommunity> {
let user_view = UserView::read(&conn, data.user_id)?; let user_view = UserView::read(&conn, data.user_id)?;
Ok(BanFromCommunityResponse { let res = BanFromCommunityResponse {
user: user_view, user: user_view,
banned: data.ban, banned: data.ban,
}) };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendCommunityRoomMessage {
op: UserOperation::BanFromCommunity,
response: res.clone(),
community_id: data.community_id,
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<AddModToCommunityResponse> for Oper<AddModToCommunity> { impl Perform for Oper<AddModToCommunity> {
fn perform(&self, conn: &PgConnection) -> Result<AddModToCommunityResponse, Error> { type Response = AddModToCommunityResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<AddModToCommunityResponse, Error> {
let data: &AddModToCommunity = &self.data; let data: &AddModToCommunity = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -493,6 +598,8 @@ impl Perform<AddModToCommunityResponse> for Oper<AddModToCommunity> {
user_id: data.user_id, user_id: data.user_id,
}; };
let conn = pool.get()?;
if data.added { if data.added {
match CommunityModerator::join(&conn, &community_moderator_form) { match CommunityModerator::join(&conn, &community_moderator_form) {
Ok(user) => user, Ok(user) => user,
@ -516,12 +623,29 @@ impl Perform<AddModToCommunityResponse> for Oper<AddModToCommunity> {
let moderators = CommunityModeratorView::for_community(&conn, data.community_id)?; let moderators = CommunityModeratorView::for_community(&conn, data.community_id)?;
Ok(AddModToCommunityResponse { moderators }) let res = AddModToCommunityResponse { moderators };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendCommunityRoomMessage {
op: UserOperation::AddModToCommunity,
response: res.clone(),
community_id: data.community_id,
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<GetCommunityResponse> for Oper<TransferCommunity> { impl Perform for Oper<TransferCommunity> {
fn perform(&self, conn: &PgConnection) -> Result<GetCommunityResponse, Error> { type Response = GetCommunityResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetCommunityResponse, Error> {
let data: &TransferCommunity = &self.data; let data: &TransferCommunity = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -531,6 +655,8 @@ impl Perform<GetCommunityResponse> for Oper<TransferCommunity> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let read_community = Community::read(&conn, data.community_id)?; let read_community = Community::read(&conn, data.community_id)?;
let site_creator_id = Site::read(&conn, 1)?.creator_id; let site_creator_id = Site::read(&conn, 1)?.creator_id;

View file

@ -18,12 +18,25 @@ use crate::db::user_mention_view::*;
use crate::db::user_view::*; use crate::db::user_view::*;
use crate::db::*; use crate::db::*;
use crate::{ use crate::{
extract_usernames, fetch_iframely_and_pictshare_data, naive_from_unix, naive_now, remove_slurs, extract_usernames, fetch_iframely_and_pictshare_data, generate_random_string, naive_from_unix,
slur_check, slurs_vec_to_str, naive_now, remove_slurs, send_email, slur_check, slurs_vec_to_str,
}; };
use crate::settings::Settings;
use crate::websocket::UserOperation;
use crate::websocket::{
server::{
JoinCommunityRoom, JoinPostRoom, JoinUserRoom, SendAllMessage, SendComment,
SendCommunityRoomMessage, SendPost, SendUserRoomMessage,
},
WebsocketInfo,
};
use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection; use diesel::PgConnection;
use failure::Error; use failure::Error;
use log::{error, info};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::str::FromStr;
pub mod comment; pub mod comment;
pub mod community; pub mod community;
@ -55,8 +68,12 @@ impl<T> Oper<T> {
} }
} }
pub trait Perform<T> { pub trait Perform {
fn perform(&self, conn: &PgConnection) -> Result<T, Error> type Response: serde::ser::Serialize;
where
T: Sized; fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<Self::Response, Error>;
} }

View file

@ -1,6 +1,4 @@
use super::*; use super::*;
use diesel::PgConnection;
use std::str::FromStr;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct CreatePost { pub struct CreatePost {
@ -79,8 +77,14 @@ pub struct SavePost {
auth: String, auth: String,
} }
impl Perform<PostResponse> for Oper<CreatePost> { impl Perform for Oper<CreatePost> {
fn perform(&self, conn: &PgConnection) -> Result<PostResponse, Error> { type Response = PostResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<PostResponse, Error> {
let data: &CreatePost = &self.data; let data: &CreatePost = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -100,6 +104,8 @@ impl Perform<PostResponse> for Oper<CreatePost> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Check for a community ban // Check for a community ban
if CommunityUserBanView::get(&conn, user_id, data.community_id).is_ok() { if CommunityUserBanView::get(&conn, user_id, data.community_id).is_ok() {
return Err(APIError::err("community_ban").into()); return Err(APIError::err("community_ban").into());
@ -164,12 +170,28 @@ impl Perform<PostResponse> for Oper<CreatePost> {
Err(_e) => return Err(APIError::err("couldnt_find_post").into()), Err(_e) => return Err(APIError::err("couldnt_find_post").into()),
}; };
Ok(PostResponse { post: post_view }) let res = PostResponse { post: post_view };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendPost {
op: UserOperation::CreatePost,
post: res.clone(),
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<GetPostResponse> for Oper<GetPost> { impl Perform for Oper<GetPost> {
fn perform(&self, conn: &PgConnection) -> Result<GetPostResponse, Error> { type Response = GetPostResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<GetPostResponse, Error> {
let data: &GetPost = &self.data; let data: &GetPost = &self.data;
let user_id: Option<i32> = match &data.auth { let user_id: Option<i32> = match &data.auth {
@ -183,6 +205,8 @@ impl Perform<GetPostResponse> for Oper<GetPost> {
None => None, None => None,
}; };
let conn = pool.get()?;
let post_view = match PostView::read(&conn, data.id, user_id) { let post_view = match PostView::read(&conn, data.id, user_id) {
Ok(post) => post, Ok(post) => post,
Err(_e) => return Err(APIError::err("couldnt_find_post").into()), Err(_e) => return Err(APIError::err("couldnt_find_post").into()),
@ -204,6 +228,24 @@ impl Perform<GetPostResponse> for Oper<GetPost> {
let creator_user = admins.remove(creator_index); let creator_user = admins.remove(creator_index);
admins.insert(0, creator_user); admins.insert(0, creator_user);
let online = if let Some(ws) = websocket_info {
if let Some(id) = ws.id {
ws.chatserver.do_send(JoinPostRoom {
post_id: data.id,
id,
});
}
// TODO
1
// let fut = async {
// ws.chatserver.send(GetPostUsersOnline {post_id: data.id}).await.unwrap()
// };
// Runtime::new().unwrap().block_on(fut)
} else {
0
};
// Return the jwt // Return the jwt
Ok(GetPostResponse { Ok(GetPostResponse {
post: post_view, post: post_view,
@ -211,13 +253,19 @@ impl Perform<GetPostResponse> for Oper<GetPost> {
community, community,
moderators, moderators,
admins, admins,
online: 0, online,
}) })
} }
} }
impl Perform<GetPostsResponse> for Oper<GetPosts> { impl Perform for Oper<GetPosts> {
fn perform(&self, conn: &PgConnection) -> Result<GetPostsResponse, Error> { type Response = GetPostsResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<GetPostsResponse, Error> {
let data: &GetPosts = &self.data; let data: &GetPosts = &self.data;
let user_claims: Option<Claims> = match &data.auth { let user_claims: Option<Claims> = match &data.auth {
@ -241,6 +289,8 @@ impl Perform<GetPostsResponse> for Oper<GetPosts> {
let type_ = ListingType::from_str(&data.type_)?; let type_ = ListingType::from_str(&data.type_)?;
let sort = SortType::from_str(&data.sort)?; let sort = SortType::from_str(&data.sort)?;
let conn = pool.get()?;
let posts = match PostQueryBuilder::create(&conn) let posts = match PostQueryBuilder::create(&conn)
.listing_type(type_) .listing_type(type_)
.sort(&sort) .sort(&sort)
@ -255,12 +305,32 @@ impl Perform<GetPostsResponse> for Oper<GetPosts> {
Err(_e) => return Err(APIError::err("couldnt_get_posts").into()), Err(_e) => return Err(APIError::err("couldnt_get_posts").into()),
}; };
if let Some(ws) = websocket_info {
// You don't need to join the specific community room, bc this is already handled by
// GetCommunity
if data.community_id.is_none() {
if let Some(id) = ws.id {
// 0 is the "all" community
ws.chatserver.do_send(JoinCommunityRoom {
community_id: 0,
id,
});
}
}
}
Ok(GetPostsResponse { posts }) Ok(GetPostsResponse { posts })
} }
} }
impl Perform<PostResponse> for Oper<CreatePostLike> { impl Perform for Oper<CreatePostLike> {
fn perform(&self, conn: &PgConnection) -> Result<PostResponse, Error> { type Response = PostResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<PostResponse, Error> {
let data: &CreatePostLike = &self.data; let data: &CreatePostLike = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -270,6 +340,8 @@ impl Perform<PostResponse> for Oper<CreatePostLike> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Don't do a downvote if site has downvotes disabled // Don't do a downvote if site has downvotes disabled
if data.score == -1 { if data.score == -1 {
let site = SiteView::read(&conn)?; let site = SiteView::read(&conn)?;
@ -312,13 +384,28 @@ impl Perform<PostResponse> for Oper<CreatePostLike> {
Err(_e) => return Err(APIError::err("couldnt_find_post").into()), Err(_e) => return Err(APIError::err("couldnt_find_post").into()),
}; };
// just output the score let res = PostResponse { post: post_view };
Ok(PostResponse { post: post_view })
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendPost {
op: UserOperation::CreatePostLike,
post: res.clone(),
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<PostResponse> for Oper<EditPost> { impl Perform for Oper<EditPost> {
fn perform(&self, conn: &PgConnection) -> Result<PostResponse, Error> { type Response = PostResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<PostResponse, Error> {
let data: &EditPost = &self.data; let data: &EditPost = &self.data;
if let Err(slurs) = slur_check(&data.name) { if let Err(slurs) = slur_check(&data.name) {
@ -338,6 +425,8 @@ impl Perform<PostResponse> for Oper<EditPost> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Verify its the creator or a mod or admin // Verify its the creator or a mod or admin
let mut editors: Vec<i32> = vec![data.creator_id]; let mut editors: Vec<i32> = vec![data.creator_id];
editors.append( editors.append(
@ -427,12 +516,28 @@ impl Perform<PostResponse> for Oper<EditPost> {
let post_view = PostView::read(&conn, data.edit_id, Some(user_id))?; let post_view = PostView::read(&conn, data.edit_id, Some(user_id))?;
Ok(PostResponse { post: post_view }) let res = PostResponse { post: post_view };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendPost {
op: UserOperation::EditPost,
post: res.clone(),
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<PostResponse> for Oper<SavePost> { impl Perform for Oper<SavePost> {
fn perform(&self, conn: &PgConnection) -> Result<PostResponse, Error> { type Response = PostResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<PostResponse, Error> {
let data: &SavePost = &self.data; let data: &SavePost = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -447,6 +552,8 @@ impl Perform<PostResponse> for Oper<SavePost> {
user_id, user_id,
}; };
let conn = pool.get()?;
if data.save { if data.save {
match PostSaved::save(&conn, &post_saved_form) { match PostSaved::save(&conn, &post_saved_form) {
Ok(post) => post, Ok(post) => post,

View file

@ -1,10 +1,5 @@
use super::user::Register;
use super::*; use super::*;
use crate::api::user::Register;
use crate::api::{Oper, Perform};
use crate::settings::Settings;
use diesel::PgConnection;
use log::info;
use std::str::FromStr;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ListCategories {} pub struct ListCategories {}
@ -78,7 +73,7 @@ pub struct EditSite {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct GetSite {} pub struct GetSite {}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub struct SiteResponse { pub struct SiteResponse {
site: SiteView, site: SiteView,
} }
@ -113,10 +108,18 @@ pub struct SaveSiteConfig {
auth: String, auth: String,
} }
impl Perform<ListCategoriesResponse> for Oper<ListCategories> { impl Perform for Oper<ListCategories> {
fn perform(&self, conn: &PgConnection) -> Result<ListCategoriesResponse, Error> { type Response = ListCategoriesResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<ListCategoriesResponse, Error> {
let _data: &ListCategories = &self.data; let _data: &ListCategories = &self.data;
let conn = pool.get()?;
let categories: Vec<Category> = Category::list_all(&conn)?; let categories: Vec<Category> = Category::list_all(&conn)?;
// Return the jwt // Return the jwt
@ -124,10 +127,18 @@ impl Perform<ListCategoriesResponse> for Oper<ListCategories> {
} }
} }
impl Perform<GetModlogResponse> for Oper<GetModlog> { impl Perform for Oper<GetModlog> {
fn perform(&self, conn: &PgConnection) -> Result<GetModlogResponse, Error> { type Response = GetModlogResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetModlogResponse, Error> {
let data: &GetModlog = &self.data; let data: &GetModlog = &self.data;
let conn = pool.get()?;
let removed_posts = ModRemovePostView::list( let removed_posts = ModRemovePostView::list(
&conn, &conn,
data.community_id, data.community_id,
@ -197,8 +208,14 @@ impl Perform<GetModlogResponse> for Oper<GetModlog> {
} }
} }
impl Perform<SiteResponse> for Oper<CreateSite> { impl Perform for Oper<CreateSite> {
fn perform(&self, conn: &PgConnection) -> Result<SiteResponse, Error> { type Response = SiteResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<SiteResponse, Error> {
let data: &CreateSite = &self.data; let data: &CreateSite = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -218,6 +235,8 @@ impl Perform<SiteResponse> for Oper<CreateSite> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Make sure user is an admin // Make sure user is an admin
if !UserView::read(&conn, user_id)?.admin { if !UserView::read(&conn, user_id)?.admin {
return Err(APIError::err("not_an_admin").into()); return Err(APIError::err("not_an_admin").into());
@ -244,8 +263,13 @@ impl Perform<SiteResponse> for Oper<CreateSite> {
} }
} }
impl Perform<SiteResponse> for Oper<EditSite> { impl Perform for Oper<EditSite> {
fn perform(&self, conn: &PgConnection) -> Result<SiteResponse, Error> { type Response = SiteResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<SiteResponse, Error> {
let data: &EditSite = &self.data; let data: &EditSite = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -265,6 +289,8 @@ impl Perform<SiteResponse> for Oper<EditSite> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Make sure user is an admin // Make sure user is an admin
if !UserView::read(&conn, user_id)?.admin { if !UserView::read(&conn, user_id)?.admin {
return Err(APIError::err("not_an_admin").into()); return Err(APIError::err("not_an_admin").into());
@ -289,14 +315,33 @@ impl Perform<SiteResponse> for Oper<EditSite> {
let site_view = SiteView::read(&conn)?; let site_view = SiteView::read(&conn)?;
Ok(SiteResponse { site: site_view }) let res = SiteResponse { site: site_view };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendAllMessage {
op: UserOperation::EditSite,
response: res.clone(),
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<GetSiteResponse> for Oper<GetSite> { impl Perform for Oper<GetSite> {
fn perform(&self, conn: &PgConnection) -> Result<GetSiteResponse, Error> { type Response = GetSiteResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<GetSiteResponse, Error> {
let _data: &GetSite = &self.data; let _data: &GetSite = &self.data;
let conn = pool.get()?;
// TODO refactor this a little
let site = Site::read(&conn, 1); let site = Site::read(&conn, 1);
let site_view = if site.is_ok() { let site_view = if site.is_ok() {
Some(SiteView::read(&conn)?) Some(SiteView::read(&conn)?)
@ -309,7 +354,7 @@ impl Perform<GetSiteResponse> for Oper<GetSite> {
admin: true, admin: true,
show_nsfw: true, show_nsfw: true,
}; };
let login_response = Oper::new(register).perform(&conn)?; let login_response = Oper::new(register).perform(pool.clone(), websocket_info.clone())?;
info!("Admin {} created", setup.admin_username); info!("Admin {} created", setup.admin_username);
let create_site = CreateSite { let create_site = CreateSite {
@ -320,7 +365,7 @@ impl Perform<GetSiteResponse> for Oper<GetSite> {
enable_nsfw: false, enable_nsfw: false,
auth: login_response.jwt, auth: login_response.jwt,
}; };
Oper::new(create_site).perform(&conn)?; Oper::new(create_site).perform(pool, websocket_info.clone())?;
info!("Site {} created", setup.site_name); info!("Site {} created", setup.site_name);
Some(SiteView::read(&conn)?) Some(SiteView::read(&conn)?)
} else { } else {
@ -337,17 +382,34 @@ impl Perform<GetSiteResponse> for Oper<GetSite> {
let banned = UserView::banned(&conn)?; let banned = UserView::banned(&conn)?;
let online = if let Some(_ws) = websocket_info {
// TODO
1
// let fut = async {
// ws.chatserver.send(GetUsersOnline).await.unwrap()
// };
// Runtime::new().unwrap().block_on(fut)
} else {
0
};
Ok(GetSiteResponse { Ok(GetSiteResponse {
site: site_view, site: site_view,
admins, admins,
banned, banned,
online: 0, online,
}) })
} }
} }
impl Perform<SearchResponse> for Oper<Search> { impl Perform for Oper<Search> {
fn perform(&self, conn: &PgConnection) -> Result<SearchResponse, Error> { type Response = SearchResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<SearchResponse, Error> {
let data: &Search = &self.data; let data: &Search = &self.data;
let user_id: Option<i32> = match &data.auth { let user_id: Option<i32> = match &data.auth {
@ -371,6 +433,8 @@ impl Perform<SearchResponse> for Oper<Search> {
// TODO no clean / non-nsfw searching rn // TODO no clean / non-nsfw searching rn
let conn = pool.get()?;
match type_ { match type_ {
SearchType::Posts => { SearchType::Posts => {
posts = PostQueryBuilder::create(&conn) posts = PostQueryBuilder::create(&conn)
@ -464,8 +528,14 @@ impl Perform<SearchResponse> for Oper<Search> {
} }
} }
impl Perform<GetSiteResponse> for Oper<TransferSite> { impl Perform for Oper<TransferSite> {
fn perform(&self, conn: &PgConnection) -> Result<GetSiteResponse, Error> { type Response = GetSiteResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetSiteResponse, Error> {
let data: &TransferSite = &self.data; let data: &TransferSite = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -475,6 +545,8 @@ impl Perform<GetSiteResponse> for Oper<TransferSite> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let read_site = Site::read(&conn, 1)?; let read_site = Site::read(&conn, 1)?;
// Make sure user is the creator // Make sure user is the creator
@ -527,8 +599,14 @@ impl Perform<GetSiteResponse> for Oper<TransferSite> {
} }
} }
impl Perform<GetSiteConfigResponse> for Oper<GetSiteConfig> { impl Perform for Oper<GetSiteConfig> {
fn perform(&self, conn: &PgConnection) -> Result<GetSiteConfigResponse, Error> { type Response = GetSiteConfigResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetSiteConfigResponse, Error> {
let data: &GetSiteConfig = &self.data; let data: &GetSiteConfig = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -538,6 +616,8 @@ impl Perform<GetSiteConfigResponse> for Oper<GetSiteConfig> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Only let admins read this // Only let admins read this
let admins = UserView::admins(&conn)?; let admins = UserView::admins(&conn)?;
let admin_ids: Vec<i32> = admins.into_iter().map(|m| m.id).collect(); let admin_ids: Vec<i32> = admins.into_iter().map(|m| m.id).collect();
@ -552,8 +632,14 @@ impl Perform<GetSiteConfigResponse> for Oper<GetSiteConfig> {
} }
} }
impl Perform<GetSiteConfigResponse> for Oper<SaveSiteConfig> { impl Perform for Oper<SaveSiteConfig> {
fn perform(&self, conn: &PgConnection) -> Result<GetSiteConfigResponse, Error> { type Response = GetSiteConfigResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetSiteConfigResponse, Error> {
let data: &SaveSiteConfig = &self.data; let data: &SaveSiteConfig = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -563,6 +649,8 @@ impl Perform<GetSiteConfigResponse> for Oper<SaveSiteConfig> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Only let admins read this // Only let admins read this
let admins = UserView::admins(&conn)?; let admins = UserView::admins(&conn)?;
let admin_ids: Vec<i32> = admins.into_iter().map(|m| m.id).collect(); let admin_ids: Vec<i32> = admins.into_iter().map(|m| m.id).collect();

View file

@ -1,10 +1,5 @@
use super::*; use super::*;
use crate::settings::Settings;
use crate::{generate_random_string, send_email};
use bcrypt::verify; use bcrypt::verify;
use diesel::PgConnection;
use log::error;
use std::str::FromStr;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Login { pub struct Login {
@ -89,7 +84,7 @@ pub struct AddAdmin {
auth: String, auth: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub struct AddAdminResponse { pub struct AddAdminResponse {
admins: Vec<UserView>, admins: Vec<UserView>,
} }
@ -103,7 +98,7 @@ pub struct BanUser {
auth: String, auth: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub struct BanUserResponse { pub struct BanUserResponse {
user: UserView, user: UserView,
banned: bool, banned: bool,
@ -204,10 +199,18 @@ pub struct UserJoinResponse {
pub user_id: i32, pub user_id: i32,
} }
impl Perform<LoginResponse> for Oper<Login> { impl Perform for Oper<Login> {
fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { type Response = LoginResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<LoginResponse, Error> {
let data: &Login = &self.data; let data: &Login = &self.data;
let conn = pool.get()?;
// Fetch that username / email // Fetch that username / email
let user: User_ = match User_::find_by_email_or_username(&conn, &data.username_or_email) { let user: User_ = match User_::find_by_email_or_username(&conn, &data.username_or_email) {
Ok(user) => user, Ok(user) => user,
@ -225,10 +228,18 @@ impl Perform<LoginResponse> for Oper<Login> {
} }
} }
impl Perform<LoginResponse> for Oper<Register> { impl Perform for Oper<Register> {
fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { type Response = LoginResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<LoginResponse, Error> {
let data: &Register = &self.data; let data: &Register = &self.data;
let conn = pool.get()?;
// Make sure site has open registration // Make sure site has open registration
if let Ok(site) = SiteView::read(&conn) { if let Ok(site) = SiteView::read(&conn) {
if !site.open_registration { if !site.open_registration {
@ -339,8 +350,14 @@ impl Perform<LoginResponse> for Oper<Register> {
} }
} }
impl Perform<LoginResponse> for Oper<SaveUserSettings> { impl Perform for Oper<SaveUserSettings> {
fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { type Response = LoginResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<LoginResponse, Error> {
let data: &SaveUserSettings = &self.data; let data: &SaveUserSettings = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -350,6 +367,8 @@ impl Perform<LoginResponse> for Oper<SaveUserSettings> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let read_user = User_::read(&conn, user_id)?; let read_user = User_::read(&conn, user_id)?;
let email = match &data.email { let email = match &data.email {
@ -427,10 +446,18 @@ impl Perform<LoginResponse> for Oper<SaveUserSettings> {
} }
} }
impl Perform<GetUserDetailsResponse> for Oper<GetUserDetails> { impl Perform for Oper<GetUserDetails> {
fn perform(&self, conn: &PgConnection) -> Result<GetUserDetailsResponse, Error> { type Response = GetUserDetailsResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetUserDetailsResponse, Error> {
let data: &GetUserDetails = &self.data; let data: &GetUserDetails = &self.data;
let conn = pool.get()?;
let user_claims: Option<Claims> = match &data.auth { let user_claims: Option<Claims> = match &data.auth {
Some(auth) => match Claims::decode(&auth) { Some(auth) => match Claims::decode(&auth) {
Ok(claims) => Some(claims.claims), Ok(claims) => Some(claims.claims),
@ -524,8 +551,14 @@ impl Perform<GetUserDetailsResponse> for Oper<GetUserDetails> {
} }
} }
impl Perform<AddAdminResponse> for Oper<AddAdmin> { impl Perform for Oper<AddAdmin> {
fn perform(&self, conn: &PgConnection) -> Result<AddAdminResponse, Error> { type Response = AddAdminResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<AddAdminResponse, Error> {
let data: &AddAdmin = &self.data; let data: &AddAdmin = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -535,6 +568,8 @@ impl Perform<AddAdminResponse> for Oper<AddAdmin> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Make sure user is an admin // Make sure user is an admin
if !UserView::read(&conn, user_id)?.admin { if !UserView::read(&conn, user_id)?.admin {
return Err(APIError::err("not_an_admin").into()); return Err(APIError::err("not_an_admin").into());
@ -583,12 +618,28 @@ impl Perform<AddAdminResponse> for Oper<AddAdmin> {
let creator_user = admins.remove(creator_index); let creator_user = admins.remove(creator_index);
admins.insert(0, creator_user); admins.insert(0, creator_user);
Ok(AddAdminResponse { admins }) let res = AddAdminResponse { admins };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendAllMessage {
op: UserOperation::AddAdmin,
response: res.clone(),
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<BanUserResponse> for Oper<BanUser> { impl Perform for Oper<BanUser> {
fn perform(&self, conn: &PgConnection) -> Result<BanUserResponse, Error> { type Response = BanUserResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<BanUserResponse, Error> {
let data: &BanUser = &self.data; let data: &BanUser = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -598,6 +649,8 @@ impl Perform<BanUserResponse> for Oper<BanUser> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
// Make sure user is an admin // Make sure user is an admin
if !UserView::read(&conn, user_id)?.admin { if !UserView::read(&conn, user_id)?.admin {
return Err(APIError::err("not_an_admin").into()); return Err(APIError::err("not_an_admin").into());
@ -649,15 +702,31 @@ impl Perform<BanUserResponse> for Oper<BanUser> {
let user_view = UserView::read(&conn, data.user_id)?; let user_view = UserView::read(&conn, data.user_id)?;
Ok(BanUserResponse { let res = BanUserResponse {
user: user_view, user: user_view,
banned: data.ban, banned: data.ban,
}) };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendAllMessage {
op: UserOperation::BanUser,
response: res.clone(),
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<GetRepliesResponse> for Oper<GetReplies> { impl Perform for Oper<GetReplies> {
fn perform(&self, conn: &PgConnection) -> Result<GetRepliesResponse, Error> { type Response = GetRepliesResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetRepliesResponse, Error> {
let data: &GetReplies = &self.data; let data: &GetReplies = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -669,6 +738,8 @@ impl Perform<GetRepliesResponse> for Oper<GetReplies> {
let sort = SortType::from_str(&data.sort)?; let sort = SortType::from_str(&data.sort)?;
let conn = pool.get()?;
let replies = ReplyQueryBuilder::create(&conn, user_id) let replies = ReplyQueryBuilder::create(&conn, user_id)
.sort(&sort) .sort(&sort)
.unread_only(data.unread_only) .unread_only(data.unread_only)
@ -680,8 +751,14 @@ impl Perform<GetRepliesResponse> for Oper<GetReplies> {
} }
} }
impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> { impl Perform for Oper<GetUserMentions> {
fn perform(&self, conn: &PgConnection) -> Result<GetUserMentionsResponse, Error> { type Response = GetUserMentionsResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetUserMentionsResponse, Error> {
let data: &GetUserMentions = &self.data; let data: &GetUserMentions = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -693,6 +770,8 @@ impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> {
let sort = SortType::from_str(&data.sort)?; let sort = SortType::from_str(&data.sort)?;
let conn = pool.get()?;
let mentions = UserMentionQueryBuilder::create(&conn, user_id) let mentions = UserMentionQueryBuilder::create(&conn, user_id)
.sort(&sort) .sort(&sort)
.unread_only(data.unread_only) .unread_only(data.unread_only)
@ -704,8 +783,14 @@ impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> {
} }
} }
impl Perform<UserMentionResponse> for Oper<EditUserMention> { impl Perform for Oper<EditUserMention> {
fn perform(&self, conn: &PgConnection) -> Result<UserMentionResponse, Error> { type Response = UserMentionResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<UserMentionResponse, Error> {
let data: &EditUserMention = &self.data; let data: &EditUserMention = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -715,6 +800,8 @@ impl Perform<UserMentionResponse> for Oper<EditUserMention> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let user_mention = UserMention::read(&conn, data.user_mention_id)?; let user_mention = UserMention::read(&conn, data.user_mention_id)?;
let user_mention_form = UserMentionForm { let user_mention_form = UserMentionForm {
@ -737,8 +824,14 @@ impl Perform<UserMentionResponse> for Oper<EditUserMention> {
} }
} }
impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> { impl Perform for Oper<MarkAllAsRead> {
fn perform(&self, conn: &PgConnection) -> Result<GetRepliesResponse, Error> { type Response = GetRepliesResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<GetRepliesResponse, Error> {
let data: &MarkAllAsRead = &self.data; let data: &MarkAllAsRead = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -748,6 +841,8 @@ impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let replies = ReplyQueryBuilder::create(&conn, user_id) let replies = ReplyQueryBuilder::create(&conn, user_id)
.unread_only(true) .unread_only(true)
.page(1) .page(1)
@ -821,8 +916,14 @@ impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> {
} }
} }
impl Perform<LoginResponse> for Oper<DeleteAccount> { impl Perform for Oper<DeleteAccount> {
fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { type Response = LoginResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<LoginResponse, Error> {
let data: &DeleteAccount = &self.data; let data: &DeleteAccount = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -832,6 +933,8 @@ impl Perform<LoginResponse> for Oper<DeleteAccount> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let user: User_ = User_::read(&conn, user_id)?; let user: User_ = User_::read(&conn, user_id)?;
// Verify the password // Verify the password
@ -902,10 +1005,18 @@ impl Perform<LoginResponse> for Oper<DeleteAccount> {
} }
} }
impl Perform<PasswordResetResponse> for Oper<PasswordReset> { impl Perform for Oper<PasswordReset> {
fn perform(&self, conn: &PgConnection) -> Result<PasswordResetResponse, Error> { type Response = PasswordResetResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<PasswordResetResponse, Error> {
let data: &PasswordReset = &self.data; let data: &PasswordReset = &self.data;
let conn = pool.get()?;
// Fetch that email // Fetch that email
let user: User_ = match User_::find_by_email(&conn, &data.email) { let user: User_ = match User_::find_by_email(&conn, &data.email) {
Ok(user) => user, Ok(user) => user,
@ -933,10 +1044,18 @@ impl Perform<PasswordResetResponse> for Oper<PasswordReset> {
} }
} }
impl Perform<LoginResponse> for Oper<PasswordChange> { impl Perform for Oper<PasswordChange> {
fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { type Response = LoginResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<LoginResponse, Error> {
let data: &PasswordChange = &self.data; let data: &PasswordChange = &self.data;
let conn = pool.get()?;
// Fetch the user_id from the token // Fetch the user_id from the token
let user_id = PasswordResetRequest::read_from_token(&conn, &data.token)?.user_id; let user_id = PasswordResetRequest::read_from_token(&conn, &data.token)?.user_id;
@ -958,8 +1077,14 @@ impl Perform<LoginResponse> for Oper<PasswordChange> {
} }
} }
impl Perform<PrivateMessageResponse> for Oper<CreatePrivateMessage> { impl Perform for Oper<CreatePrivateMessage> {
fn perform(&self, conn: &PgConnection) -> Result<PrivateMessageResponse, Error> { type Response = PrivateMessageResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<PrivateMessageResponse, Error> {
let data: &CreatePrivateMessage = &self.data; let data: &CreatePrivateMessage = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -971,6 +1096,8 @@ impl Perform<PrivateMessageResponse> for Oper<CreatePrivateMessage> {
let hostname = &format!("https://{}", Settings::get().hostname); let hostname = &format!("https://{}", Settings::get().hostname);
let conn = pool.get()?;
// Check for a site ban // Check for a site ban
if UserView::read(&conn, user_id)?.banned { if UserView::read(&conn, user_id)?.banned {
return Err(APIError::err("site_ban").into()); return Err(APIError::err("site_ban").into());
@ -1016,12 +1143,29 @@ impl Perform<PrivateMessageResponse> for Oper<CreatePrivateMessage> {
let message = PrivateMessageView::read(&conn, inserted_private_message.id)?; let message = PrivateMessageView::read(&conn, inserted_private_message.id)?;
Ok(PrivateMessageResponse { message }) let res = PrivateMessageResponse { message };
if let Some(ws) = websocket_info {
ws.chatserver.do_send(SendUserRoomMessage {
op: UserOperation::CreatePrivateMessage,
response: res.clone(),
recipient_id: recipient_user.id,
my_id: ws.id,
});
}
Ok(res)
} }
} }
impl Perform<PrivateMessageResponse> for Oper<EditPrivateMessage> { impl Perform for Oper<EditPrivateMessage> {
fn perform(&self, conn: &PgConnection) -> Result<PrivateMessageResponse, Error> { type Response = PrivateMessageResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<PrivateMessageResponse, Error> {
let data: &EditPrivateMessage = &self.data; let data: &EditPrivateMessage = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -1031,6 +1175,8 @@ impl Perform<PrivateMessageResponse> for Oper<EditPrivateMessage> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let orig_private_message = PrivateMessage::read(&conn, data.edit_id)?; let orig_private_message = PrivateMessage::read(&conn, data.edit_id)?;
// Check for a site ban // Check for a site ban
@ -1075,8 +1221,14 @@ impl Perform<PrivateMessageResponse> for Oper<EditPrivateMessage> {
} }
} }
impl Perform<PrivateMessagesResponse> for Oper<GetPrivateMessages> { impl Perform for Oper<GetPrivateMessages> {
fn perform(&self, conn: &PgConnection) -> Result<PrivateMessagesResponse, Error> { type Response = PrivateMessagesResponse;
fn perform(
&self,
pool: Pool<ConnectionManager<PgConnection>>,
_websocket_info: Option<WebsocketInfo>,
) -> Result<PrivateMessagesResponse, Error> {
let data: &GetPrivateMessages = &self.data; let data: &GetPrivateMessages = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -1086,6 +1238,8 @@ impl Perform<PrivateMessagesResponse> for Oper<GetPrivateMessages> {
let user_id = claims.id; let user_id = claims.id;
let conn = pool.get()?;
let messages = PrivateMessageQueryBuilder::create(&conn, user_id) let messages = PrivateMessageQueryBuilder::create(&conn, user_id)
.page(data.page) .page(data.page)
.limit(data.limit) .limit(data.limit)
@ -1096,8 +1250,14 @@ impl Perform<PrivateMessagesResponse> for Oper<GetPrivateMessages> {
} }
} }
impl Perform<UserJoinResponse> for Oper<UserJoin> { impl Perform for Oper<UserJoin> {
fn perform(&self, _conn: &PgConnection) -> Result<UserJoinResponse, Error> { type Response = UserJoinResponse;
fn perform(
&self,
_pool: Pool<ConnectionManager<PgConnection>>,
websocket_info: Option<WebsocketInfo>,
) -> Result<UserJoinResponse, Error> {
let data: &UserJoin = &self.data; let data: &UserJoin = &self.data;
let claims = match Claims::decode(&data.auth) { let claims = match Claims::decode(&data.auth) {
@ -1106,6 +1266,13 @@ impl Perform<UserJoinResponse> for Oper<UserJoin> {
}; };
let user_id = claims.id; let user_id = claims.id;
if let Some(ws) = websocket_info {
if let Some(id) = ws.id {
ws.chatserver.do_send(JoinUserRoom { user_id, id });
}
}
Ok(UserJoinResponse { user_id }) Ok(UserJoinResponse { user_id })
} }
} }

View file

@ -27,13 +27,14 @@ pub extern crate strum;
pub mod api; pub mod api;
pub mod apub; pub mod apub;
pub mod db; pub mod db;
pub mod rate_limit;
pub mod routes; pub mod routes;
pub mod schema; pub mod schema;
pub mod settings; pub mod settings;
pub mod version; pub mod version;
pub mod websocket; pub mod websocket;
use crate::settings::Settings; use actix_web::dev::ConnectionInfo;
use chrono::{DateTime, NaiveDateTime, Utc}; use chrono::{DateTime, NaiveDateTime, Utc};
use isahc::prelude::*; use isahc::prelude::*;
use lettre::smtp::authentication::{Credentials, Mechanism}; use lettre::smtp::authentication::{Credentials, Mechanism};
@ -48,6 +49,14 @@ use rand::{thread_rng, Rng};
use regex::{Regex, RegexBuilder}; use regex::{Regex, RegexBuilder};
use serde::Deserialize; use serde::Deserialize;
use crate::settings::Settings;
pub type ConnectionId = usize;
pub type PostId = i32;
pub type CommunityId = i32;
pub type UserId = i32;
pub type IPAddr = String;
pub fn to_datetime_utc(ndt: NaiveDateTime) -> DateTime<Utc> { pub fn to_datetime_utc(ndt: NaiveDateTime) -> DateTime<Utc> {
DateTime::<Utc>::from_utc(ndt, Utc) DateTime::<Utc>::from_utc(ndt, Utc)
} }
@ -224,6 +233,16 @@ pub fn markdown_to_html(text: &str) -> String {
comrak::markdown_to_html(text, &comrak::ComrakOptions::default()) comrak::markdown_to_html(text, &comrak::ComrakOptions::default())
} }
pub fn get_ip(conn_info: &ConnectionInfo) -> String {
conn_info
.remote()
.unwrap_or("127.0.0.1:12345")
.split(':')
.next()
.unwrap_or("127.0.0.1")
.to_string()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{extract_usernames, is_email_regex, remove_slurs, slur_check, slurs_vec_to_str}; use crate::{extract_usernames, is_email_regex, remove_slurs, slur_check, slurs_vec_to_str};

View file

@ -6,10 +6,14 @@ use actix::prelude::*;
use actix_web::*; use actix_web::*;
use diesel::r2d2::{ConnectionManager, Pool}; use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection; use diesel::PgConnection;
use lemmy_server::routes::{api, federation, feeds, index, nodeinfo, webfinger, websocket}; use lemmy_server::{
use lemmy_server::settings::Settings; rate_limit::{rate_limiter::RateLimiter, RateLimit},
use lemmy_server::websocket::server::*; routes::{api, federation, feeds, index, nodeinfo, webfinger},
use std::io; settings::Settings,
websocket::server::*,
};
use std::{io, sync::Arc};
use tokio::sync::Mutex;
embed_migrations!(); embed_migrations!();
@ -29,8 +33,13 @@ async fn main() -> io::Result<()> {
let conn = pool.get().unwrap(); let conn = pool.get().unwrap();
embedded_migrations::run(&conn).unwrap(); embedded_migrations::run(&conn).unwrap();
// Set up the rate limiter
let rate_limiter = RateLimit {
rate_limiter: Arc::new(Mutex::new(RateLimiter::default())),
};
// Set up websocket server // Set up websocket server
let server = ChatServer::startup(pool.clone()).start(); let server = ChatServer::startup(pool.clone(), rate_limiter.clone()).start();
println!( println!(
"Starting http server at {}:{}", "Starting http server at {}:{}",
@ -40,18 +49,18 @@ async fn main() -> io::Result<()> {
// Create Http server with websocket support // Create Http server with websocket support
HttpServer::new(move || { HttpServer::new(move || {
let settings = Settings::get(); let settings = Settings::get();
let rate_limiter = rate_limiter.clone();
App::new() App::new()
.wrap(middleware::Logger::default()) .wrap(middleware::Logger::default())
.data(pool.clone()) .data(pool.clone())
.data(server.clone()) .data(server.clone())
// The routes // The routes
.configure(api::config) .configure(move |cfg| api::config(cfg, &rate_limiter))
.configure(federation::config) .configure(federation::config)
.configure(feeds::config) .configure(feeds::config)
.configure(index::config) .configure(index::config)
.configure(nodeinfo::config) .configure(nodeinfo::config)
.configure(webfinger::config) .configure(webfinger::config)
.configure(websocket::config)
// static files // static files
.service(actix_files::Files::new( .service(actix_files::Files::new(
"/static", "/static",

View file

@ -0,0 +1,194 @@
pub mod rate_limiter;
use super::{IPAddr, Settings};
use crate::api::APIError;
use crate::get_ip;
use crate::settings::RateLimitConfig;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
use failure::Error;
use futures::future::{ok, Ready};
use log::debug;
use rate_limiter::{RateLimitType, RateLimiter};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::SystemTime;
use strum::IntoEnumIterator;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct RateLimit {
pub rate_limiter: Arc<Mutex<RateLimiter>>,
}
#[derive(Debug, Clone)]
pub struct RateLimited {
rate_limiter: Arc<Mutex<RateLimiter>>,
type_: RateLimitType,
}
pub struct RateLimitedMiddleware<S> {
rate_limited: RateLimited,
service: S,
}
impl RateLimit {
pub fn message(&self) -> RateLimited {
self.kind(RateLimitType::Message)
}
pub fn post(&self) -> RateLimited {
self.kind(RateLimitType::Post)
}
pub fn register(&self) -> RateLimited {
self.kind(RateLimitType::Register)
}
fn kind(&self, type_: RateLimitType) -> RateLimited {
RateLimited {
rate_limiter: self.rate_limiter.clone(),
type_,
}
}
}
impl RateLimited {
pub async fn wrap<T, E>(
self,
ip_addr: String,
fut: impl Future<Output = Result<T, E>>,
) -> Result<T, E>
where
E: From<failure::Error>,
{
let rate_limit: RateLimitConfig = actix_web::web::block(move || {
// needs to be in a web::block because the RwLock in settings is from stdlib
Ok(Settings::get().rate_limit) as Result<_, failure::Error>
})
.await
.map_err(|e| match e {
actix_web::error::BlockingError::Error(e) => e,
_ => APIError::err("Operation canceled").into(),
})?;
// before
{
let mut limiter = self.rate_limiter.lock().await;
match self.type_ {
RateLimitType::Message => {
limiter.check_rate_limit_full(
self.type_,
&ip_addr,
rate_limit.message,
rate_limit.message_per_second,
false,
)?;
return fut.await;
}
RateLimitType::Post => {
limiter.check_rate_limit_full(
self.type_.clone(),
&ip_addr,
rate_limit.post,
rate_limit.post_per_second,
true,
)?;
}
RateLimitType::Register => {
limiter.check_rate_limit_full(
self.type_,
&ip_addr,
rate_limit.register,
rate_limit.register_per_second,
true,
)?;
}
};
}
let res = fut.await;
// after
{
let mut limiter = self.rate_limiter.lock().await;
if res.is_ok() {
match self.type_ {
RateLimitType::Post => {
limiter.check_rate_limit_full(
self.type_,
&ip_addr,
rate_limit.post,
rate_limit.post_per_second,
false,
)?;
}
RateLimitType::Register => {
limiter.check_rate_limit_full(
self.type_,
&ip_addr,
rate_limit.register,
rate_limit.register_per_second,
false,
)?;
}
_ => (),
};
}
}
res
}
}
impl<S> Transform<S> for RateLimited
where
S: Service<Request = ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
S::Future: 'static,
{
type Request = S::Request;
type Response = S::Response;
type Error = actix_web::Error;
type InitError = ();
type Transform = RateLimitedMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(RateLimitedMiddleware {
rate_limited: self.clone(),
service,
})
}
}
type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
impl<S> Service for RateLimitedMiddleware<S>
where
S: Service<Request = ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
S::Future: 'static,
{
type Request = S::Request;
type Response = S::Response;
type Error = actix_web::Error;
type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: S::Request) -> Self::Future {
let ip_addr = get_ip(&req.connection_info());
let fut = self
.rate_limited
.clone()
.wrap(ip_addr, self.service.call(req));
Box::pin(async move { fut.await.map_err(actix_web::Error::from) })
}
}

View file

@ -0,0 +1,101 @@
use super::*;
#[derive(Debug, Clone)]
pub struct RateLimitBucket {
last_checked: SystemTime,
allowance: f64,
}
#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone)]
pub enum RateLimitType {
Message,
Register,
Post,
}
/// Rate limiting based on rate type and IP addr
#[derive(Debug, Clone)]
pub struct RateLimiter {
pub buckets: HashMap<RateLimitType, HashMap<IPAddr, RateLimitBucket>>,
}
impl Default for RateLimiter {
fn default() -> Self {
Self {
buckets: HashMap::new(),
}
}
}
impl RateLimiter {
fn insert_ip(&mut self, ip: &str) {
for rate_limit_type in RateLimitType::iter() {
if self.buckets.get(&rate_limit_type).is_none() {
self.buckets.insert(rate_limit_type, HashMap::new());
}
if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
if bucket.get(ip).is_none() {
bucket.insert(
ip.to_string(),
RateLimitBucket {
last_checked: SystemTime::now(),
allowance: -2f64,
},
);
}
}
}
}
#[allow(clippy::float_cmp)]
pub(super) fn check_rate_limit_full(
&mut self,
type_: RateLimitType,
ip: &str,
rate: i32,
per: i32,
check_only: bool,
) -> Result<(), Error> {
self.insert_ip(ip);
if let Some(bucket) = self.buckets.get_mut(&type_) {
if let Some(rate_limit) = bucket.get_mut(ip) {
let current = SystemTime::now();
let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
// The initial value
if rate_limit.allowance == -2f64 {
rate_limit.allowance = rate as f64;
};
rate_limit.last_checked = current;
rate_limit.allowance += time_passed * (rate as f64 / per as f64);
if !check_only && rate_limit.allowance > rate as f64 {
rate_limit.allowance = rate as f64;
}
if rate_limit.allowance < 1.0 {
debug!(
"Rate limited IP: {}, time_passed: {}, allowance: {}",
ip, time_passed, rate_limit.allowance
);
Err(
APIError {
message: format!("Too many requests. {} per {} seconds", rate, per),
}
.into(),
)
} else {
if !check_only {
rate_limit.allowance -= 1.0;
}
Ok(())
}
} else {
Ok(())
}
} else {
Ok(())
}
}
}

View file

@ -1,105 +1,185 @@
use super::*;
use crate::api::comment::*; use crate::api::comment::*;
use crate::api::community::*; use crate::api::community::*;
use crate::api::post::*; use crate::api::post::*;
use crate::api::site::*; use crate::api::site::*;
use crate::api::user::*; use crate::api::user::*;
use crate::api::{Oper, Perform}; use crate::rate_limit::RateLimit;
use actix_web::{web, HttpResponse}; use actix_web::guard;
use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection;
use failure::Error;
use serde::Serialize;
type DbParam = web::Data<Pool<ConnectionManager<PgConnection>>>; pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimit) {
cfg.service(
#[rustfmt::skip] web::scope("/api/v1")
pub fn config(cfg: &mut web::ServiceConfig) { // Websockets
cfg .service(web::resource("/ws").to(super::websocket::chat_route))
// Site // Site
.route("/api/v1/site", web::get().to(route_get::<GetSite, GetSiteResponse>)) .service(
.route("/api/v1/categories", web::get().to(route_get::<ListCategories, ListCategoriesResponse>)) web::scope("/site")
.route("/api/v1/modlog", web::get().to(route_get::<GetModlog, GetModlogResponse>)) .wrap(rate_limit.message())
.route("/api/v1/search", web::get().to(route_get::<Search, SearchResponse>)) .route("", web::get().to(route_get::<GetSite>))
// Community // Admin Actions
.route("/api/v1/community", web::post().to(route_post::<CreateCommunity, CommunityResponse>)) .route("", web::post().to(route_post::<CreateSite>))
.route("/api/v1/community", web::get().to(route_get::<GetCommunity, GetCommunityResponse>)) .route("", web::put().to(route_post::<EditSite>))
.route("/api/v1/community", web::put().to(route_post::<EditCommunity, CommunityResponse>)) .route("/transfer", web::post().to(route_post::<TransferSite>))
.route("/api/v1/community/list", web::get().to(route_get::<ListCommunities, ListCommunitiesResponse>)) .route("/config", web::get().to(route_get::<GetSiteConfig>))
.route("/api/v1/community/follow", web::post().to(route_post::<FollowCommunity, CommunityResponse>)) .route("/config", web::put().to(route_post::<SaveSiteConfig>)),
// Post )
.route("/api/v1/post", web::post().to(route_post::<CreatePost, PostResponse>)) .service(
.route("/api/v1/post", web::put().to(route_post::<EditPost, PostResponse>)) web::resource("/categories")
.route("/api/v1/post", web::get().to(route_get::<GetPost, GetPostResponse>)) .wrap(rate_limit.message())
.route("/api/v1/post/list", web::get().to(route_get::<GetPosts, GetPostsResponse>)) .route(web::get().to(route_get::<ListCategories>)),
.route("/api/v1/post/like", web::post().to(route_post::<CreatePostLike, PostResponse>)) )
.route("/api/v1/post/save", web::put().to(route_post::<SavePost, PostResponse>)) .service(
// Comment web::resource("/modlog")
.route("/api/v1/comment", web::post().to(route_post::<CreateComment, CommentResponse>)) .wrap(rate_limit.message())
.route("/api/v1/comment", web::put().to(route_post::<EditComment, CommentResponse>)) .route(web::get().to(route_get::<GetModlog>)),
.route("/api/v1/comment/like", web::post().to(route_post::<CreateCommentLike, CommentResponse>)) )
.route("/api/v1/comment/save", web::put().to(route_post::<SaveComment, CommentResponse>)) .service(
// User web::resource("/search")
.route("/api/v1/user", web::get().to(route_get::<GetUserDetails, GetUserDetailsResponse>)) .wrap(rate_limit.message())
.route("/api/v1/user/mention", web::get().to(route_get::<GetUserMentions, GetUserMentionsResponse>)) .route(web::get().to(route_get::<Search>)),
.route("/api/v1/user/mention", web::put().to(route_post::<EditUserMention, UserMentionResponse>)) )
.route("/api/v1/user/replies", web::get().to(route_get::<GetReplies, GetRepliesResponse>)) // Community
.route("/api/v1/user/followed_communities", web::get().to(route_get::<GetFollowedCommunities, GetFollowedCommunitiesResponse>)) .service(
// Mod actions web::resource("/community")
.route("/api/v1/community/transfer", web::post().to(route_post::<TransferCommunity, GetCommunityResponse>)) .guard(guard::Post())
.route("/api/v1/community/ban_user", web::post().to(route_post::<BanFromCommunity, BanFromCommunityResponse>)) .wrap(rate_limit.register())
.route("/api/v1/community/mod", web::post().to(route_post::<AddModToCommunity, AddModToCommunityResponse>)) .route(web::post().to(route_post::<CreateCommunity>)),
// Admin actions )
.route("/api/v1/site", web::post().to(route_post::<CreateSite, SiteResponse>)) .service(
.route("/api/v1/site", web::put().to(route_post::<EditSite, SiteResponse>)) web::scope("/community")
.route("/api/v1/site/transfer", web::post().to(route_post::<TransferSite, GetSiteResponse>)) .wrap(rate_limit.message())
.route("/api/v1/site/config", web::get().to(route_get::<GetSiteConfig, GetSiteConfigResponse>)) .route("", web::get().to(route_get::<GetCommunity>))
.route("/api/v1/site/config", web::put().to(route_post::<SaveSiteConfig, GetSiteConfigResponse>)) .route("", web::put().to(route_post::<EditCommunity>))
.route("/api/v1/admin/add", web::post().to(route_post::<AddAdmin, AddAdminResponse>)) .route("/list", web::get().to(route_get::<ListCommunities>))
.route("/api/v1/user/ban", web::post().to(route_post::<BanUser, BanUserResponse>)) .route("/follow", web::post().to(route_post::<FollowCommunity>))
// User account actions // Mod Actions
.route("/api/v1/user/login", web::post().to(route_post::<Login, LoginResponse>)) .route("/transfer", web::post().to(route_post::<TransferCommunity>))
.route("/api/v1/user/register", web::post().to(route_post::<Register, LoginResponse>)) .route("/ban_user", web::post().to(route_post::<BanFromCommunity>))
.route("/api/v1/user/delete_account", web::post().to(route_post::<DeleteAccount, LoginResponse>)) .route("/mod", web::post().to(route_post::<AddModToCommunity>)),
.route("/api/v1/user/password_reset", web::post().to(route_post::<PasswordReset, PasswordResetResponse>)) )
.route("/api/v1/user/password_change", web::post().to(route_post::<PasswordChange, LoginResponse>)) // Post
.route("/api/v1/user/mark_all_as_read", web::post().to(route_post::<MarkAllAsRead, GetRepliesResponse>)) .service(
.route("/api/v1/user/save_user_settings", web::put().to(route_post::<SaveUserSettings, LoginResponse>)); // Handle POST to /post separately to add the post() rate limitter
web::resource("/post")
.guard(guard::Post())
.wrap(rate_limit.post())
.route(web::post().to(route_post::<CreatePost>)),
)
.service(
web::scope("/post")
.wrap(rate_limit.message())
.route("", web::get().to(route_get::<GetPost>))
.route("", web::put().to(route_post::<EditPost>))
.route("/list", web::get().to(route_get::<GetPosts>))
.route("/like", web::post().to(route_post::<CreatePostLike>))
.route("/save", web::put().to(route_post::<SavePost>)),
)
// Comment
.service(
web::scope("/comment")
.wrap(rate_limit.message())
.route("", web::post().to(route_post::<CreateComment>))
.route("", web::put().to(route_post::<EditComment>))
.route("/like", web::post().to(route_post::<CreateCommentLike>))
.route("/save", web::put().to(route_post::<SaveComment>)),
)
// User
.service(
// Account action, I don't like that it's in /user maybe /accounts
// Handle /user/register separately to add the register() rate limitter
web::resource("/user/register")
.guard(guard::Post())
.wrap(rate_limit.register())
.route(web::post().to(route_post::<Register>)),
)
// User actions
.service(
web::scope("/user")
.wrap(rate_limit.message())
.route("", web::get().to(route_get::<GetUserDetails>))
.route("/mention", web::get().to(route_get::<GetUserMentions>))
.route("/mention", web::put().to(route_post::<EditUserMention>))
.route("/replies", web::get().to(route_get::<GetReplies>))
.route(
"/followed_communities",
web::get().to(route_get::<GetFollowedCommunities>),
)
// Admin action. I don't like that it's in /user
.route("/ban", web::post().to(route_post::<BanUser>))
// Account actions. I don't like that they're in /user maybe /accounts
.route("/login", web::post().to(route_post::<Login>))
.route(
"/delete_account",
web::post().to(route_post::<DeleteAccount>),
)
.route(
"/password_reset",
web::post().to(route_post::<PasswordReset>),
)
.route(
"/password_change",
web::post().to(route_post::<PasswordChange>),
)
// mark_all_as_read feels off being in this section as well
.route(
"/mark_all_as_read",
web::post().to(route_post::<MarkAllAsRead>),
)
.route(
"/save_user_settings",
web::put().to(route_post::<SaveUserSettings>),
),
)
// Admin Actions
.service(
web::resource("/admin/add")
.wrap(rate_limit.message())
.route(web::post().to(route_post::<AddAdmin>)),
),
);
} }
fn perform<Request, Response>(data: Request, db: DbParam) -> Result<HttpResponse, Error> fn perform<Request>(
data: Request,
db: DbPoolParam,
chat_server: ChatServerParam,
) -> Result<HttpResponse, Error>
where where
Response: Serialize, Oper<Request>: Perform,
Oper<Request>: Perform<Response>,
{ {
let conn = match db.get() { let ws_info = WebsocketInfo {
Ok(c) => c, chatserver: chat_server.get_ref().to_owned(),
Err(e) => return Err(format_err!("{}", e)), id: None,
}; };
let oper: Oper<Request> = Oper::new(data); let oper: Oper<Request> = Oper::new(data);
let response = oper.perform(&conn);
Ok(HttpResponse::Ok().json(response?)) let res = oper.perform(db.get_ref().to_owned(), Some(ws_info));
Ok(HttpResponse::Ok().json(res?))
} }
async fn route_get<Data, Response>( async fn route_get<Data>(
data: web::Query<Data>, data: web::Query<Data>,
db: DbParam, db: DbPoolParam,
chat_server: ChatServerParam,
) -> Result<HttpResponse, Error> ) -> Result<HttpResponse, Error>
where where
Data: Serialize, Data: Serialize,
Response: Serialize, Oper<Data>: Perform,
Oper<Data>: Perform<Response>,
{ {
perform::<Data, Response>(data.0, db) perform::<Data>(data.0, db, chat_server)
} }
async fn route_post<Data, Response>( async fn route_post<Data>(
data: web::Json<Data>, data: web::Json<Data>,
db: DbParam, db: DbPoolParam,
chat_server: ChatServerParam,
) -> Result<HttpResponse, Error> ) -> Result<HttpResponse, Error>
where where
Data: Serialize, Data: Serialize,
Response: Serialize, Oper<Data>: Perform,
Oper<Data>: Perform<Response>,
{ {
perform::<Data, Response>(data.0, db) perform::<Data>(data.0, db, chat_server)
} }

View file

@ -1,5 +1,5 @@
use super::*;
use crate::apub; use crate::apub;
use actix_web::web;
pub fn config(cfg: &mut web::ServiceConfig) { pub fn config(cfg: &mut web::ServiceConfig) {
cfg cfg

View file

@ -6,16 +6,6 @@ use crate::db::site_view::SiteView;
use crate::db::user::{Claims, User_}; use crate::db::user::{Claims, User_};
use crate::db::user_mention_view::{UserMentionQueryBuilder, UserMentionView}; use crate::db::user_mention_view::{UserMentionQueryBuilder, UserMentionView};
use crate::db::{ListingType, SortType}; use crate::db::{ListingType, SortType};
use crate::{markdown_to_html, Settings};
use actix_web::{web, HttpResponse, Result};
use chrono::{DateTime, NaiveDateTime, Utc};
use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection;
use failure::Error;
use rss::{CategoryBuilder, ChannelBuilder, GuidBuilder, Item, ItemBuilder};
use serde::Deserialize;
use std::str::FromStr;
use strum::ParseError;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Params { pub struct Params {

View file

@ -1,6 +1,4 @@
use crate::settings::Settings; use super::*;
use actix_files::NamedFile;
use actix_web::web;
pub fn config(cfg: &mut web::ServiceConfig) { pub fn config(cfg: &mut web::ServiceConfig) {
cfg cfg

View file

@ -1,3 +1,32 @@
use crate::api::{Oper, Perform};
use crate::db::site_view::SiteView;
use crate::rate_limit::rate_limiter::RateLimiter;
use crate::websocket::{server::ChatServer, WebsocketInfo};
use crate::{get_ip, markdown_to_html, version, Settings};
use actix::prelude::*;
use actix_files::NamedFile;
use actix_web::{body::Body, web::Query, *};
use actix_web_actors::ws;
use chrono::{DateTime, NaiveDateTime, Utc};
use diesel::{
r2d2::{ConnectionManager, Pool},
PgConnection,
};
use failure::Error;
use log::{error, info};
use regex::Regex;
use rss::{CategoryBuilder, ChannelBuilder, GuidBuilder, Item, ItemBuilder};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use strum::ParseError;
pub type DbPoolParam = web::Data<Pool<ConnectionManager<PgConnection>>>;
pub type RateLimitParam = web::Data<Arc<Mutex<RateLimiter>>>;
pub type ChatServerParam = web::Data<Addr<ChatServer>>;
pub mod api; pub mod api;
pub mod federation; pub mod federation;
pub mod feeds; pub mod feeds;

View file

@ -1,12 +1,4 @@
use crate::db::site_view::SiteView; use super::*;
use crate::version;
use crate::Settings;
use actix_web::body::Body;
use actix_web::web;
use actix_web::HttpResponse;
use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection;
use serde::Serialize;
pub fn config(cfg: &mut web::ServiceConfig) { pub fn config(cfg: &mut web::ServiceConfig) {
cfg cfg

View file

@ -1,13 +1,5 @@
use super::*;
use crate::db::community::Community; use crate::db::community::Community;
use crate::Settings;
use actix_web::web;
use actix_web::web::Query;
use actix_web::HttpResponse;
use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection;
use regex::Regex;
use serde::Deserialize;
use serde_json::json;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Params { pub struct Params {

View file

@ -1,14 +1,6 @@
use super::*;
use crate::websocket::server::*; use crate::websocket::server::*;
use actix::prelude::*; use actix_web::{Error, Result};
use actix_web::web;
use actix_web::*;
use actix_web_actors::ws;
use log::{error, info};
use std::time::{Duration, Instant};
pub fn config(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("/api/v1/ws").to(chat_route));
}
/// How often heartbeat pings are sent /// How often heartbeat pings are sent
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -16,25 +8,17 @@ const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
/// Entry point for our route /// Entry point for our route
async fn chat_route( pub async fn chat_route(
req: HttpRequest, req: HttpRequest,
stream: web::Payload, stream: web::Payload,
chat_server: web::Data<Addr<ChatServer>>, chat_server: web::Data<Addr<ChatServer>>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
// TODO not sure if the blocking should be here or not
ws::start( ws::start(
WSSession { WSSession {
cs_addr: chat_server.get_ref().to_owned(), cs_addr: chat_server.get_ref().to_owned(),
id: 0, id: 0,
hb: Instant::now(), hb: Instant::now(),
ip: req ip: get_ip(&req.connection_info()),
.connection_info()
.remote()
.unwrap_or("127.0.0.1:12345")
.split(':')
.next()
.unwrap_or("127.0.0.1")
.to_string(),
}, },
&req, &req,
stream, stream,
@ -135,10 +119,9 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSSession {
.into_actor(self) .into_actor(self)
.then(|res, _, ctx| { .then(|res, _, ctx| {
match res { match res {
Ok(res) => ctx.text(res), Ok(Ok(res)) => ctx.text(res),
Err(e) => { Ok(Err(e)) => match e {},
error!("{}", &e); Err(e) => error!("{}", &e),
}
} }
actix::fut::ready(()) actix::fut::ready(())
}) })

View file

@ -1,6 +1,19 @@
pub mod server; pub mod server;
#[derive(EnumString, ToString, Debug)] use crate::ConnectionId;
use actix::prelude::*;
use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection;
use failure::Error;
use log::{error, info};
use rand::{rngs::ThreadRng, Rng};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use server::ChatServer;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
#[derive(EnumString, ToString, Debug, Clone)]
pub enum UserOperation { pub enum UserOperation {
Login, Login,
Register, Register,
@ -49,3 +62,9 @@ pub enum UserOperation {
GetSiteConfig, GetSiteConfig,
SaveSiteConfig, SaveSiteConfig,
} }
#[derive(Clone)]
pub struct WebsocketInfo {
pub chatserver: Addr<ChatServer>,
pub id: Option<ConnectionId>,
}

File diff suppressed because it is too large Load diff