Merge remote-tracking branch 'origin/main' into combined_tables_2

This commit is contained in:
Dessalines 2024-12-02 17:07:03 -05:00
commit 46748210ee
12 changed files with 139 additions and 27 deletions

View file

@ -156,7 +156,6 @@ test("Delete a comment", async () => {
commentRes.comment_view.comment.id, commentRes.comment_view.comment.id,
); );
expect(deleteCommentRes.comment_view.comment.deleted).toBe(true); expect(deleteCommentRes.comment_view.comment.deleted).toBe(true);
expect(deleteCommentRes.comment_view.comment.content).toBe("");
// Make sure that comment is deleted on beta // Make sure that comment is deleted on beta
await waitUntil( await waitUntil(
@ -254,7 +253,6 @@ test("Remove a comment from admin and community on different instance", async ()
betaComment.comment.id, betaComment.comment.id,
); );
expect(removeCommentRes.comment_view.comment.removed).toBe(true); expect(removeCommentRes.comment_view.comment.removed).toBe(true);
expect(removeCommentRes.comment_view.comment.content).toBe("");
// Comment text is also hidden from list // Comment text is also hidden from list
let listComments = await getComments( let listComments = await getComments(
@ -263,7 +261,6 @@ test("Remove a comment from admin and community on different instance", async ()
); );
expect(listComments.comments.length).toBe(1); expect(listComments.comments.length).toBe(1);
expect(listComments.comments[0].comment.removed).toBe(true); expect(listComments.comments[0].comment.removed).toBe(true);
expect(listComments.comments[0].comment.content).toBe("");
// Make sure its not removed on alpha // Make sure its not removed on alpha
let refetchedPostComments = await getComments( let refetchedPostComments = await getComments(

View file

@ -25,6 +25,8 @@ pub struct CreateOAuthProvider {
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub use_pkce: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub enabled: Option<bool>, pub enabled: Option<bool>,
} }
@ -54,6 +56,8 @@ pub struct EditOAuthProvider {
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub use_pkce: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub enabled: Option<bool>, pub enabled: Option<bool>,
} }
@ -82,4 +86,6 @@ pub struct AuthenticateWithOauth {
/// An answer is mandatory if require application is enabled on the server /// An answer is mandatory if require application is enabled on the server
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub answer: Option<String>, pub answer: Option<String>,
#[cfg_attr(feature = "full", ts(optional))]
pub pkce_code_verifier: Option<String>,
} }

View file

@ -29,6 +29,7 @@ anyhow.workspace = true
chrono.workspace = true chrono.workspace = true
webmention = "0.6.0" webmention = "0.6.0"
accept-language = "3.1.0" accept-language = "3.1.0"
regex = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_with = { workspace = true } serde_with = { workspace = true }

View file

@ -35,6 +35,7 @@ pub async fn create_oauth_provider(
scopes: data.scopes.to_string(), scopes: data.scopes.to_string(),
auto_verify_email: data.auto_verify_email, auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled, account_linking_enabled: data.account_linking_enabled,
use_pkce: data.use_pkce,
enabled: data.enabled, enabled: data.enabled,
}; };
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?; let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;

View file

@ -33,6 +33,7 @@ pub async fn update_oauth_provider(
auto_verify_email: data.auto_verify_email, auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled, account_linking_enabled: data.account_linking_enabled,
enabled: data.enabled, enabled: data.enabled,
use_pkce: data.use_pkce,
updated: Some(Some(Utc::now())), updated: Some(Some(Utc::now())),
}; };

View file

