mirror of
https://github.com/LemmyNet/lemmy.git
synced 2024-12-23 03:11:32 +00:00
Merge remote-tracking branch 'origin/main' into combined_tables_2
This commit is contained in:
commit
46748210ee
12 changed files with 139 additions and 27 deletions
|
@ -156,7 +156,6 @@ test("Delete a comment", async () => {
|
||||||
commentRes.comment_view.comment.id,
|
commentRes.comment_view.comment.id,
|
||||||
);
|
);
|
||||||
expect(deleteCommentRes.comment_view.comment.deleted).toBe(true);
|
expect(deleteCommentRes.comment_view.comment.deleted).toBe(true);
|
||||||
expect(deleteCommentRes.comment_view.comment.content).toBe("");
|
|
||||||
|
|
||||||
// Make sure that comment is deleted on beta
|
// Make sure that comment is deleted on beta
|
||||||
await waitUntil(
|
await waitUntil(
|
||||||
|
@ -254,7 +253,6 @@ test("Remove a comment from admin and community on different instance", async ()
|
||||||
betaComment.comment.id,
|
betaComment.comment.id,
|
||||||
);
|
);
|
||||||
expect(removeCommentRes.comment_view.comment.removed).toBe(true);
|
expect(removeCommentRes.comment_view.comment.removed).toBe(true);
|
||||||
expect(removeCommentRes.comment_view.comment.content).toBe("");
|
|
||||||
|
|
||||||
// Comment text is also hidden from list
|
// Comment text is also hidden from list
|
||||||
let listComments = await getComments(
|
let listComments = await getComments(
|
||||||
|
@ -263,7 +261,6 @@ test("Remove a comment from admin and community on different instance", async ()
|
||||||
);
|
);
|
||||||
expect(listComments.comments.length).toBe(1);
|
expect(listComments.comments.length).toBe(1);
|
||||||
expect(listComments.comments[0].comment.removed).toBe(true);
|
expect(listComments.comments[0].comment.removed).toBe(true);
|
||||||
expect(listComments.comments[0].comment.content).toBe("");
|
|
||||||
|
|
||||||
// Make sure its not removed on alpha
|
// Make sure its not removed on alpha
|
||||||
let refetchedPostComments = await getComments(
|
let refetchedPostComments = await getComments(
|
||||||
|
|
|
@ -25,6 +25,8 @@ pub struct CreateOAuthProvider {
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +56,8 @@ pub struct EditOAuthProvider {
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,4 +86,6 @@ pub struct AuthenticateWithOauth {
|
||||||
/// An answer is mandatory if require application is enabled on the server
|
/// An answer is mandatory if require application is enabled on the server
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub answer: Option<String>,
|
pub answer: Option<String>,
|
||||||
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
|
pub pkce_code_verifier: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ anyhow.workspace = true
|
||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
webmention = "0.6.0"
|
webmention = "0.6.0"
|
||||||
accept-language = "3.1.0"
|
accept-language = "3.1.0"
|
||||||
|
regex = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_with = { workspace = true }
|
serde_with = { workspace = true }
|
||||||
|
|
|
@ -35,6 +35,7 @@ pub async fn create_oauth_provider(
|
||||||
scopes: data.scopes.to_string(),
|
scopes: data.scopes.to_string(),
|
||||||
auto_verify_email: data.auto_verify_email,
|
auto_verify_email: data.auto_verify_email,
|
||||||
account_linking_enabled: data.account_linking_enabled,
|
account_linking_enabled: data.account_linking_enabled,
|
||||||
|
use_pkce: data.use_pkce,
|
||||||
enabled: data.enabled,
|
enabled: data.enabled,
|
||||||
};
|
};
|
||||||
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
|
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
|
||||||
|
|
|
@ -33,6 +33,7 @@ pub async fn update_oauth_provider(
|
||||||
auto_verify_email: data.auto_verify_email,
|
auto_verify_email: data.auto_verify_email,
|
||||||
account_linking_enabled: data.account_linking_enabled,
|
account_linking_enabled: data.account_linking_enabled,
|
||||||
enabled: data.enabled,
|
enabled: data.enabled,
|
||||||
|
use_pkce: data.use_pkce,
|
||||||
updated: Some(Some(Utc::now())),
|
updated: Some(Some(Utc::now())),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -45,9 +45,10 @@ use lemmy_utils::{
|
||||||
validation::is_valid_actor_name,
|
validation::is_valid_actor_name,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
use regex::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_with::skip_serializing_none;
|
use serde_with::skip_serializing_none;
|
||||||
use std::collections::HashSet;
|
use std::{collections::HashSet, sync::LazyLock};
|
||||||
|
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
|
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
|
||||||
|
@ -225,6 +226,11 @@ pub async fn authenticate_with_oauth(
|
||||||
Err(LemmyErrorType::OauthAuthorizationInvalid)?
|
Err(LemmyErrorType::OauthAuthorizationInvalid)?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validate the PKCE challenge
|
||||||
|
if let Some(code_verifier) = &data.pkce_code_verifier {
|
||||||
|
check_code_verifier(code_verifier)?;
|
||||||
|
}
|
||||||
|
|
||||||
// Fetch the OAUTH provider and make sure it's enabled
|
// Fetch the OAUTH provider and make sure it's enabled
|
||||||
let oauth_provider_id = data.oauth_provider_id;
|
let oauth_provider_id = data.oauth_provider_id;
|
||||||
let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id)
|
let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id)
|
||||||
|
@ -236,8 +242,13 @@ pub async fn authenticate_with_oauth(
|
||||||
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
|
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_response =
|
let token_response = oauth_request_access_token(
|
||||||
oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str())
|
&context,
|
||||||
|
&oauth_provider,
|
||||||
|
&data.code,
|
||||||
|
data.pkce_code_verifier.as_deref(),
|
||||||
|
redirect_uri.as_str(),
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let user_info = oidc_get_user_info(
|
let user_info = oidc_get_user_info(
|
||||||
|
@ -533,20 +544,27 @@ async fn oauth_request_access_token(
|
||||||
context: &Data<LemmyContext>,
|
context: &Data<LemmyContext>,
|
||||||
oauth_provider: &OAuthProvider,
|
oauth_provider: &OAuthProvider,
|
||||||
code: &str,
|
code: &str,
|
||||||
|
pkce_code_verifier: Option<&str>,
|
||||||
redirect_uri: &str,
|
redirect_uri: &str,
|
||||||
) -> LemmyResult<TokenResponse> {
|
) -> LemmyResult<TokenResponse> {
|
||||||
|
let mut form = vec![
|
||||||
|
("client_id", &*oauth_provider.client_id),
|
||||||
|
("client_secret", &*oauth_provider.client_secret),
|
||||||
|
("code", code),
|
||||||
|
("grant_type", "authorization_code"),
|
||||||
|
("redirect_uri", redirect_uri),
|
||||||
|
];
|
||||||
|
|
||||||
|
if let Some(code_verifier) = pkce_code_verifier {
|
||||||
|
form.push(("code_verifier", code_verifier));
|
||||||
|
}
|
||||||
|
|
||||||
// Request an Access Token from the OAUTH provider
|
// Request an Access Token from the OAUTH provider
|
||||||
let response = context
|
let response = context
|
||||||
.client()
|
.client()
|
||||||
.post(oauth_provider.token_endpoint.as_str())
|
.post(oauth_provider.token_endpoint.as_str())
|
||||||
.header("Accept", "application/json")
|
.header("Accept", "application/json")
|
||||||
.form(&[
|
.form(&form[..])
|
||||||
("grant_type", "authorization_code"),
|
|
||||||
("code", code),
|
|
||||||
("redirect_uri", redirect_uri),
|
|
||||||
("client_id", &oauth_provider.client_id),
|
|
||||||
("client_secret", &oauth_provider.client_secret),
|
|
||||||
])
|
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.with_lemmy_type(LemmyErrorType::OauthLoginFailed)?
|
.with_lemmy_type(LemmyErrorType::OauthLoginFailed)?
|
||||||
|
@ -596,3 +614,17 @@ fn read_user_info(user_info: &serde_json::Value, key: &str) -> LemmyResult<Strin
|
||||||
}
|
}
|
||||||
Err(LemmyErrorType::OauthLoginFailed)?
|
Err(LemmyErrorType::OauthLoginFailed)?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
|
fn check_code_verifier(code_verifier: &str) -> LemmyResult<()> {
|
||||||
|
static VALID_CODE_VERIFIER_REGEX: LazyLock<Regex> =
|
||||||
|
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9\-._~]{43,128}$").expect("compile regex"));
|
||||||
|
|
||||||
|
let check = VALID_CODE_VERIFIER_REGEX.is_match(code_verifier);
|
||||||
|
|
||||||
|
if check {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(LemmyErrorType::InvalidCodeVerifier.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -660,6 +660,7 @@ diesel::table! {
|
||||||
enabled -> Bool,
|
enabled -> Bool,
|
||||||
published -> Timestamptz,
|
published -> Timestamptz,
|
||||||
updated -> Nullable<Timestamptz>,
|
updated -> Nullable<Timestamptz>,
|
||||||
|
use_pkce -> Bool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,8 @@ pub struct OAuthProvider {
|
||||||
pub published: DateTime<Utc>,
|
pub published: DateTime<Utc>,
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub updated: Option<DateTime<Utc>>,
|
pub updated: Option<DateTime<Utc>>,
|
||||||
|
/// switch to enable or disable PKCE
|
||||||
|
pub use_pkce: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize)]
|
#[derive(Clone, PartialEq, Eq, Debug, Deserialize)]
|
||||||
|
@ -83,6 +85,7 @@ impl Serialize for PublicOAuthProvider {
|
||||||
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
|
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
|
||||||
state.serialize_field("client_id", &self.0.client_id)?;
|
state.serialize_field("client_id", &self.0.client_id)?;
|
||||||
state.serialize_field("scopes", &self.0.scopes)?;
|
state.serialize_field("scopes", &self.0.scopes)?;
|
||||||
|
state.serialize_field("use_pkce", &self.0.use_pkce)?;
|
||||||
state.end()
|
state.end()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,6 +105,7 @@ pub struct OAuthProviderInsertForm {
|
||||||
pub scopes: String,
|
pub scopes: String,
|
||||||
pub auto_verify_email: Option<bool>,
|
pub auto_verify_email: Option<bool>,
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +122,7 @@ pub struct OAuthProviderUpdateForm {
|
||||||
pub scopes: Option<String>,
|
pub scopes: Option<String>,
|
||||||
pub auto_verify_email: Option<bool>,
|
pub auto_verify_email: Option<bool>,
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
pub updated: Option<Option<DateTime<Utc>>>,
|
pub updated: Option<Option<DateTime<Utc>>>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -316,17 +316,14 @@ impl CommentView {
|
||||||
comment_id: CommentId,
|
comment_id: CommentId,
|
||||||
my_local_user: Option<&'_ LocalUser>,
|
my_local_user: Option<&'_ LocalUser>,
|
||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
|
let is_admin = my_local_user.map(|u| u.admin).unwrap_or(false);
|
||||||
// If a person is given, then my_vote (res.9), if None, should be 0, not null
|
// If a person is given, then my_vote (res.9), if None, should be 0, not null
|
||||||
// Necessary to differentiate between other person's votes
|
// Necessary to differentiate between other person's votes
|
||||||
let res = queries().read(pool, (comment_id, my_local_user)).await?;
|
let mut res = queries().read(pool, (comment_id, my_local_user)).await?;
|
||||||
let mut new_view = res.clone();
|
|
||||||
if my_local_user.is_some() && res.my_vote.is_none() {
|
if my_local_user.is_some() && res.my_vote.is_none() {
|
||||||
new_view.my_vote = Some(0);
|
res.my_vote = Some(0);
|
||||||
}
|
}
|
||||||
if res.comment.deleted || res.comment.removed {
|
Ok(handle_deleted(res, is_admin))
|
||||||
new_view.comment.content = String::new();
|
|
||||||
}
|
|
||||||
Ok(new_view)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,22 +347,25 @@ pub struct CommentQuery<'a> {
|
||||||
|
|
||||||
impl CommentQuery<'_> {
|
impl CommentQuery<'_> {
|
||||||
pub async fn list(self, site: &Site, pool: &mut DbPool<'_>) -> Result<Vec<CommentView>, Error> {
|
pub async fn list(self, site: &Site, pool: &mut DbPool<'_>) -> Result<Vec<CommentView>, Error> {
|
||||||
|
let is_admin = self.local_user.map(|u| u.admin).unwrap_or(false);
|
||||||
Ok(
|
Ok(
|
||||||
queries()
|
queries()
|
||||||
.list(pool, (self, site))
|
.list(pool, (self, site))
|
||||||
.await?
|
.await?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|mut c| {
|
.map(|c| handle_deleted(c, is_admin))
|
||||||
if c.comment.deleted || c.comment.removed {
|
|
||||||
c.comment.content = String::new();
|
|
||||||
}
|
|
||||||
c
|
|
||||||
})
|
|
||||||
.collect(),
|
.collect(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_deleted(mut c: CommentView, is_admin: bool) -> CommentView {
|
||||||
|
if !is_admin && (c.comment.deleted || c.comment.removed) {
|
||||||
|
c.comment.content = String::new();
|
||||||
|
}
|
||||||
|
c
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[expect(clippy::indexing_slicing)]
|
#[expect(clippy::indexing_slicing)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
@ -1301,4 +1301,65 @@ mod tests {
|
||||||
|
|
||||||
cleanup(data, pool).await
|
cleanup(data, pool).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn comment_removed() -> LemmyResult<()> {
|
||||||
|
let pool = &build_db_pool_for_tests();
|
||||||
|
let pool = &mut pool.into();
|
||||||
|
let mut data = init_data(pool).await?;
|
||||||
|
|
||||||
|
// Mark a comment as removed
|
||||||
|
let form = CommentUpdateForm {
|
||||||
|
removed: Some(true),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
Comment::update(pool, data.inserted_comment_0.id, &form).await?;
|
||||||
|
|
||||||
|
// Read as normal user, content is cleared
|
||||||
|
data.timmy_local_user_view.local_user.admin = false;
|
||||||
|
let comment_view = CommentView::read(
|
||||||
|
pool,
|
||||||
|
data.inserted_comment_0.id,
|
||||||
|
Some(&data.timmy_local_user_view.local_user),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
assert_eq!("", comment_view.comment.content);
|
||||||
|
let comment_listing = CommentQuery {
|
||||||
|
community_id: Some(data.inserted_community.id),
|
||||||
|
local_user: Some(&data.timmy_local_user_view.local_user),
|
||||||
|
sort: Some(CommentSortType::Old),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.list(&data.site, pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!("", comment_listing[0].comment.content);
|
||||||
|
|
||||||
|
// Read as admin, content is returned
|
||||||
|
data.timmy_local_user_view.local_user.admin = true;
|
||||||
|
let comment_view = CommentView::read(
|
||||||
|
pool,
|
||||||
|
data.inserted_comment_0.id,
|
||||||
|
Some(&data.timmy_local_user_view.local_user),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(
|
||||||
|
data.inserted_comment_0.content,
|
||||||
|
comment_view.comment.content
|
||||||
|
);
|
||||||
|
let comment_listing = CommentQuery {
|
||||||
|
community_id: Some(data.inserted_community.id),
|
||||||
|
local_user: Some(&data.timmy_local_user_view.local_user),
|
||||||
|
sort: Some(CommentSortType::Old),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.list(&data.site, pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(
|
||||||
|
data.inserted_comment_0.content,
|
||||||
|
comment_listing[0].comment.content
|
||||||
|
);
|
||||||
|
|
||||||
|
cleanup(data, pool).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,6 +76,7 @@ pub enum LemmyErrorType {
|
||||||
InvalidEmailAddress(String),
|
InvalidEmailAddress(String),
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
InvalidName,
|
InvalidName,
|
||||||
|
InvalidCodeVerifier,
|
||||||
InvalidDisplayName,
|
InvalidDisplayName,
|
||||||
InvalidMatrixId,
|
InvalidMatrixId,
|
||||||
InvalidPostTitle,
|
InvalidPostTitle,
|
||||||
|
|
3
migrations/2024-11-23-234637_oauth_pkce/down.sql
Normal file
3
migrations/2024-11-23-234637_oauth_pkce/down.sql
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
ALTER TABLE oauth_provider
|
||||||
|
DROP COLUMN use_pkce;
|
||||||
|
|
3
migrations/2024-11-23-234637_oauth_pkce/up.sql
Normal file
3
migrations/2024-11-23-234637_oauth_pkce/up.sql
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
ALTER TABLE oauth_provider
|
||||||
|
ADD COLUMN use_pkce boolean DEFAULT FALSE NOT NULL;
|
||||||
|
|
Loading…
Reference in a new issue