parent
ba3e6b482b
commit
9505d1d205
10 changed files with 65 additions and 11 deletions
|
@ -25,6 +25,8 @@ pub struct CreateOAuthProvider {
|
|||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub account_linking_enabled: Option<bool>,
|
||||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub use_pkce: Option<bool>,
|
||||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
|
@ -54,6 +56,8 @@ pub struct EditOAuthProvider {
|
|||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub account_linking_enabled: Option<bool>,
|
||||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub use_pkce: Option<bool>,
|
||||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
|
@ -82,4 +86,6 @@ pub struct AuthenticateWithOauth {
|
|||
/// An answer is mandatory if require application is enabled on the server
|
||||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub answer: Option<String>,
|
||||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub pkce_code_verifier: Option<String>,
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ anyhow.workspace = true
|
|||
chrono.workspace = true
|
||||
webmention = "0.6.0"
|
||||
accept-language = "3.1.0"
|
||||
regex = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_with = { workspace = true }
|
||||
|
|
|
@ -35,6 +35,7 @@ pub async fn create_oauth_provider(
|
|||
scopes: data.scopes.to_string(),
|
||||
auto_verify_email: data.auto_verify_email,
|
||||
account_linking_enabled: data.account_linking_enabled,
|
||||
use_pkce: data.use_pkce,
|
||||
enabled: data.enabled,
|
||||
};
|
||||
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
|
||||
|
|
|
@ -33,6 +33,7 @@ pub async fn update_oauth_provider(
|
|||
auto_verify_email: data.auto_verify_email,
|
||||
account_linking_enabled: data.account_linking_enabled,
|
||||
enabled: data.enabled,
|
||||
use_pkce: data.use_pkce,
|
||||
updated: Some(Some(Utc::now())),
|
||||
};
|
||||
|
||||
|
|
|
@ -45,9 +45,10 @@ use lemmy_utils::{
|
|||
validation::is_valid_actor_name,
|
||||
},
|
||||
};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashSet;
|
||||
use std::{collections::HashSet, sync::LazyLock};
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
|
||||
|
@ -225,6 +226,11 @@ pub async fn authenticate_with_oauth(
|
|||
Err(LemmyErrorType::OauthAuthorizationInvalid)?
|
||||
}
|
||||
|
||||
// validate the PKCE challenge
|
||||
if let Some(code_verifier) = &data.pkce_code_verifier {
|
||||
check_code_verifier(code_verifier)?;
|
||||
}
|
||||
|
||||
// Fetch the OAUTH provider and make sure it's enabled
|
||||
let oauth_provider_id = data.oauth_provider_id;
|
||||
let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id)
|
||||
|
@ -236,9 +242,14 @@ pub async fn authenticate_with_oauth(
|
|||
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
|
||||
}
|
||||
|
||||
let token_response =
|
||||
oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str())
|
||||
.await?;
|
||||
let token_response = oauth_request_access_token(
|
||||
&context,
|
||||
&oauth_provider,
|
||||
&data.code,
|
||||
data.pkce_code_verifier.as_deref(),
|
||||
redirect_uri.as_str(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let user_info = oidc_get_user_info(
|
||||
&context,
|
||||
|
@ -533,20 +544,27 @@ async fn oauth_request_access_token(
|
|||
context: &Data<LemmyContext>,
|
||||
oauth_provider: &OAuthProvider,
|
||||
code: &str,
|
||||
pkce_code_verifier: Option<&str>,
|
||||
redirect_uri: &str,
|
||||
) -> LemmyResult<TokenResponse> {
|
||||
let mut form = vec![
|
||||
("client_id", &*oauth_provider.client_id),
|
||||
("client_secret", &*oauth_provider.client_secret),
|
||||
("code", code),
|
||||
("grant_type", "authorization_code"),
|
||||
("redirect_uri", redirect_uri),
|
||||
];
|
||||
|
||||
if let Some(code_verifier) = pkce_code_verifier {
|
||||
form.push(("code_verifier", code_verifier));
|
||||
}
|
||||
|
||||
// Request an Access Token from the OAUTH provider
|
||||
let response = context
|
||||
.client()
|
||||
.post(oauth_provider.token_endpoint.as_str())
|
||||
.header("Accept", "application/json")
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("redirect_uri", redirect_uri),
|
||||
("client_id", &oauth_provider.client_id),
|
||||
("client_secret", &oauth_provider.client_secret),
|
||||
])
|
||||
.form(&form[..])
|
||||
.send()
|
||||
.await
|
||||
.with_lemmy_type(LemmyErrorType::OauthLoginFailed)?
|
||||
|
@ -596,3 +614,17 @@ fn read_user_info(user_info: &serde_json::Value, key: &str) -> LemmyResult<Strin
|
|||
}
|
||||
Err(LemmyErrorType::OauthLoginFailed)?
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
fn check_code_verifier(code_verifier: &str) -> LemmyResult<()> {
|
||||
static VALID_CODE_VERIFIER_REGEX: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9\-._~]{43,128}$").expect("compile regex"));
|
||||
|
||||
let check = VALID_CODE_VERIFIER_REGEX.is_match(code_verifier);
|
||||
|
||||
if check {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(LemmyErrorType::InvalidCodeVerifier.into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -660,6 +660,7 @@ diesel::table! {
|
|||
enabled -> Bool,
|
||||
published -> Timestamptz,
|
||||
updated -> Nullable<Timestamptz>,
|
||||
use_pkce -> Bool,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -62,6 +62,8 @@ pub struct OAuthProvider {
|
|||
pub published: DateTime<Utc>,
|
||||
#[cfg_attr(feature = "full", ts(optional))]
|
||||
pub updated: Option<DateTime<Utc>>,
|
||||
/// switch to enable or disable PKCE
|
||||
pub use_pkce: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize)]
|
||||
|
@ -83,6 +85,7 @@ impl Serialize for PublicOAuthProvider {
|
|||
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
|
||||
state.serialize_field("client_id", &self.0.client_id)?;
|
||||
state.serialize_field("scopes", &self.0.scopes)?;
|
||||
state.serialize_field("use_pkce", &self.0.use_pkce)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
@ -102,6 +105,7 @@ pub struct OAuthProviderInsertForm {
|
|||
pub scopes: String,
|
||||
pub auto_verify_email: Option<bool>,
|
||||
pub account_linking_enabled: Option<bool>,
|
||||
pub use_pkce: Option<bool>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
|
@ -118,6 +122,7 @@ pub struct OAuthProviderUpdateForm {
|
|||
pub scopes: Option<String>,
|
||||
pub auto_verify_email: Option<bool>,
|
||||
pub account_linking_enabled: Option<bool>,
|
||||
pub use_pkce: Option<bool>,
|
||||
pub enabled: Option<bool>,
|
||||
pub updated: Option<Option<DateTime<Utc>>>,
|
||||
}
|
||||
|
|
|
@ -76,6 +76,7 @@ pub enum LemmyErrorType {
|
|||
InvalidEmailAddress(String),
|
||||
RateLimitError,
|
||||
InvalidName,
|
||||
InvalidCodeVerifier,
|
||||
InvalidDisplayName,
|
||||
InvalidMatrixId,
|
||||
InvalidPostTitle,
|
||||
|
|
3
migrations/2024-11-23-234637_oauth_pkce/down.sql
Normal file
3
migrations/2024-11-23-234637_oauth_pkce/down.sql
Normal file
|
@ -0,0 +1,3 @@
|
|||
ALTER TABLE oauth_provider
|
||||
DROP COLUMN use_pkce;
|
||||
|
3
migrations/2024-11-23-234637_oauth_pkce/up.sql
Normal file
3
migrations/2024-11-23-234637_oauth_pkce/up.sql
Normal file
|
@ -0,0 +1,3 @@
|
|||
ALTER TABLE oauth_provider
|
||||
ADD COLUMN use_pkce boolean DEFAULT FALSE NOT NULL;
|
||||
|
Loading…
Reference in a new issue