@ -45,9 +45,10 @@ use lemmy_utils::{
validation::is_valid_actor_name, validation::is_valid_actor_name,
}, },
}; };
use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use std::collections::HashSet; use std::{collections::HashSet, sync::LazyLock};
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, Clone, Default)] #[derive(Debug, Serialize, Deserialize, Clone, Default)]
@ -225,6 +226,11 @@ pub async fn authenticate_with_oauth(
Err(LemmyErrorType::OauthAuthorizationInvalid)? Err(LemmyErrorType::OauthAuthorizationInvalid)?
} }
// validate the PKCE challenge
if let Some(code_verifier) = &data.pkce_code_verifier {
check_code_verifier(code_verifier)?;
}
// Fetch the OAUTH provider and make sure it's enabled // Fetch the OAUTH provider and make sure it's enabled
let oauth_provider_id = data.oauth_provider_id; let oauth_provider_id = data.oauth_provider_id;
let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id) let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id)
@ -236,8 +242,13 @@ pub async fn authenticate_with_oauth(
return Err(LemmyErrorType::OauthAuthorizationInvalid)?; return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
} }
let token_response = let token_response = oauth_request_access_token(
oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str()) &context,
&oauth_provider,
&data.code,
data.pkce_code_verifier.as_deref(),
redirect_uri.as_str(),
)
.await?; .await?;
let user_info = oidc_get_user_info( let user_info = oidc_get_user_info(
@ -533,20 +544,27 @@ async fn oauth_request_access_token(
context: &Data<LemmyContext>, context: &Data<LemmyContext>,
oauth_provider: &OAuthProvider, oauth_provider: &OAuthProvider,
code: &str, code: &str,
pkce_code_verifier: Option<&str>,
redirect_uri: &str, redirect_uri: &str,
) -> LemmyResult<TokenResponse> { ) -> LemmyResult<TokenResponse> {
let mut form = vec![
("client_id", &*oauth_provider.client_id),
("client_secret", &*oauth_provider.client_secret),
("code", code),
("grant_type", "authorization_code"),
("redirect_uri", redirect_uri),
];
if let Some(code_verifier) = pkce_code_verifier {
form.push(("code_verifier", code_verifier));
}
// Request an Access Token from the OAUTH provider // Request an Access Token from the OAUTH provider
let response = context let response = context
.client() .client()
.post(oauth_provider.token_endpoint.as_str()) .post(oauth_provider.token_endpoint.as_str())
.header("Accept", "application/json") .header("Accept", "application/json")
.form(&[ .form(&form[..])
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", &oauth_provider.client_id),
("client_secret", &oauth_provider.client_secret),
])
.send() .send()
.await .await
.with_lemmy_type(LemmyErrorType::OauthLoginFailed)? .with_lemmy_type(LemmyErrorType::OauthLoginFailed)?
@ -596,3 +614,17 @@ fn read_user_info(user_info: &serde_json::Value, key: &str) -> LemmyResult<Strin
} }
Err(LemmyErrorType::OauthLoginFailed)? Err(LemmyErrorType::OauthLoginFailed)?
} }
#[allow(clippy::expect_used)]
fn check_code_verifier(code_verifier: &str) -> LemmyResult<()> {
static VALID_CODE_VERIFIER_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9\-._~]{43,128}$").expect("compile regex"));
let check = VALID_CODE_VERIFIER_REGEX.is_match(code_verifier);
if check {
Ok(())
} else {
Err(LemmyErrorType::InvalidCodeVerifier.into())
}
}

View file

@ -660,6 +660,7 @@ diesel::table! {
enabled -> Bool, enabled -> Bool,
published -> Timestamptz, published -> Timestamptz,
updated -> Nullable<Timestamptz>, updated -> Nullable<Timestamptz>,
use_pkce -> Bool,
} }
} }

View file

@ -62,6 +62,8 @@ pub struct OAuthProvider {
pub published: DateTime<Utc>, pub published: DateTime<Utc>,
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub updated: Option<DateTime<Utc>>, pub updated: Option<DateTime<Utc>>,
/// switch to enable or disable PKCE
pub use_pkce: bool,
} }
#[derive(Clone, PartialEq, Eq, Debug, Deserialize)] #[derive(Clone, PartialEq, Eq, Debug, Deserialize)]
@ -83,6 +85,7 @@ impl Serialize for PublicOAuthProvider {
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?; state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
state.serialize_field("client_id", &self.0.client_id)?; state.serialize_field("client_id", &self.0.client_id)?;
state.serialize_field("scopes", &self.0.scopes)?; state.serialize_field("scopes", &self.0.scopes)?;
state.serialize_field("use_pkce", &self.0.use_pkce)?;
state.end() state.end()
} }
} }
@ -102,6 +105,7 @@ pub struct OAuthProviderInsertForm {
pub scopes: String, pub scopes: String,
pub auto_verify_email: Option<bool>, pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
pub use_pkce: Option<bool>,
pub enabled: Option<bool>, pub enabled: Option<bool>,
} }
@ -118,6 +122,7 @@ pub struct OAuthProviderUpdateForm {
pub scopes: Option<String>, pub scopes: Option<String>,
pub auto_verify_email: Option<bool>, pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
pub use_pkce: Option<bool>,
pub enabled: Option<bool>, pub enabled: Option<bool>,
pub updated: Option<Option<DateTime<Utc>>>, pub updated: Option<Option<DateTime<Utc>>>,
} }

View file

