diff --git a/crates/api_common/src/oauth_provider.rs b/crates/api_common/src/oauth_provider.rs index 36fef3b18..2f3344802 100644 --- a/crates/api_common/src/oauth_provider.rs +++ b/crates/api_common/src/oauth_provider.rs @@ -25,6 +25,8 @@ pub struct CreateOAuthProvider { #[cfg_attr(feature = "full", ts(optional))] pub account_linking_enabled: Option, #[cfg_attr(feature = "full", ts(optional))] + pub use_pkce: Option, + #[cfg_attr(feature = "full", ts(optional))] pub enabled: Option, } @@ -54,6 +56,8 @@ pub struct EditOAuthProvider { #[cfg_attr(feature = "full", ts(optional))] pub account_linking_enabled: Option, #[cfg_attr(feature = "full", ts(optional))] + pub use_pkce: Option, + #[cfg_attr(feature = "full", ts(optional))] pub enabled: Option, } @@ -82,4 +86,6 @@ pub struct AuthenticateWithOauth { /// An answer is mandatory if require application is enabled on the server #[cfg_attr(feature = "full", ts(optional))] pub answer: Option, + #[cfg_attr(feature = "full", ts(optional))] + pub pkce_code_verifier: Option, } diff --git a/crates/api_crud/Cargo.toml b/crates/api_crud/Cargo.toml index 3f1a00ccd..a05a4deed 100644 --- a/crates/api_crud/Cargo.toml +++ b/crates/api_crud/Cargo.toml @@ -29,6 +29,7 @@ anyhow.workspace = true chrono.workspace = true webmention = "0.6.0" accept-language = "3.1.0" +regex = { workspace = true } serde_json = { workspace = true } serde = { workspace = true } serde_with = { workspace = true } diff --git a/crates/api_crud/src/oauth_provider/create.rs b/crates/api_crud/src/oauth_provider/create.rs index fe44ae56e..c1e30066a 100644 --- a/crates/api_crud/src/oauth_provider/create.rs +++ b/crates/api_crud/src/oauth_provider/create.rs @@ -35,6 +35,7 @@ pub async fn create_oauth_provider( scopes: data.scopes.to_string(), auto_verify_email: data.auto_verify_email, account_linking_enabled: data.account_linking_enabled, + use_pkce: data.use_pkce, enabled: data.enabled, }; let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?; diff --git a/crates/api_crud/src/oauth_provider/update.rs b/crates/api_crud/src/oauth_provider/update.rs index 29ba19b49..f8631a487 100644 --- a/crates/api_crud/src/oauth_provider/update.rs +++ b/crates/api_crud/src/oauth_provider/update.rs @@ -33,6 +33,7 @@ pub async fn update_oauth_provider( auto_verify_email: data.auto_verify_email, account_linking_enabled: data.account_linking_enabled, enabled: data.enabled, + use_pkce: data.use_pkce, updated: Some(Some(Utc::now())), }; diff --git a/crates/api_crud/src/user/create.rs b/crates/api_crud/src/user/create.rs index deb65ec38..16156abe4 100644 --- a/crates/api_crud/src/user/create.rs +++ b/crates/api_crud/src/user/create.rs @@ -45,9 +45,10 @@ use lemmy_utils::{ validation::is_valid_actor_name, }, }; +use regex::Regex; use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; -use std::collections::HashSet; +use std::{collections::HashSet, sync::LazyLock}; #[skip_serializing_none] #[derive(Debug, Serialize, Deserialize, Clone, Default)] @@ -225,6 +226,11 @@ pub async fn authenticate_with_oauth( Err(LemmyErrorType::OauthAuthorizationInvalid)? } + // validate the PKCE challenge + if let Some(code_verifier) = &data.pkce_code_verifier { + check_code_verifier(code_verifier)?; + } + // Fetch the OAUTH provider and make sure it's enabled let oauth_provider_id = data.oauth_provider_id; let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id) @@ -236,9 +242,14 @@ pub async fn authenticate_with_oauth( return Err(LemmyErrorType::OauthAuthorizationInvalid)?; } - let token_response = - oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str()) - .await?; + let token_response = oauth_request_access_token( + &context, + &oauth_provider, + &data.code, + data.pkce_code_verifier.as_deref(), + redirect_uri.as_str(), + ) + .await?; let user_info = oidc_get_user_info( &context, @@ -533,20 +544,27 @@ async fn oauth_request_access_token( context: &Data, oauth_provider: &OAuthProvider, code: &str, + pkce_code_verifier: Option<&str>, redirect_uri: &str, ) -> LemmyResult { + let mut form = vec![ + ("client_id", &*oauth_provider.client_id), + ("client_secret", &*oauth_provider.client_secret), + ("code", code), + ("grant_type", "authorization_code"), + ("redirect_uri", redirect_uri), + ]; + + if let Some(code_verifier) = pkce_code_verifier { + form.push(("code_verifier", code_verifier)); + } + // Request an Access Token from the OAUTH provider let response = context .client() .post(oauth_provider.token_endpoint.as_str()) .header("Accept", "application/json") - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("client_id", &oauth_provider.client_id), - ("client_secret", &oauth_provider.client_secret), - ]) + .form(&form[..]) .send() .await .with_lemmy_type(LemmyErrorType::OauthLoginFailed)? @@ -596,3 +614,17 @@ fn read_user_info(user_info: &serde_json::Value, key: &str) -> LemmyResult LemmyResult<()> { + static VALID_CODE_VERIFIER_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9\-._~]{43,128}$").expect("compile regex")); + + let check = VALID_CODE_VERIFIER_REGEX.is_match(code_verifier); + + if check { + Ok(()) + } else { + Err(LemmyErrorType::InvalidCodeVerifier.into()) + } +} diff --git a/crates/db_schema/src/schema.rs b/crates/db_schema/src/schema.rs index 0bc046ece..66a65d143 100644 --- a/crates/db_schema/src/schema.rs +++ b/crates/db_schema/src/schema.rs @@ -660,6 +660,7 @@ diesel::table! { enabled -> Bool, published -> Timestamptz, updated -> Nullable, + use_pkce -> Bool, } } diff --git a/crates/db_schema/src/source/oauth_provider.rs b/crates/db_schema/src/source/oauth_provider.rs index a70405a5e..0a82ab9a9 100644 --- a/crates/db_schema/src/source/oauth_provider.rs +++ b/crates/db_schema/src/source/oauth_provider.rs @@ -62,6 +62,8 @@ pub struct OAuthProvider { pub published: DateTime, #[cfg_attr(feature = "full", ts(optional))] pub updated: Option>, + /// switch to enable or disable PKCE + pub use_pkce: bool, } #[derive(Clone, PartialEq, Eq, Debug, Deserialize)] @@ -83,6 +85,7 @@ impl Serialize for PublicOAuthProvider { state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?; state.serialize_field("client_id", &self.0.client_id)?; state.serialize_field("scopes", &self.0.scopes)?; + state.serialize_field("use_pkce", &self.0.use_pkce)?; state.end() } } @@ -102,6 +105,7 @@ pub struct OAuthProviderInsertForm { pub scopes: String, pub auto_verify_email: Option, pub account_linking_enabled: Option, + pub use_pkce: Option, pub enabled: Option, } @@ -118,6 +122,7 @@ pub struct OAuthProviderUpdateForm { pub scopes: Option, pub auto_verify_email: Option, pub account_linking_enabled: Option, + pub use_pkce: Option, pub enabled: Option, pub updated: Option>>, } diff --git a/crates/utils/src/error.rs b/crates/utils/src/error.rs index 40f878747..f45bc271f 100644 --- a/crates/utils/src/error.rs +++ b/crates/utils/src/error.rs @@ -76,6 +76,7 @@ pub enum LemmyErrorType { InvalidEmailAddress(String), RateLimitError, InvalidName, + InvalidCodeVerifier, InvalidDisplayName, InvalidMatrixId, InvalidPostTitle, diff --git a/migrations/2024-11-23-234637_oauth_pkce/down.sql b/migrations/2024-11-23-234637_oauth_pkce/down.sql new file mode 100644 index 000000000..50c09050a --- /dev/null +++ b/migrations/2024-11-23-234637_oauth_pkce/down.sql @@ -0,0 +1,3 @@ +ALTER TABLE oauth_provider + DROP COLUMN use_pkce; + diff --git a/migrations/2024-11-23-234637_oauth_pkce/up.sql b/migrations/2024-11-23-234637_oauth_pkce/up.sql new file mode 100644 index 000000000..b03d74f7f --- /dev/null +++ b/migrations/2024-11-23-234637_oauth_pkce/up.sql @@ -0,0 +1,3 @@ +ALTER TABLE oauth_provider + ADD COLUMN use_pkce boolean DEFAULT FALSE NOT NULL; +