@ -316,17 +316,14 @@ impl CommentView {
comment_id: CommentId, comment_id: CommentId,
my_local_user: Option<&'_ LocalUser>, my_local_user: Option<&'_ LocalUser>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let is_admin = my_local_user.map(|u| u.admin).unwrap_or(false);
// If a person is given, then my_vote (res.9), if None, should be 0, not null // If a person is given, then my_vote (res.9), if None, should be 0, not null
// Necessary to differentiate between other person's votes // Necessary to differentiate between other person's votes
let res = queries().read(pool, (comment_id, my_local_user)).await?; let mut res = queries().read(pool, (comment_id, my_local_user)).await?;
let mut new_view = res.clone();
if my_local_user.is_some() && res.my_vote.is_none() { if my_local_user.is_some() && res.my_vote.is_none() {
new_view.my_vote = Some(0); res.my_vote = Some(0);
} }
if res.comment.deleted || res.comment.removed { Ok(handle_deleted(res, is_admin))
new_view.comment.content = String::new();
}
Ok(new_view)
} }
} }
@ -350,22 +347,25 @@ pub struct CommentQuery<'a> {
impl CommentQuery<'_> { impl CommentQuery<'_> {
pub async fn list(self, site: &Site, pool: &mut DbPool<'_>) -> Result<Vec<CommentView>, Error> { pub async fn list(self, site: &Site, pool: &mut DbPool<'_>) -> Result<Vec<CommentView>, Error> {
let is_admin = self.local_user.map(|u| u.admin).unwrap_or(false);
Ok( Ok(
queries() queries()
.list(pool, (self, site)) .list(pool, (self, site))
.await? .await?
.into_iter() .into_iter()
.map(|mut c| { .map(|c| handle_deleted(c, is_admin))
if c.comment.deleted || c.comment.removed {
c.comment.content = String::new();
}
c
})
.collect(), .collect(),
) )
} }
} }
fn handle_deleted(mut c: CommentView, is_admin: bool) -> CommentView {
if !is_admin && (c.comment.deleted || c.comment.removed) {
c.comment.content = String::new();
}
c
}
#[cfg(test)] #[cfg(test)]
#[expect(clippy::indexing_slicing)] #[expect(clippy::indexing_slicing)]
mod tests { mod tests {
@ -1301,4 +1301,65 @@ mod tests {
cleanup(data, pool).await cleanup(data, pool).await
} }
#[tokio::test]
#[serial]
async fn comment_removed() -> LemmyResult<()> {
let pool = &build_db_pool_for_tests();
let pool = &mut pool.into();
let mut data = init_data(pool).await?;
// Mark a comment as removed
let form = CommentUpdateForm {
removed: Some(true),
..Default::default()
};
Comment::update(pool, data.inserted_comment_0.id, &form).await?;
// Read as normal user, content is cleared
data.timmy_local_user_view.local_user.admin = false;
let comment_view = CommentView::read(
pool,
data.inserted_comment_0.id,
Some(&data.timmy_local_user_view.local_user),
)
.await?;
assert_eq!("", comment_view.comment.content);
let comment_listing = CommentQuery {
community_id: Some(data.inserted_community.id),
local_user: Some(&data.timmy_local_user_view.local_user),
sort: Some(CommentSortType::Old),
..Default::default()
}
.list(&data.site, pool)
.await?;
assert_eq!("", comment_listing[0].comment.content);
// Read as admin, content is returned
data.timmy_local_user_view.local_user.admin = true;
let comment_view = CommentView::read(
pool,
data.inserted_comment_0.id,
Some(&data.timmy_local_user_view.local_user),
)
.await?;
assert_eq!(
data.inserted_comment_0.content,
comment_view.comment.content
);
let comment_listing = CommentQuery {
community_id: Some(data.inserted_community.id),
local_user: Some(&data.timmy_local_user_view.local_user),
sort: Some(CommentSortType::Old),
..Default::default()
}
.list(&data.site, pool)
.await?;
assert_eq!(
data.inserted_comment_0.content,
comment_listing[0].comment.content
);
cleanup(data, pool).await
}
} }

View file

@ -76,6 +76,7 @@ pub enum LemmyErrorType {
InvalidEmailAddress(String), InvalidEmailAddress(String),
RateLimitError, RateLimitError,
InvalidName, InvalidName,
InvalidCodeVerifier,
InvalidDisplayName, InvalidDisplayName,
InvalidMatrixId, InvalidMatrixId,
InvalidPostTitle, InvalidPostTitle,

View file

@ -0,0 +1,3 @@
ALTER TABLE oauth_provider
DROP COLUMN use_pkce;

View file

@ -0,0 +1,3 @@
ALTER TABLE oauth_provider
ADD COLUMN use_pkce boolean DEFAULT FALSE NOT NULL;