From b26aaac523b9d03874d320889d93606dc5bbcce9 Mon Sep 17 00:00:00 2001 From: privacyguard <92675882+privacyguard@users.noreply.github.com> Date: Wed, 18 Sep 2024 15:52:33 +0300 Subject: [PATCH] SSO Support (#4881) * Added OAUTH2 OIDC support * Fixes and improvements based on review feedback * use derive_new::new instead of TypedBuilder * merge migrations into a single file * fixes based on review feedback * remove unnecessary hostname_ui config * improvement based on review feedback * improvements based on review feedback * delete user oauth accounts at account deletion * fixes and improvements based on review feedback * removed auto_approve_application * support registration application with sso * improvements based on review feedback * making the TokenResponse an internal struct as it should be * remove duplicate struct * prevent oauth linking to unverified accounts * switched to manually entered username and removed the oauth name claim * fix cargo fmt * fix compile error * improvements based on review feedback * fixes and improvements based on review feedback --------- Co-authored-by: privacyguard --- Cargo.lock | 7 +- crates/api/src/local_user/change_password.rs | 12 +- crates/api/src/local_user/login.rs | 45 +- crates/api/src/local_user/mod.rs | 15 - crates/api/src/local_user/reset_password.rs | 3 +- crates/api/src/site/leave_admin.rs | 4 + crates/api_common/src/lib.rs | 1 + crates/api_common/src/oauth_provider.rs | 69 +++ crates/api_common/src/request.rs | 1 + crates/api_common/src/site.rs | 7 + crates/api_common/src/utils.rs | 50 +- crates/api_crud/Cargo.toml | 3 + crates/api_crud/src/lib.rs | 1 + crates/api_crud/src/oauth_provider/create.rs | 42 ++ crates/api_crud/src/oauth_provider/delete.rs | 25 + crates/api_crud/src/oauth_provider/mod.rs | 3 + crates/api_crud/src/oauth_provider/update.rs | 44 ++ crates/api_crud/src/site/create.rs | 1 + crates/api_crud/src/site/read.rs | 19 +- crates/api_crud/src/site/update.rs | 13 + crates/api_crud/src/user/create.rs | 510 +++++++++++++++--- crates/api_crud/src/user/delete.rs | 18 +- crates/db_schema/src/impls/local_user.rs | 10 +- crates/db_schema/src/impls/mod.rs | 2 + crates/db_schema/src/impls/oauth_account.rs | 59 ++ crates/db_schema/src/impls/oauth_provider.rs | 71 +++ crates/db_schema/src/impls/person.rs | 12 + crates/db_schema/src/newtypes.rs | 6 + crates/db_schema/src/schema.rs | 37 +- crates/db_schema/src/source/local_site.rs | 4 + crates/db_schema/src/source/local_user.rs | 4 +- crates/db_schema/src/source/mod.rs | 2 + crates/db_schema/src/source/oauth_account.rs | 32 ++ crates/db_schema/src/source/oauth_provider.rs | 131 +++++ crates/db_schema/src/utils.rs | 26 +- crates/db_views/src/local_user_view.rs | 34 +- crates/routes/src/images.rs | 3 +- crates/utils/src/error.rs | 7 + .../down.sql | 10 + .../up.sql | 34 ++ src/api_routes_http.rs | 22 +- src/code_migrations.rs | 2 +- 42 files changed, 1235 insertions(+), 166 deletions(-) create mode 100644 crates/api_common/src/oauth_provider.rs create mode 100644 crates/api_crud/src/oauth_provider/create.rs create mode 100644 crates/api_crud/src/oauth_provider/delete.rs create mode 100644 crates/api_crud/src/oauth_provider/mod.rs create mode 100644 crates/api_crud/src/oauth_provider/update.rs create mode 100644 crates/db_schema/src/impls/oauth_account.rs create mode 100644 crates/db_schema/src/impls/oauth_provider.rs create mode 100644 crates/db_schema/src/source/oauth_account.rs create mode 100644 crates/db_schema/src/source/oauth_provider.rs create mode 100644 migrations/2024-09-16-174833_create_oauth_provider/down.sql create mode 100644 migrations/2024-09-16-174833_create_oauth_provider/up.sql diff --git a/Cargo.lock b/Cargo.lock index 31815bc00..c6fa6255b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2545,6 +2545,9 @@ dependencies = [ "lemmy_db_views_actor", "lemmy_utils", "moka", + "serde", + "serde_json", + "serde_with", "tracing", "url", "uuid", @@ -3314,9 +3317,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ea5043e58958ee56f3e15a90aee535795cd7dfd319846288d93c5b57d85cbe" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "overload" diff --git a/crates/api/src/local_user/change_password.rs b/crates/api/src/local_user/change_password.rs index 50ee10bb6..03f873a0f 100644 --- a/crates/api/src/local_user/change_password.rs +++ b/crates/api/src/local_user/change_password.rs @@ -28,11 +28,13 @@ pub async fn change_password( } // Check the old password - let valid: bool = verify( - &data.old_password, - &local_user_view.local_user.password_encrypted, - ) - .unwrap_or(false); + let valid: bool = if let Some(password_encrypted) = &local_user_view.local_user.password_encrypted + { + verify(&data.old_password, password_encrypted).unwrap_or(false) + } else { + data.old_password.is_empty() + }; + if !valid { Err(LemmyErrorType::IncorrectLogin)? } diff --git a/crates/api/src/local_user/login.rs b/crates/api/src/local_user/login.rs index e6ae38510..a8f65d758 100644 --- a/crates/api/src/local_user/login.rs +++ b/crates/api/src/local_user/login.rs @@ -1,4 +1,4 @@ -use crate::{check_totp_2fa_valid, local_user::check_email_verified}; +use crate::check_totp_2fa_valid; use actix_web::{ web::{Data, Json}, HttpRequest, @@ -8,12 +8,7 @@ use lemmy_api_common::{ claims::Claims, context::LemmyContext, person::{Login, LoginResponse}, - utils::check_user_valid, -}; -use lemmy_db_schema::{ - source::{local_site::LocalSite, registration_application::RegistrationApplication}, - utils::DbPool, - RegistrationMode, + utils::{check_email_verified, check_registration_application, check_user_valid}, }; use lemmy_db_views::structs::{LocalUserView, SiteView}; use lemmy_utils::error::{LemmyErrorType, LemmyResult}; @@ -34,11 +29,12 @@ pub async fn login( .ok_or(LemmyErrorType::IncorrectLogin)?; // Verify the password - let valid: bool = verify( - &data.password, - &local_user_view.local_user.password_encrypted, - ) - .unwrap_or(false); + let valid: bool = local_user_view + .local_user + .password_encrypted + .as_ref() + .and_then(|password_encrypted| verify(&data.password, password_encrypted).ok()) + .unwrap_or(false); if !valid { Err(LemmyErrorType::IncorrectLogin)? } @@ -65,28 +61,3 @@ pub async fn login( registration_created: false, })) } - -async fn check_registration_application( - local_user_view: &LocalUserView, - local_site: &LocalSite, - pool: &mut DbPool<'_>, -) -> LemmyResult<()> { - if (local_site.registration_mode == RegistrationMode::RequireApplication - || local_site.registration_mode == RegistrationMode::Closed) - && !local_user_view.local_user.accepted_application - && !local_user_view.local_user.admin - { - // Fetch the registration application. If no admin id is present its still pending. Otherwise it - // was processed (either accepted or denied). - let local_user_id = local_user_view.local_user.id; - let registration = RegistrationApplication::find_by_local_user_id(pool, local_user_id) - .await? - .ok_or(LemmyErrorType::CouldntFindRegistrationApplication)?; - if registration.admin_id.is_some() { - Err(LemmyErrorType::RegistrationDenied(registration.deny_reason))? - } else { - Err(LemmyErrorType::RegistrationApplicationIsPending)? - } - } - Ok(()) -} diff --git a/crates/api/src/local_user/mod.rs b/crates/api/src/local_user/mod.rs index c00a4516e..b1ee7c0b6 100644 --- a/crates/api/src/local_user/mod.rs +++ b/crates/api/src/local_user/mod.rs @@ -1,6 +1,3 @@ -use lemmy_db_views::structs::{LocalUserView, SiteView}; -use lemmy_utils::{error::LemmyResult, LemmyErrorType}; - pub mod add_admin; pub mod ban_person; pub mod block; @@ -20,15 +17,3 @@ pub mod save_settings; pub mod update_totp; pub mod validate_auth; pub mod verify_email; - -/// Check if the user's email is verified if email verification is turned on -/// However, skip checking verification if the user is an admin -fn check_email_verified(local_user_view: &LocalUserView, site_view: &SiteView) -> LemmyResult<()> { - if !local_user_view.local_user.admin - && site_view.local_site.require_email_verification - && !local_user_view.local_user.email_verified - { - Err(LemmyErrorType::EmailNotVerified)? - } - Ok(()) -} diff --git a/crates/api/src/local_user/reset_password.rs b/crates/api/src/local_user/reset_password.rs index 1c47e6c4e..4854d1376 100644 --- a/crates/api/src/local_user/reset_password.rs +++ b/crates/api/src/local_user/reset_password.rs @@ -1,9 +1,8 @@ -use crate::local_user::check_email_verified; use actix_web::web::{Data, Json}; use lemmy_api_common::{ context::LemmyContext, person::PasswordReset, - utils::send_password_reset_email, + utils::{check_email_verified, send_password_reset_email}, SuccessResponse, }; use lemmy_db_views::structs::{LocalUserView, SiteView}; diff --git a/crates/api/src/site/leave_admin.rs b/crates/api/src/site/leave_admin.rs index 52b8a32ef..d3581995a 100644 --- a/crates/api/src/site/leave_admin.rs +++ b/crates/api/src/site/leave_admin.rs @@ -7,6 +7,7 @@ use lemmy_db_schema::{ local_site_url_blocklist::LocalSiteUrlBlocklist, local_user::{LocalUser, LocalUserUpdateForm}, moderator::{ModAdd, ModAddForm}, + oauth_provider::OAuthProvider, tagline::Tagline, }, traits::Crud, @@ -63,6 +64,7 @@ pub async fn leave_admin( let taglines = Tagline::get_all(&mut context.pool(), site_view.local_site.id).await?; let custom_emojis = CustomEmojiView::get_all(&mut context.pool(), site_view.local_site.id).await?; + let oauth_providers = OAuthProvider::get_all_public(&mut context.pool()).await?; let blocked_urls = LocalSiteUrlBlocklist::get_all(&mut context.pool()).await?; Ok(Json(GetSiteResponse { @@ -74,6 +76,8 @@ pub async fn leave_admin( discussion_languages, taglines, custom_emojis, + oauth_providers: Some(oauth_providers), + admin_oauth_providers: None, blocked_urls, })) } diff --git a/crates/api_common/src/lib.rs b/crates/api_common/src/lib.rs index 9d12d2e13..48acaad21 100644 --- a/crates/api_common/src/lib.rs +++ b/crates/api_common/src/lib.rs @@ -7,6 +7,7 @@ pub mod community; #[cfg(feature = "full")] pub mod context; pub mod custom_emoji; +pub mod oauth_provider; pub mod person; pub mod post; pub mod private_message; diff --git a/crates/api_common/src/oauth_provider.rs b/crates/api_common/src/oauth_provider.rs new file mode 100644 index 000000000..c51edc7a4 --- /dev/null +++ b/crates/api_common/src/oauth_provider.rs @@ -0,0 +1,69 @@ +use lemmy_db_schema::newtypes::OAuthProviderId; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; +#[cfg(feature = "full")] +use ts_rs::TS; +use url::Url; + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[cfg_attr(feature = "full", derive(TS))] +#[cfg_attr(feature = "full", ts(export))] +/// Create an external auth method. +pub struct CreateOAuthProvider { + pub display_name: String, + pub issuer: String, + pub authorization_endpoint: String, + pub token_endpoint: String, + pub userinfo_endpoint: String, + pub id_claim: String, + pub client_id: String, + pub client_secret: String, + pub scopes: String, + pub auto_verify_email: bool, + pub account_linking_enabled: bool, + pub enabled: bool, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[cfg_attr(feature = "full", derive(TS))] +#[cfg_attr(feature = "full", ts(export))] +/// Edit an external auth method. +pub struct EditOAuthProvider { + pub id: OAuthProviderId, + pub display_name: Option, + pub authorization_endpoint: Option, + pub token_endpoint: Option, + pub userinfo_endpoint: Option, + pub id_claim: Option, + pub client_secret: Option, + pub scopes: Option, + pub auto_verify_email: Option, + pub account_linking_enabled: Option, + pub enabled: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +#[cfg_attr(feature = "full", derive(TS))] +#[cfg_attr(feature = "full", ts(export))] +/// Delete an external auth method. +pub struct DeleteOAuthProvider { + pub id: OAuthProviderId, +} + +#[skip_serializing_none] +#[derive(Debug, Serialize, Deserialize, Clone)] +#[cfg_attr(feature = "full", derive(TS))] +#[cfg_attr(feature = "full", ts(export))] +/// Logging in with an OAuth 2.0 authorization +pub struct AuthenticateWithOauth { + pub code: String, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub oauth_provider_id: OAuthProviderId, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub redirect_uri: Url, + pub show_nsfw: Option, + /// Username is mandatory at registration time + pub username: Option, + /// An answer is mandatory if require application is enabled on the server + pub answer: Option, +} diff --git a/crates/api_common/src/request.rs b/crates/api_common/src/request.rs index 90a626e4f..970aa17df 100644 --- a/crates/api_common/src/request.rs +++ b/crates/api_common/src/request.rs @@ -44,6 +44,7 @@ pub fn client_builder(settings: &Settings) -> ClientBuilder { .user_agent(user_agent.clone()) .timeout(REQWEST_TIMEOUT) .connect_timeout(REQWEST_TIMEOUT) + .use_rustls_tls() } /// Fetches metadata for the given link and optionally generates thumbnail. diff --git a/crates/api_common/src/site.rs b/crates/api_common/src/site.rs index fa43f2a39..6fa6e3700 100644 --- a/crates/api_common/src/site.rs +++ b/crates/api_common/src/site.rs @@ -16,6 +16,7 @@ use lemmy_db_schema::{ instance::Instance, language::Language, local_site_url_blocklist::LocalSiteUrlBlocklist, + oauth_provider::{OAuthProvider, PublicOAuthProvider}, person::Person, tagline::Tagline, }, @@ -200,6 +201,7 @@ pub struct CreateSite { pub blocked_instances: Option>, pub taglines: Option>, pub registration_mode: Option, + pub oauth_registration: Option, pub content_warning: Option, pub default_post_listing_mode: Option, } @@ -282,6 +284,8 @@ pub struct EditSite { /// A list of taglines shown at the top of the front page. pub taglines: Option>, pub registration_mode: Option, + /// Whether or not external auth methods can auto-register users. + pub oauth_registration: Option, /// Whether to email admins for new reports. pub reports_email_admins: Option, /// If present, nsfw content is visible by default. Should be displayed by frontends/clients @@ -316,6 +320,9 @@ pub struct GetSiteResponse { pub taglines: Vec, /// A list of custom emojis your site supports. pub custom_emojis: Vec, + /// A list of external auth methods your site supports. + pub oauth_providers: Option>, + pub admin_oauth_providers: Option>, pub blocked_urls: Vec, } diff --git a/crates/api_common/src/utils.rs b/crates/api_common/src/utils.rs index ebcc237e5..e41b574c5 100644 --- a/crates/api_common/src/utils.rs +++ b/crates/api_common/src/utils.rs @@ -23,18 +23,21 @@ use lemmy_db_schema::{ local_site::LocalSite, local_site_rate_limit::LocalSiteRateLimit, local_site_url_blocklist::LocalSiteUrlBlocklist, + oauth_account::OAuthAccount, password_reset_request::PasswordResetRequest, person::{Person, PersonUpdateForm}, person_block::PersonBlock, post::{Post, PostRead}, + registration_application::RegistrationApplication, site::Site, }, traits::Crud, utils::DbPool, + RegistrationMode, }; use lemmy_db_views::{ comment_view::CommentQuery, - structs::{LocalImageView, LocalUserView}, + structs::{LocalImageView, LocalUserView, SiteView}, }; use lemmy_db_views_actor::structs::{ CommunityModeratorView, @@ -192,6 +195,46 @@ pub fn check_user_valid(person: &Person) -> LemmyResult<()> { } } +/// Check if the user's email is verified if email verification is turned on +/// However, skip checking verification if the user is an admin +pub fn check_email_verified( + local_user_view: &LocalUserView, + site_view: &SiteView, +) -> LemmyResult<()> { + if !local_user_view.local_user.admin + && site_view.local_site.require_email_verification + && !local_user_view.local_user.email_verified + { + Err(LemmyErrorType::EmailNotVerified)? + } + Ok(()) +} + +pub async fn check_registration_application( + local_user_view: &LocalUserView, + local_site: &LocalSite, + pool: &mut DbPool<'_>, +) -> LemmyResult<()> { + if (local_site.registration_mode == RegistrationMode::RequireApplication + || local_site.registration_mode == RegistrationMode::Closed) + && !local_user_view.local_user.accepted_application + && !local_user_view.local_user.admin + { + // Fetch the registration application. If no admin id is present its still pending. Otherwise it + // was processed (either accepted or denied). + let local_user_id = local_user_view.local_user.id; + let registration = RegistrationApplication::find_by_local_user_id(pool, local_user_id) + .await? + .ok_or(LemmyErrorType::CouldntFindRegistrationApplication)?; + if registration.admin_id.is_some() { + Err(LemmyErrorType::RegistrationDenied(registration.deny_reason))? + } else { + Err(LemmyErrorType::RegistrationApplicationIsPending)? + } + } + Ok(()) +} + /// Checks that a normal user action (eg posting or voting) is allowed in a given community. /// /// In particular it checks that neither the user nor community are banned or deleted, and that @@ -852,6 +895,11 @@ pub async fn purge_user_account(person_id: PersonId, context: &LemmyContext) -> // Leave communities they mod CommunityModerator::leave_all_communities(pool, person_id).await?; + // Delete the oauth accounts linked to the local user + if let Ok(Some(local_user)) = LocalUserView::read_person(pool, person_id).await { + OAuthAccount::delete_user_accounts(pool, local_user.local_user.id).await?; + } + Person::delete_account(pool, person_id).await?; Ok(()) diff --git a/crates/api_crud/Cargo.toml b/crates/api_crud/Cargo.toml index 5114eb8cf..259116a38 100644 --- a/crates/api_crud/Cargo.toml +++ b/crates/api_crud/Cargo.toml @@ -29,6 +29,9 @@ moka.workspace = true anyhow.workspace = true webmention = "0.6.0" accept-language = "3.1.0" +serde_json = { workspace = true } +serde = { workspace = true } +serde_with = { workspace = true } [package.metadata.cargo-machete] ignored = ["futures"] diff --git a/crates/api_crud/src/lib.rs b/crates/api_crud/src/lib.rs index aee3e8134..b138fbd30 100644 --- a/crates/api_crud/src/lib.rs +++ b/crates/api_crud/src/lib.rs @@ -1,6 +1,7 @@ pub mod comment; pub mod community; pub mod custom_emoji; +pub mod oauth_provider; pub mod post; pub mod private_message; pub mod site; diff --git a/crates/api_crud/src/oauth_provider/create.rs b/crates/api_crud/src/oauth_provider/create.rs new file mode 100644 index 000000000..fe44ae56e --- /dev/null +++ b/crates/api_crud/src/oauth_provider/create.rs @@ -0,0 +1,42 @@ +use activitypub_federation::config::Data; +use actix_web::web::Json; +use lemmy_api_common::{ + context::LemmyContext, + oauth_provider::CreateOAuthProvider, + utils::is_admin, +}; +use lemmy_db_schema::{ + source::oauth_provider::{OAuthProvider, OAuthProviderInsertForm}, + traits::Crud, +}; +use lemmy_db_views::structs::LocalUserView; +use lemmy_utils::error::LemmyError; +use url::Url; + +#[tracing::instrument(skip(context))] +pub async fn create_oauth_provider( + data: Json, + context: Data, + local_user_view: LocalUserView, +) -> Result, LemmyError> { + // Make sure user is an admin + is_admin(&local_user_view)?; + + let cloned_data = data.clone(); + let oauth_provider_form = OAuthProviderInsertForm { + display_name: cloned_data.display_name, + issuer: Url::parse(&cloned_data.issuer)?.into(), + authorization_endpoint: Url::parse(&cloned_data.authorization_endpoint)?.into(), + token_endpoint: Url::parse(&cloned_data.token_endpoint)?.into(), + userinfo_endpoint: Url::parse(&cloned_data.userinfo_endpoint)?.into(), + id_claim: cloned_data.id_claim, + client_id: data.client_id.to_string(), + client_secret: data.client_secret.to_string(), + scopes: data.scopes.to_string(), + auto_verify_email: data.auto_verify_email, + account_linking_enabled: data.account_linking_enabled, + enabled: data.enabled, + }; + let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?; + Ok(Json(oauth_provider)) +} diff --git a/crates/api_crud/src/oauth_provider/delete.rs b/crates/api_crud/src/oauth_provider/delete.rs new file mode 100644 index 000000000..0d4d616cc --- /dev/null +++ b/crates/api_crud/src/oauth_provider/delete.rs @@ -0,0 +1,25 @@ +use activitypub_federation::config::Data; +use actix_web::web::Json; +use lemmy_api_common::{ + context::LemmyContext, + oauth_provider::DeleteOAuthProvider, + utils::is_admin, + SuccessResponse, +}; +use lemmy_db_schema::{source::oauth_provider::OAuthProvider, traits::Crud}; +use lemmy_db_views::structs::LocalUserView; +use lemmy_utils::error::{LemmyError, LemmyErrorExt, LemmyErrorType}; + +#[tracing::instrument(skip(context))] +pub async fn delete_oauth_provider( + data: Json, + context: Data, + local_user_view: LocalUserView, +) -> Result, LemmyError> { + // Make sure user is an admin + is_admin(&local_user_view)?; + OAuthProvider::delete(&mut context.pool(), data.id) + .await + .with_lemmy_type(LemmyErrorType::CouldntDeleteOauthProvider)?; + Ok(Json(SuccessResponse::default())) +} diff --git a/crates/api_crud/src/oauth_provider/mod.rs b/crates/api_crud/src/oauth_provider/mod.rs new file mode 100644 index 000000000..fdb2f5561 --- /dev/null +++ b/crates/api_crud/src/oauth_provider/mod.rs @@ -0,0 +1,3 @@ +pub mod create; +pub mod delete; +pub mod update; diff --git a/crates/api_crud/src/oauth_provider/update.rs b/crates/api_crud/src/oauth_provider/update.rs new file mode 100644 index 000000000..61d5b0adc --- /dev/null +++ b/crates/api_crud/src/oauth_provider/update.rs @@ -0,0 +1,44 @@ +use activitypub_federation::config::Data; +use actix_web::web::Json; +use lemmy_api_common::{context::LemmyContext, oauth_provider::EditOAuthProvider, utils::is_admin}; +use lemmy_db_schema::{ + source::oauth_provider::{OAuthProvider, OAuthProviderUpdateForm}, + traits::Crud, + utils::{diesel_required_string_update, diesel_required_url_update, naive_now}, +}; +use lemmy_db_views::structs::LocalUserView; +use lemmy_utils::{error::LemmyError, LemmyErrorType}; + +#[tracing::instrument(skip(context))] +pub async fn update_oauth_provider( + data: Json, + context: Data, + local_user_view: LocalUserView, +) -> Result, LemmyError> { + // Make sure user is an admin + is_admin(&local_user_view)?; + + let cloned_data = data.clone(); + let oauth_provider_form = OAuthProviderUpdateForm { + display_name: diesel_required_string_update(cloned_data.display_name.as_deref()), + authorization_endpoint: diesel_required_url_update( + cloned_data.authorization_endpoint.as_deref(), + )?, + token_endpoint: diesel_required_url_update(cloned_data.token_endpoint.as_deref())?, + userinfo_endpoint: diesel_required_url_update(cloned_data.userinfo_endpoint.as_deref())?, + id_claim: diesel_required_string_update(data.id_claim.as_deref()), + client_secret: diesel_required_string_update(data.client_secret.as_deref()), + scopes: diesel_required_string_update(data.scopes.as_deref()), + auto_verify_email: data.auto_verify_email, + account_linking_enabled: data.account_linking_enabled, + enabled: data.enabled, + updated: Some(Some(naive_now())), + }; + + let update_result = + OAuthProvider::update(&mut context.pool(), data.id, &oauth_provider_form).await?; + let oauth_provider = OAuthProvider::read(&mut context.pool(), update_result.id) + .await? + .ok_or(LemmyErrorType::CouldntFindOauthProvider)?; + Ok(Json(oauth_provider)) +} diff --git a/crates/api_crud/src/site/create.rs b/crates/api_crud/src/site/create.rs index 6566a7a9f..3d96d20cf 100644 --- a/crates/api_crud/src/site/create.rs +++ b/crates/api_crud/src/site/create.rs @@ -591,6 +591,7 @@ mod tests { blocked_instances: None, taglines: None, registration_mode: site_registration_mode, + oauth_registration: None, content_warning: None, default_post_listing_mode: None, } diff --git a/crates/api_crud/src/site/read.rs b/crates/api_crud/src/site/read.rs index 94a28a4ad..6f524dd7d 100644 --- a/crates/api_crud/src/site/read.rs +++ b/crates/api_crud/src/site/read.rs @@ -9,6 +9,7 @@ use lemmy_db_schema::source::{ instance_block::InstanceBlock, language::Language, local_site_url_blocklist::LocalSiteUrlBlocklist, + oauth_provider::OAuthProvider, person_block::PersonBlock, tagline::Tagline, }; @@ -45,6 +46,10 @@ pub async fn get_site( let custom_emojis = CustomEmojiView::get_all(&mut context.pool(), site_view.local_site.id).await?; let blocked_urls = LocalSiteUrlBlocklist::get_all(&mut context.pool()).await?; + let admin_oauth_providers = OAuthProvider::get_all(&mut context.pool()).await?; + let oauth_providers = + OAuthProvider::convert_providers_to_public(admin_oauth_providers.clone()); + Ok(GetSiteResponse { site_view, admins, @@ -55,13 +60,15 @@ pub async fn get_site( taglines, custom_emojis, blocked_urls, + oauth_providers: Some(oauth_providers), + admin_oauth_providers: Some(admin_oauth_providers), }) }) .await .map_err(|e| anyhow::anyhow!("Failed to construct site response: {e}"))?; // Build the local user with parallel queries and add it to site response - site_response.my_user = if let Some(local_user_view) = local_user_view { + site_response.my_user = if let Some(ref local_user_view) = local_user_view { let person_id = local_user_view.person.id; let local_user_id = local_user_view.local_user.id; let pool = &mut context.pool(); @@ -84,7 +91,7 @@ pub async fn get_site( .with_lemmy_type(LemmyErrorType::SystemErrLogin)?; Some(MyUserInfo { - local_user_view, + local_user_view: local_user_view.clone(), follows, moderates, community_blocks, @@ -96,5 +103,13 @@ pub async fn get_site( None }; + // filter oauth_providers for public access + if !local_user_view + .map(|l| l.local_user.admin) + .unwrap_or_default() + { + site_response.admin_oauth_providers = None; + } + Ok(Json(site_response)) } diff --git a/crates/api_crud/src/site/update.rs b/crates/api_crud/src/site/update.rs index f68b00c04..7e9dc8f03 100644 --- a/crates/api_crud/src/site/update.rs +++ b/crates/api_crud/src/site/update.rs @@ -119,6 +119,7 @@ pub async fn update_site( captcha_difficulty: data.captcha_difficulty.clone(), reports_email_admins: data.reports_email_admins, default_post_listing_mode: data.default_post_listing_mode, + oauth_registration: data.oauth_registration, ..Default::default() }; @@ -278,6 +279,7 @@ mod tests { None::, None::, None::, + None::, ), ), ( @@ -301,6 +303,7 @@ mod tests { None::, None::, None::, + None::, ), ), ( @@ -324,6 +327,7 @@ mod tests { None::, None::, None::, + None::, ), ), ( @@ -347,6 +351,7 @@ mod tests { Some(true), None::, None::, + None::, ), ), ( @@ -370,6 +375,7 @@ mod tests { Some(true), None::, None::, + None::, ), ), ( @@ -393,6 +399,7 @@ mod tests { None::, None::, Some(RegistrationMode::RequireApplication), + None::, ), ), ]; @@ -447,6 +454,7 @@ mod tests { None::, None::, None::, + None::, ), ), ( @@ -469,6 +477,7 @@ mod tests { Some(true), Some(String::new()), Some(RegistrationMode::Open), + None::, ), ), ( @@ -491,6 +500,7 @@ mod tests { None::, None::, None::, + None::, ), ), ( @@ -513,6 +523,7 @@ mod tests { None::, None::, Some(RegistrationMode::RequireApplication), + None::, ), ), ]; @@ -561,6 +572,7 @@ mod tests { site_is_federated: Option, site_application_question: Option, site_registration_mode: Option, + site_oauth_registration: Option, ) -> EditSite { EditSite { name: site_name, @@ -607,6 +619,7 @@ mod tests { reports_email_admins: None, content_warning: None, default_post_listing_mode: None, + oauth_registration: site_oauth_registration, } } } diff --git a/crates/api_crud/src/user/create.rs b/crates/api_crud/src/user/create.rs index b717d8816..1fb14a6e2 100644 --- a/crates/api_crud/src/user/create.rs +++ b/crates/api_crud/src/user/create.rs @@ -3,8 +3,12 @@ use actix_web::{web::Json, HttpRequest}; use lemmy_api_common::{ claims::Claims, context::LemmyContext, + oauth_provider::AuthenticateWithOauth, person::{LoginResponse, Register}, utils::{ + check_email_verified, + check_registration_application, + check_user_valid, generate_inbox_url, generate_local_apub_endpoint, generate_shared_inbox_url, @@ -18,11 +22,15 @@ use lemmy_api_common::{ }; use lemmy_db_schema::{ aggregates::structs::PersonAggregates, + newtypes::{InstanceId, OAuthProviderId}, source::{ captcha_answer::{CaptchaAnswer, CheckCaptchaAnswer}, language::Language, + local_site::LocalSite, local_user::{LocalUser, LocalUserInsertForm}, local_user_vote_display_mode::LocalUserVoteDisplayMode, + oauth_account::{OAuthAccount, OAuthAccountInsertForm}, + oauth_provider::OAuthProvider, person::{Person, PersonInsertForm}, registration_application::{RegistrationApplication, RegistrationApplicationInsertForm}, }, @@ -31,15 +39,27 @@ use lemmy_db_schema::{ }; use lemmy_db_views::structs::{LocalUserView, SiteView}; use lemmy_utils::{ - error::{LemmyErrorExt, LemmyErrorType, LemmyResult}, + error::{LemmyError, LemmyErrorExt, LemmyErrorType, LemmyResult}, utils::{ slurs::{check_slurs, check_slurs_opt}, validation::is_valid_actor_name, }, }; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; use std::collections::HashSet; -#[tracing::instrument(skip(context))] +#[skip_serializing_none] +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +/// Response from OAuth token endpoint +struct TokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: Option, + pub refresh_token: Option, + pub scope: Option, +} + pub async fn register( data: Json, req: HttpRequest, @@ -61,8 +81,9 @@ pub async fn register( Err(LemmyErrorType::EmailRequired)? } - if local_site.site_setup && require_registration_application && data.answer.is_none() { - Err(LemmyErrorType::RegistrationApplicationAnswerRequired)? + // make sure the registration answer is provided when the registration application is required + if local_site.site_setup { + validate_registration_answer(require_registration_application, &data.answer)?; } // Make sure passwords match @@ -93,13 +114,9 @@ pub async fn register( check_slurs(&data.username, &slur_regex)?; check_slurs_opt(&data.answer, &slur_regex)?; - let actor_keypair = generate_actor_keypair()?; - is_valid_actor_name(&data.username, local_site.actor_name_max_length as usize)?; - let actor_id = generate_local_apub_endpoint( - EndpointType::Person, - &data.username, - &context.settings().get_protocol_and_hostname(), - )?; + if Person::is_username_taken(&mut context.pool(), &data.username).await? { + return Err(LemmyErrorType::UsernameAlreadyExists)?; + } if let Some(email) = &data.email { if LocalUser::is_email_taken(&mut context.pool(), email).await? { @@ -108,49 +125,28 @@ pub async fn register( } // We have to create both a person, and local_user - - // Register the new person - let person_form = PersonInsertForm { - actor_id: Some(actor_id.clone()), - inbox_url: Some(generate_inbox_url(&actor_id)?), - shared_inbox_url: Some(generate_shared_inbox_url(context.settings())?), - private_key: Some(actor_keypair.private_key), - ..PersonInsertForm::new( - data.username.clone(), - actor_keypair.public_key, - site_view.site.instance_id, - ) - }; - - // insert the person - let inserted_person = Person::create(&mut context.pool(), &person_form) - .await - .with_lemmy_type(LemmyErrorType::UserAlreadyExists)?; + let inserted_person = create_person( + data.username.clone(), + &local_site, + site_view.site.instance_id, + &context, + ) + .await?; // Automatically set their application as accepted, if they created this with open registration. // Also fixes a bug which allows users to log in when registrations are changed to closed. let accepted_application = Some(!require_registration_application); - // Get the user's preferred language using the Accept-Language header - let language_tags: Vec = req - .headers() - .get("Accept-Language") - .map(|hdr| accept_language::parse(hdr.to_str().unwrap_or_default())) - .iter() - .flatten() - // Remove the optional region code - .map(|lang_str| lang_str.split('-').next().unwrap_or_default().to_string()) - .collect(); - // Show nsfw content if param is true, or if content_warning exists let show_nsfw = data .show_nsfw .unwrap_or(site_view.site.content_warning.is_some()); + let language_tags = get_language_tags(&req); + // Create the local user let local_user_form = LocalUserInsertForm { email: data.email.as_deref().map(str::to_lowercase), - password_encrypted: data.password.to_string(), show_nsfw: Some(show_nsfw), accepted_application, default_listing_type: Some(local_site.default_post_listing_type), @@ -158,21 +154,10 @@ pub async fn register( interface_language: language_tags.first().cloned(), // If its the initial site setup, they are an admin admin: Some(!local_site.site_setup), - ..LocalUserInsertForm::new(inserted_person.id, data.password.to_string()) + ..LocalUserInsertForm::new(inserted_person.id, Some(data.password.to_string())) }; - let all_languages = Language::read_all(&mut context.pool()).await?; - // use hashset to avoid duplicates - let mut language_ids = HashSet::new(); - for l in language_tags { - if let Some(found) = all_languages.iter().find(|all| all.code == l) { - language_ids.insert(found.id); - } - } - let language_ids = language_ids.into_iter().collect(); - - let inserted_local_user = - LocalUser::create(&mut context.pool(), &local_user_form, language_ids).await?; + let inserted_local_user = create_local_user(&context, language_tags, &local_user_form).await?; if local_site.site_setup && require_registration_application { // Create the registration application @@ -205,29 +190,13 @@ pub async fn register( let jwt = Claims::generate(inserted_local_user.id, req, &context).await?; login_response.jwt = Some(jwt); } else { - if local_site.require_email_verification { - let local_user_view = LocalUserView { - local_user: inserted_local_user, - local_user_vote_display_mode: LocalUserVoteDisplayMode::default(), - person: inserted_person, - counts: PersonAggregates::default(), - }; - // we check at the beginning of this method that email is set - let email = local_user_view - .local_user - .email - .clone() - .expect("email was provided"); - - send_verification_email( - &local_user_view, - &email, - &mut context.pool(), - context.settings(), - ) - .await?; - login_response.verify_email_sent = true; - } + login_response.verify_email_sent = send_verification_email_if_required( + &context, + &local_site, + &inserted_local_user, + &inserted_person, + ) + .await?; if require_registration_application { login_response.registration_created = true; @@ -236,3 +205,390 @@ pub async fn register( Ok(Json(login_response)) } + +#[tracing::instrument(skip(context))] +pub async fn authenticate_with_oauth( + data: Json, + req: HttpRequest, + context: Data, +) -> LemmyResult> { + let site_view = SiteView::read_local(&mut context.pool()).await?; + let local_site = site_view.local_site.clone(); + + // validate inputs + if data.oauth_provider_id == OAuthProviderId(0) || data.code.is_empty() || data.code.len() > 300 { + return Err(LemmyErrorType::OauthAuthorizationInvalid)?; + } + + // validate the redirect_uri + let redirect_uri = &data.redirect_uri; + if redirect_uri.host_str().unwrap_or("").is_empty() + || !redirect_uri.path().eq(&String::from("/oauth/callback")) + || !redirect_uri.query().unwrap_or("").is_empty() + { + Err(LemmyErrorType::OauthAuthorizationInvalid)? + } + + // Fetch the OAUTH provider and make sure it's enabled + let oauth_provider_id = data.oauth_provider_id; + let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id) + .await + .ok() + .flatten() + .ok_or(LemmyErrorType::OauthAuthorizationInvalid)?; + + if !oauth_provider.enabled { + return Err(LemmyErrorType::OauthAuthorizationInvalid)?; + } + + let token_response = + oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str()) + .await?; + + let user_info = oidc_get_user_info( + &context, + &oauth_provider, + token_response.access_token.as_str(), + ) + .await?; + + let oauth_user_id = read_user_info(&user_info, oauth_provider.id_claim.as_str())?; + + let mut login_response = LoginResponse { + jwt: None, + registration_created: false, + verify_email_sent: false, + }; + + // Lookup user by oauth_user_id + let mut local_user_view = + LocalUserView::find_by_oauth_id(&mut context.pool(), oauth_provider.id, &oauth_user_id).await?; + + let local_user: LocalUser; + if let Some(user_view) = local_user_view { + // user found by oauth_user_id => Login user + local_user = user_view.clone().local_user; + + check_user_valid(&user_view.person)?; + check_email_verified(&user_view, &site_view)?; + check_registration_application(&user_view, &site_view.local_site, &mut context.pool()).await?; + } else { + // user has never previously registered using oauth + + // prevent registration if registration is closed + if local_site.registration_mode == RegistrationMode::Closed { + Err(LemmyErrorType::RegistrationClosed)? + } + + // prevent registration if registration is closed for OAUTH providers + if !local_site.oauth_registration { + return Err(LemmyErrorType::OauthRegistrationClosed)?; + } + + // Extract the OAUTH email claim from the returned user_info + let email = read_user_info(&user_info, "email")?; + + let require_registration_application = + local_site.registration_mode == RegistrationMode::RequireApplication; + + // Lookup user by OAUTH email and link accounts + local_user_view = LocalUserView::find_by_email(&mut context.pool(), &email).await?; + + let person; + if let Some(user_view) = local_user_view { + // user found by email => link and login if linking is allowed + + // we only allow linking by email when email_verification is required otherwise emails cannot + // be trusted + if oauth_provider.account_linking_enabled && site_view.local_site.require_email_verification { + // WARNING: + // If an admin switches the require_email_verification config from false to true, + // users who signed up before the switch could have accounts with unverified emails falsely + // marked as verified. + + check_user_valid(&user_view.person)?; + check_email_verified(&user_view, &site_view)?; + check_registration_application(&user_view, &site_view.local_site, &mut context.pool()) + .await?; + + // Link with OAUTH => Login user + let oauth_account_form = + OAuthAccountInsertForm::new(user_view.local_user.id, oauth_provider.id, oauth_user_id); + + OAuthAccount::create(&mut context.pool(), &oauth_account_form) + .await + .map_err(|_| LemmyErrorType::OauthLoginFailed)?; + + local_user = user_view.local_user.clone(); + } else { + return Err(LemmyErrorType::EmailAlreadyExists)?; + } + } else { + // No user was found by email => Register as new user + + // make sure the registration answer is provided when the registration application is required + validate_registration_answer(require_registration_application, &data.answer)?; + + // make sure the username is provided + let username = data + .username + .as_ref() + .ok_or(LemmyErrorType::RegistrationUsernameRequired)?; + + let slur_regex = local_site_to_slur_regex(&local_site); + check_slurs(username, &slur_regex)?; + check_slurs_opt(&data.answer, &slur_regex)?; + + if Person::is_username_taken(&mut context.pool(), username).await? { + return Err(LemmyErrorType::UsernameAlreadyExists)?; + } + + // We have to create a person, a local_user, and an oauth_account + person = create_person( + username.clone(), + &local_site, + site_view.site.instance_id, + &context, + ) + .await?; + + // Show nsfw content if param is true, or if content_warning exists + let show_nsfw = data + .show_nsfw + .unwrap_or(site_view.site.content_warning.is_some()); + + let language_tags = get_language_tags(&req); + + // Create the local user + let local_user_form = LocalUserInsertForm { + email: Some(str::to_lowercase(&email)), + show_nsfw: Some(show_nsfw), + accepted_application: Some(!require_registration_application), + email_verified: Some(oauth_provider.auto_verify_email), + post_listing_mode: Some(local_site.default_post_listing_mode), + interface_language: language_tags.first().cloned(), + // If its the initial site setup, they are an admin + admin: Some(!local_site.site_setup), + ..LocalUserInsertForm::new(person.id, None) + }; + + local_user = create_local_user(&context, language_tags, &local_user_form).await?; + + // Create the oauth account + let oauth_account_form = + OAuthAccountInsertForm::new(local_user.id, oauth_provider.id, oauth_user_id); + + OAuthAccount::create(&mut context.pool(), &oauth_account_form) + .await + .map_err(|_| LemmyErrorType::IncorrectLogin)?; + + // prevent sign in until application is accepted + if local_site.site_setup + && require_registration_application + && !local_user.accepted_application + && !local_user.admin + { + // Create the registration application + RegistrationApplication::create( + &mut context.pool(), + &RegistrationApplicationInsertForm { + local_user_id: local_user.id, + answer: data.answer.clone().expect("must have an answer"), + }, + ) + .await?; + + login_response.registration_created = true; + } + + // Check email is verified when required + login_response.verify_email_sent = + send_verification_email_if_required(&context, &local_site, &local_user, &person).await?; + } + } + + if !login_response.registration_created && !login_response.verify_email_sent { + let jwt = Claims::generate(local_user.id, req, &context).await?; + login_response.jwt = Some(jwt); + } + + return Ok(Json(login_response)); +} + +async fn create_person( + username: String, + local_site: &LocalSite, + instance_id: InstanceId, + context: &Data, +) -> Result { + let actor_keypair = generate_actor_keypair()?; + is_valid_actor_name(&username, local_site.actor_name_max_length as usize)?; + let actor_id = generate_local_apub_endpoint( + EndpointType::Person, + &username, + &context.settings().get_protocol_and_hostname(), + )?; + + // Register the new person + let person_form = PersonInsertForm { + actor_id: Some(actor_id.clone()), + inbox_url: Some(generate_inbox_url(&actor_id)?), + shared_inbox_url: Some(generate_shared_inbox_url(context.settings())?), + private_key: Some(actor_keypair.private_key), + ..PersonInsertForm::new(username.clone(), actor_keypair.public_key, instance_id) + }; + + // insert the person + let inserted_person = Person::create(&mut context.pool(), &person_form) + .await + .with_lemmy_type(LemmyErrorType::UserAlreadyExists)?; + + Ok(inserted_person) +} + +fn get_language_tags(req: &HttpRequest) -> Vec { + req + .headers() + .get("Accept-Language") + .map(|hdr| accept_language::parse(hdr.to_str().unwrap_or_default())) + .iter() + .flatten() + // Remove the optional region code + .map(|lang_str| lang_str.split('-').next().unwrap_or_default().to_string()) + .collect::>() +} + +async fn create_local_user( + context: &Data, + language_tags: Vec, + local_user_form: &LocalUserInsertForm, +) -> Result { + let all_languages = Language::read_all(&mut context.pool()).await?; + // use hashset to avoid duplicates + let mut language_ids = HashSet::new(); + for l in language_tags { + if let Some(found) = all_languages.iter().find(|all| all.code == l) { + language_ids.insert(found.id); + } + } + let language_ids = language_ids.into_iter().collect(); + + let inserted_local_user = + LocalUser::create(&mut context.pool(), local_user_form, language_ids).await?; + + Ok(inserted_local_user) +} + +async fn send_verification_email_if_required( + context: &Data, + local_site: &LocalSite, + local_user: &LocalUser, + person: &Person, +) -> LemmyResult { + let mut sent = false; + if !local_user.admin && local_site.require_email_verification && !local_user.email_verified { + let local_user_view = LocalUserView { + local_user: local_user.clone(), + local_user_vote_display_mode: LocalUserVoteDisplayMode::default(), + person: person.clone(), + counts: PersonAggregates::default(), + }; + + send_verification_email( + &local_user_view, + &local_user + .email + .clone() + .expect("invalid verification email"), + &mut context.pool(), + context.settings(), + ) + .await?; + + sent = true; + } + Ok(sent) +} + +fn validate_registration_answer( + require_registration_application: bool, + answer: &Option, +) -> LemmyResult<()> { + if require_registration_application && answer.is_none() { + Err(LemmyErrorType::RegistrationApplicationAnswerRequired)? + } + + Ok(()) +} + +async fn oauth_request_access_token( + context: &Data, + oauth_provider: &OAuthProvider, + code: &str, + redirect_uri: &str, +) -> LemmyResult { + // Request an Access Token from the OAUTH provider + let response = context + .client() + .post(oauth_provider.token_endpoint.as_str()) + .header("Accept", "application/json") + .form(&[ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", redirect_uri), + ("client_id", &oauth_provider.client_id), + ("client_secret", &oauth_provider.client_secret), + ]) + .send() + .await; + + let response = response.map_err(|_| LemmyErrorType::OauthLoginFailed)?; + if !response.status().is_success() { + Err(LemmyErrorType::OauthLoginFailed)?; + } + + // Extract the access token + let token_response = response + .json::() + .await + .map_err(|_| LemmyErrorType::OauthLoginFailed)?; + + Ok(token_response) +} + +async fn oidc_get_user_info( + context: &Data, + oauth_provider: &OAuthProvider, + access_token: &str, +) -> LemmyResult { + // Request the user info from the OAUTH provider + let response = context + .client() + .get(oauth_provider.userinfo_endpoint.as_str()) + .header("Accept", "application/json") + .bearer_auth(access_token) + .send() + .await; + + let response = response.map_err(|_| LemmyErrorType::OauthLoginFailed)?; + if !response.status().is_success() { + Err(LemmyErrorType::OauthLoginFailed)?; + } + + // Extract the OAUTH user_id claim from the returned user_info + let user_info = response + .json::() + .await + .map_err(|_| LemmyErrorType::OauthLoginFailed)?; + + Ok(user_info) +} + +fn read_user_info(user_info: &serde_json::Value, key: &str) -> LemmyResult { + if let Some(value) = user_info.get(key) { + let result = serde_json::from_value::(value.clone()) + .map_err(|_| LemmyErrorType::OauthLoginFailed)?; + return Ok(result); + } + Err(LemmyErrorType::OauthLoginFailed)? +} diff --git a/crates/api_crud/src/user/delete.rs b/crates/api_crud/src/user/delete.rs index 363230d83..d1825425c 100644 --- a/crates/api_crud/src/user/delete.rs +++ b/crates/api_crud/src/user/delete.rs @@ -8,7 +8,11 @@ use lemmy_api_common::{ utils::purge_user_account, SuccessResponse, }; -use lemmy_db_schema::source::{login_token::LoginToken, person::Person}; +use lemmy_db_schema::source::{ + login_token::LoginToken, + oauth_account::OAuthAccount, + person::Person, +}; use lemmy_db_views::structs::LocalUserView; use lemmy_utils::error::{LemmyErrorType, LemmyResult}; @@ -19,11 +23,12 @@ pub async fn delete_account( local_user_view: LocalUserView, ) -> LemmyResult> { // Verify the password - let valid: bool = verify( - &data.password, - &local_user_view.local_user.password_encrypted, - ) - .unwrap_or(false); + let valid: bool = local_user_view + .local_user + .password_encrypted + .as_ref() + .and_then(|password_encrypted| verify(&data.password, password_encrypted).ok()) + .unwrap_or(false); if !valid { Err(LemmyErrorType::IncorrectLogin)? } @@ -31,6 +36,7 @@ pub async fn delete_account( if data.delete_content { purge_user_account(local_user_view.person.id, &context).await?; } else { + OAuthAccount::delete_user_accounts(&mut context.pool(), local_user_view.local_user.id).await?; Person::delete_account(&mut context.pool(), local_user_view.person.id).await?; } diff --git a/crates/db_schema/src/impls/local_user.rs b/crates/db_schema/src/impls/local_user.rs index acff6af2a..87f2ac638 100644 --- a/crates/db_schema/src/impls/local_user.rs +++ b/crates/db_schema/src/impls/local_user.rs @@ -35,9 +35,11 @@ impl LocalUser { ) -> Result { let conn = &mut get_conn(pool).await?; let mut form_with_encrypted_password = form.clone(); - let password_hash = - hash(&form.password_encrypted, DEFAULT_COST).expect("Couldn't hash password"); - form_with_encrypted_password.password_encrypted = password_hash; + + if let Some(password_encrypted) = &form.password_encrypted { + let password_hash = hash(password_encrypted, DEFAULT_COST).expect("Couldn't hash password"); + form_with_encrypted_password.password_encrypted = Some(password_hash); + } let local_user_ = insert_into(local_user::table) .values(form_with_encrypted_password) @@ -346,7 +348,7 @@ impl LocalUserOptionHelper for Option<&LocalUser> { impl LocalUserInsertForm { pub fn test_form(person_id: PersonId) -> Self { - Self::new(person_id, String::new()) + Self::new(person_id, Some(String::new())) } pub fn test_form_admin(person_id: PersonId) -> Self { diff --git a/crates/db_schema/src/impls/mod.rs b/crates/db_schema/src/impls/mod.rs index 3a4e71307..f115a101f 100644 --- a/crates/db_schema/src/impls/mod.rs +++ b/crates/db_schema/src/impls/mod.rs @@ -22,6 +22,8 @@ pub mod local_user; pub mod local_user_vote_display_mode; pub mod login_token; pub mod moderator; +pub mod oauth_account; +pub mod oauth_provider; pub mod password_reset_request; pub mod person; pub mod person_block; diff --git a/crates/db_schema/src/impls/oauth_account.rs b/crates/db_schema/src/impls/oauth_account.rs new file mode 100644 index 000000000..921a21d3d --- /dev/null +++ b/crates/db_schema/src/impls/oauth_account.rs @@ -0,0 +1,59 @@ +use crate::{ + newtypes::{LocalUserId, OAuthProviderId}, + schema::{oauth_account, oauth_account::dsl::local_user_id}, + source::oauth_account::{OAuthAccount, OAuthAccountInsertForm}, + utils::{get_conn, DbPool}, +}; +use diesel::{ + dsl::{exists, insert_into}, + result::Error, + select, + ExpressionMethods, + QueryDsl, +}; +use diesel_async::RunQueryDsl; + +impl OAuthAccount { + pub async fn read( + pool: &mut DbPool<'_>, + for_oauth_provider_id: OAuthProviderId, + for_local_user_id: LocalUserId, + ) -> Result { + let conn = &mut get_conn(pool).await?; + select(exists( + oauth_account::table.find((for_oauth_provider_id, for_local_user_id)), + )) + .get_result(conn) + .await + } + + pub async fn create(pool: &mut DbPool<'_>, form: &OAuthAccountInsertForm) -> Result { + let conn = &mut get_conn(pool).await?; + insert_into(oauth_account::table) + .values(form) + .get_result::(conn) + .await + } + + pub async fn delete( + pool: &mut DbPool<'_>, + for_oauth_provider_id: OAuthProviderId, + for_local_user_id: LocalUserId, + ) -> Result { + let conn = &mut get_conn(pool).await?; + diesel::delete(oauth_account::table.find((for_oauth_provider_id, for_local_user_id))) + .execute(conn) + .await + } + + pub async fn delete_user_accounts( + pool: &mut DbPool<'_>, + for_local_user_id: LocalUserId, + ) -> Result { + let conn = &mut get_conn(pool).await?; + + diesel::delete(oauth_account::table.filter(local_user_id.eq(for_local_user_id))) + .execute(conn) + .await + } +} diff --git a/crates/db_schema/src/impls/oauth_provider.rs b/crates/db_schema/src/impls/oauth_provider.rs new file mode 100644 index 000000000..9d7e791e7 --- /dev/null +++ b/crates/db_schema/src/impls/oauth_provider.rs @@ -0,0 +1,71 @@ +use crate::{ + newtypes::OAuthProviderId, + schema::oauth_provider, + source::oauth_provider::{ + OAuthProvider, + OAuthProviderInsertForm, + OAuthProviderUpdateForm, + PublicOAuthProvider, + }, + traits::Crud, + utils::{get_conn, DbPool}, +}; +use diesel::{dsl::insert_into, result::Error, QueryDsl}; +use diesel_async::RunQueryDsl; + +#[async_trait] +impl Crud for OAuthProvider { + type InsertForm = OAuthProviderInsertForm; + type UpdateForm = OAuthProviderUpdateForm; + type IdType = OAuthProviderId; + + async fn create(pool: &mut DbPool<'_>, form: &Self::InsertForm) -> Result { + let conn = &mut get_conn(pool).await?; + insert_into(oauth_provider::table) + .values(form) + .get_result::(conn) + .await + } + + async fn update( + pool: &mut DbPool<'_>, + oauth_provider_id: OAuthProviderId, + form: &Self::UpdateForm, + ) -> Result { + let conn = &mut get_conn(pool).await?; + diesel::update(oauth_provider::table.find(oauth_provider_id)) + .set(form) + .get_result::(conn) + .await + } +} + +impl OAuthProvider { + pub async fn get_all(pool: &mut DbPool<'_>) -> Result, Error> { + let conn = &mut get_conn(pool).await?; + let oauth_providers = oauth_provider::table + .order(oauth_provider::id) + .select(oauth_provider::all_columns) + .load::(conn) + .await?; + + Ok(oauth_providers) + } + + pub fn convert_providers_to_public( + oauth_providers: Vec, + ) -> Vec { + let mut result = Vec::::new(); + for oauth_provider in &oauth_providers { + if oauth_provider.enabled { + result.push(PublicOAuthProvider(oauth_provider.clone())); + } + } + result + } + + pub async fn get_all_public(pool: &mut DbPool<'_>) -> Result, Error> { + let oauth_providers = OAuthProvider::get_all(pool).await?; + Ok(Self::convert_providers_to_public(oauth_providers)) + } +} diff --git a/crates/db_schema/src/impls/person.rs b/crates/db_schema/src/impls/person.rs index f2909218c..312bbcf21 100644 --- a/crates/db_schema/src/impls/person.rs +++ b/crates/db_schema/src/impls/person.rs @@ -121,6 +121,18 @@ impl Person { .load::(conn) .await } + + pub async fn is_username_taken(pool: &mut DbPool<'_>, username: &str) -> Result { + use diesel::dsl::{exists, select}; + let conn = &mut get_conn(pool).await?; + select(exists( + person::table + .filter(lower(person::name).eq(username.to_lowercase())) + .filter(person::local.eq(true)), + )) + .get_result(conn) + .await + } } impl PersonInsertForm { diff --git a/crates/db_schema/src/newtypes.rs b/crates/db_schema/src/newtypes.rs index c715305bb..d90b1f3f6 100644 --- a/crates/db_schema/src/newtypes.rs +++ b/crates/db_schema/src/newtypes.rs @@ -154,6 +154,12 @@ pub struct CustomEmojiId(i32); /// The registration application id. pub struct RegistrationApplicationId(i32); +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Serialize, Deserialize, Default)] +#[cfg_attr(feature = "full", derive(DieselNewType, TS))] +#[cfg_attr(feature = "full", ts(export))] +/// The oauth provider id. +pub struct OAuthProviderId(pub i32); + #[cfg(feature = "full")] #[derive(Serialize, Deserialize)] #[serde(remote = "Ltree")] diff --git a/crates/db_schema/src/schema.rs b/crates/db_schema/src/schema.rs index aa143c4c9..de3e4fa48 100644 --- a/crates/db_schema/src/schema.rs +++ b/crates/db_schema/src/schema.rs @@ -392,6 +392,7 @@ diesel::table! { federation_signed_fetch -> Bool, default_post_listing_mode -> PostListingModeEnum, default_sort_type -> SortTypeEnum, + oauth_registration -> Bool, } } @@ -435,7 +436,7 @@ diesel::table! { local_user (id) { id -> Int4, person_id -> Int4, - password_encrypted -> Text, + password_encrypted -> Nullable, email -> Nullable, show_nsfw -> Bool, theme -> Text, @@ -611,6 +612,36 @@ diesel::table! { } } +diesel::table! { + oauth_account (oauth_provider_id, local_user_id) { + local_user_id -> Int4, + oauth_provider_id -> Int4, + oauth_user_id -> Text, + published -> Timestamptz, + updated -> Nullable, + } +} + +diesel::table! { + oauth_provider (id) { + id -> Int4, + display_name -> Text, + issuer -> Text, + authorization_endpoint -> Text, + token_endpoint -> Text, + userinfo_endpoint -> Text, + id_claim -> Text, + client_id -> Text, + client_secret -> Text, + scopes -> Text, + auto_verify_email -> Bool, + account_linking_enabled -> Bool, + enabled -> Bool, + published -> Timestamptz, + updated -> Nullable, + } +} + diesel::table! { password_reset_request (id) { id -> Int4, @@ -1003,6 +1034,8 @@ diesel::joinable!(mod_remove_community -> person (mod_person_id)); diesel::joinable!(mod_remove_post -> person (mod_person_id)); diesel::joinable!(mod_remove_post -> post (post_id)); diesel::joinable!(mod_transfer_community -> community (community_id)); +diesel::joinable!(oauth_account -> local_user (local_user_id)); +diesel::joinable!(oauth_account -> oauth_provider (oauth_provider_id)); diesel::joinable!(password_reset_request -> local_user (local_user_id)); diesel::joinable!(person -> instance (instance_id)); diesel::joinable!(person_aggregates -> person (person_id)); @@ -1084,6 +1117,8 @@ diesel::allow_tables_to_appear_in_same_query!( mod_remove_community, mod_remove_post, mod_transfer_community, + oauth_account, + oauth_provider, password_reset_request, person, person_aggregates, diff --git a/crates/db_schema/src/source/local_site.rs b/crates/db_schema/src/source/local_site.rs index 21af1f6ca..8dc81a9a5 100644 --- a/crates/db_schema/src/source/local_site.rs +++ b/crates/db_schema/src/source/local_site.rs @@ -68,6 +68,8 @@ pub struct LocalSite { pub default_post_listing_mode: PostListingMode, /// Default value for [LocalUser.post_listing_mode] pub default_sort_type: SortType, + /// Whether or not external auth methods can auto-register users. + pub oauth_registration: bool, } #[derive(Clone, TypedBuilder)] @@ -94,6 +96,7 @@ pub struct LocalSiteInsertForm { pub captcha_enabled: Option, pub captcha_difficulty: Option, pub registration_mode: Option, + pub oauth_registration: Option, pub reports_email_admins: Option, pub federation_signed_fetch: Option, pub default_post_listing_mode: Option, @@ -121,6 +124,7 @@ pub struct LocalSiteUpdateForm { pub captcha_enabled: Option, pub captcha_difficulty: Option, pub registration_mode: Option, + pub oauth_registration: Option, pub reports_email_admins: Option, pub updated: Option>>, pub federation_signed_fetch: Option, diff --git a/crates/db_schema/src/source/local_user.rs b/crates/db_schema/src/source/local_user.rs index 89bdb1b55..e184d3605 100644 --- a/crates/db_schema/src/source/local_user.rs +++ b/crates/db_schema/src/source/local_user.rs @@ -24,7 +24,7 @@ pub struct LocalUser { /// The person_id for the local user. pub person_id: PersonId, #[serde(skip)] - pub password_encrypted: SensitiveString, + pub password_encrypted: Option, pub email: Option, /// Whether to show NSFW content. pub show_nsfw: bool, @@ -70,7 +70,7 @@ pub struct LocalUser { #[cfg_attr(feature = "full", diesel(table_name = local_user))] pub struct LocalUserInsertForm { pub person_id: PersonId, - pub password_encrypted: String, + pub password_encrypted: Option, #[new(default)] pub email: Option, #[new(default)] diff --git a/crates/db_schema/src/source/mod.rs b/crates/db_schema/src/source/mod.rs index bbc8aafa2..377c1aaef 100644 --- a/crates/db_schema/src/source/mod.rs +++ b/crates/db_schema/src/source/mod.rs @@ -27,6 +27,8 @@ pub mod local_user; pub mod local_user_vote_display_mode; pub mod login_token; pub mod moderator; +pub mod oauth_account; +pub mod oauth_provider; pub mod password_reset_request; pub mod person; pub mod person_block; diff --git a/crates/db_schema/src/source/oauth_account.rs b/crates/db_schema/src/source/oauth_account.rs new file mode 100644 index 000000000..83b578e22 --- /dev/null +++ b/crates/db_schema/src/source/oauth_account.rs @@ -0,0 +1,32 @@ +use crate::newtypes::{LocalUserId, OAuthProviderId}; +#[cfg(feature = "full")] +use crate::schema::oauth_account; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; +#[cfg(feature = "full")] +use ts_rs::TS; + +#[skip_serializing_none] +#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "full", derive(Queryable, Selectable, TS))] +#[cfg_attr(feature = "full", diesel(table_name = oauth_account))] +#[cfg_attr(feature = "full", diesel(check_for_backend(diesel::pg::Pg)))] +#[cfg_attr(feature = "full", ts(export))] +/// An auth account method. +pub struct OAuthAccount { + pub local_user_id: LocalUserId, + pub oauth_provider_id: OAuthProviderId, + pub oauth_user_id: String, + pub published: DateTime, + pub updated: Option>, +} + +#[derive(Debug, Clone, derive_new::new)] +#[cfg_attr(feature = "full", derive(Insertable, AsChangeset))] +#[cfg_attr(feature = "full", diesel(table_name = oauth_account))] +pub struct OAuthAccountInsertForm { + pub local_user_id: LocalUserId, + pub oauth_provider_id: OAuthProviderId, + pub oauth_user_id: String, +} diff --git a/crates/db_schema/src/source/oauth_provider.rs b/crates/db_schema/src/source/oauth_provider.rs new file mode 100644 index 000000000..40046c83c --- /dev/null +++ b/crates/db_schema/src/source/oauth_provider.rs @@ -0,0 +1,131 @@ +#[cfg(feature = "full")] +use crate::schema::oauth_provider; +use crate::{ + newtypes::{DbUrl, OAuthProviderId}, + sensitive::SensitiveString, +}; +use chrono::{DateTime, Utc}; +use serde::{ + ser::{SerializeStruct, Serializer}, + Deserialize, + Serialize, +}; +use serde_with::skip_serializing_none; +#[cfg(feature = "full")] +use ts_rs::TS; + +#[skip_serializing_none] +#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "full", derive(Queryable, Selectable, Identifiable, TS))] +#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))] +#[cfg_attr(feature = "full", diesel(check_for_backend(diesel::pg::Pg)))] +#[cfg_attr(feature = "full", ts(export))] +/// oauth provider with client_secret - should never be sent to the client +pub struct OAuthProvider { + pub id: OAuthProviderId, + /// The OAuth 2.0 provider name displayed to the user on the Login page + pub display_name: String, + /// The issuer url of the OAUTH provider. + #[cfg_attr(feature = "full", ts(type = "string"))] + pub issuer: DbUrl, + /// The authorization endpoint is used to interact with the resource owner and obtain an + /// authorization grant. This is usually provided by the OAUTH provider. + #[cfg_attr(feature = "full", ts(type = "string"))] + pub authorization_endpoint: DbUrl, + /// The token endpoint is used by the client to obtain an access token by presenting its + /// authorization grant or refresh token. This is usually provided by the OAUTH provider. + #[cfg_attr(feature = "full", ts(type = "string"))] + pub token_endpoint: DbUrl, + /// The UserInfo Endpoint is an OAuth 2.0 Protected Resource that returns Claims about the + /// authenticated End-User. This is defined in the OIDC specification. + #[cfg_attr(feature = "full", ts(type = "string"))] + pub userinfo_endpoint: DbUrl, + /// The OAuth 2.0 claim containing the unique user ID returned by the provider. Usually this + /// should be set to "sub". + pub id_claim: String, + /// The client_id is provided by the OAuth 2.0 provider and is a unique identifier to this + /// service + pub client_id: String, + /// The client_secret is provided by the OAuth 2.0 provider and is used to authenticate this + /// service with the provider + #[serde(skip)] + pub client_secret: SensitiveString, + /// Lists the scopes requested from users. Users will have to grant access to the requested scope + /// at sign up. + pub scopes: String, + /// Automatically sets email as verified on registration + pub auto_verify_email: bool, + /// Allows linking an OAUTH account to an existing user account by matching emails + pub account_linking_enabled: bool, + /// switch to enable or disable an oauth provider + pub enabled: bool, + pub published: DateTime, + pub updated: Option>, +} + +#[derive(Clone, PartialEq, Eq, Debug, Deserialize)] +#[serde(transparent)] +#[cfg_attr(feature = "full", derive(TS))] +#[cfg_attr(feature = "full", ts(export))] +// A subset of OAuthProvider used for public requests, for example to display the OAUTH buttons on +// the login page +pub struct PublicOAuthProvider(pub OAuthProvider); + +impl Serialize for PublicOAuthProvider { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("PublicOAuthProvider", 5)?; + state.serialize_field("id", &self.0.id)?; + state.serialize_field("display_name", &self.0.display_name)?; + state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?; + state.serialize_field("client_id", &self.0.client_id)?; + state.serialize_field("scopes", &self.0.scopes)?; + state.end() + } +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))] +#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))] +#[cfg_attr(feature = "full", ts(export))] +pub struct OAuthProviderInsertForm { + pub display_name: String, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub issuer: DbUrl, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub authorization_endpoint: DbUrl, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub token_endpoint: DbUrl, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub userinfo_endpoint: DbUrl, + pub id_claim: String, + pub client_id: String, + pub client_secret: String, + pub scopes: String, + pub auto_verify_email: bool, + pub account_linking_enabled: bool, + pub enabled: bool, +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))] +#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))] +#[cfg_attr(feature = "full", ts(export))] +pub struct OAuthProviderUpdateForm { + pub display_name: Option, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub authorization_endpoint: Option, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub token_endpoint: Option, + #[cfg_attr(feature = "full", ts(type = "string"))] + pub userinfo_endpoint: Option, + pub id_claim: Option, + pub client_secret: Option, + pub scopes: Option, + pub auto_verify_email: Option, + pub account_linking_enabled: Option, + pub enabled: Option, + pub updated: Option>>, +} diff --git a/crates/db_schema/src/utils.rs b/crates/db_schema/src/utils.rs index 8e4e35006..a174e3cb9 100644 --- a/crates/db_schema/src/utils.rs +++ b/crates/db_schema/src/utils.rs @@ -288,7 +288,7 @@ pub fn is_email_regex(test: &str) -> bool { EMAIL_REGEX.is_match(test) } -/// Takes an API text input, and converts it to an optional diesel DB update. +/// Takes an API optional text input, and converts it to an optional diesel DB update. pub fn diesel_string_update(opt: Option<&str>) -> Option> { match opt { // An empty string is an erase @@ -298,6 +298,17 @@ pub fn diesel_string_update(opt: Option<&str>) -> Option> { } } +/// Takes an API optional text input, and converts it to an optional diesel DB update (for non +/// nullable properties). +pub fn diesel_required_string_update(opt: Option<&str>) -> Option { + match opt { + // An empty string is no change + Some("") => None, + Some(str) => Some(str.into()), + None => None, + } +} + /// Takes an optional API URL-type input, and converts it to an optional diesel DB update. /// Also cleans the url params. pub fn diesel_url_update(opt: Option<&str>) -> LemmyResult>> { @@ -311,6 +322,19 @@ pub fn diesel_url_update(opt: Option<&str>) -> LemmyResult> } } +/// Takes an optional API URL-type input, and converts it to an optional diesel DB update (for non +/// nullable properties). Also cleans the url params. +pub fn diesel_required_url_update(opt: Option<&str>) -> LemmyResult> { + match opt { + // An empty string is no change + Some("") => Ok(None), + Some(str_url) => Url::parse(str_url) + .map(|u| Some(clean_url(&u).into())) + .with_lemmy_type(LemmyErrorType::InvalidUrl), + None => Ok(None), + } +} + /// Takes an optional API URL-type input, and converts it to an optional diesel DB create. /// Also cleans the url params. pub fn diesel_url_create(opt: Option<&str>) -> LemmyResult> { diff --git a/crates/db_views/src/local_user_view.rs b/crates/db_views/src/local_user_view.rs index 0c13b0a68..b20dfe235 100644 --- a/crates/db_views/src/local_user_view.rs +++ b/crates/db_views/src/local_user_view.rs @@ -3,8 +3,8 @@ use actix_web::{dev::Payload, FromRequest, HttpMessage, HttpRequest}; use diesel::{result::Error, BoolExpressionMethods, ExpressionMethods, JoinOnDsl, QueryDsl}; use diesel_async::RunQueryDsl; use lemmy_db_schema::{ - newtypes::{LocalUserId, PersonId}, - schema::{local_user, local_user_vote_display_mode, person, person_aggregates}, + newtypes::{LocalUserId, OAuthProviderId, PersonId}, + schema::{local_user, local_user_vote_display_mode, oauth_account, person, person_aggregates}, utils::{ functions::{coalesce, lower}, DbConn, @@ -23,6 +23,7 @@ enum ReadBy<'a> { Name(&'a str), NameOrEmail(&'a str), Email(&'a str), + OAuthID(OAuthProviderId, &'a str), } enum ListMode { @@ -58,12 +59,21 @@ fn queries<'a>( ), _ => query, }; - query + let query = query .inner_join(local_user_vote_display_mode::table) - .inner_join(person_aggregates::table.on(person::id.eq(person_aggregates::person_id))) - .select(selection) - .first(&mut conn) - .await + .inner_join(person_aggregates::table.on(person::id.eq(person_aggregates::person_id))); + + if let ReadBy::OAuthID(oauth_provider_id, oauth_user_id) = search { + query + .inner_join(oauth_account::table) + .filter(oauth_account::oauth_provider_id.eq(oauth_provider_id)) + .filter(oauth_account::oauth_user_id.eq(oauth_user_id)) + .select(selection) + .first(&mut conn) + .await + } else { + query.select(selection).first(&mut conn).await + } }; let list = move |mut conn: DbConn<'a>, mode: ListMode| async move { @@ -120,6 +130,16 @@ impl LocalUserView { queries().read(pool, ReadBy::Email(from_email)).await } + pub async fn find_by_oauth_id( + pool: &mut DbPool<'_>, + oauth_provider_id: OAuthProviderId, + oauth_user_id: &str, + ) -> Result, Error> { + queries() + .read(pool, ReadBy::OAuthID(oauth_provider_id, oauth_user_id)) + .await + } + pub async fn list_admins_with_emails(pool: &mut DbPool<'_>) -> Result, Error> { queries().list(pool, ListMode::AdminsWithEmails).await } diff --git a/crates/routes/src/images.rs b/crates/routes/src/images.rs index 768a607c2..a0f804b6b 100644 --- a/crates/routes/src/images.rs +++ b/crates/routes/src/images.rs @@ -5,8 +5,7 @@ use actix_web::{ Method, StatusCode, }, - web, - web::Query, + web::{self, Query}, HttpRequest, HttpResponse, }; diff --git a/crates/utils/src/error.rs b/crates/utils/src/error.rs index 4e634bde3..1935e4132 100644 --- a/crates/utils/src/error.rs +++ b/crates/utils/src/error.rs @@ -55,6 +55,7 @@ pub enum LemmyErrorType { CouldntFindCommentReply, CouldntFindPrivateMessage, CouldntFindActivity, + CouldntFindOauthProvider, PersonIsBlocked, CommunityIsBlocked, InstanceIsBlocked, @@ -83,7 +84,9 @@ pub enum LemmyErrorType { InvalidDefaultPostListingType, RegistrationClosed, RegistrationApplicationAnswerRequired, + RegistrationUsernameRequired, EmailAlreadyExists, + UsernameAlreadyExists, FederationForbiddenByStrictAllowList, PersonIsBannedFromCommunity, ObjectIsNotPublic, @@ -178,6 +181,10 @@ pub enum LemmyErrorType { CantBlockLocalInstance, UrlWithoutDomain, InboxTimeout, + OauthAuthorizationInvalid, + OauthLoginFailed, + OauthRegistrationClosed, + CouldntDeleteOauthProvider, Unknown(String), CantDeleteSite, UrlLengthOverflow, diff --git a/migrations/2024-09-16-174833_create_oauth_provider/down.sql b/migrations/2024-09-16-174833_create_oauth_provider/down.sql new file mode 100644 index 000000000..d1e62bc46 --- /dev/null +++ b/migrations/2024-09-16-174833_create_oauth_provider/down.sql @@ -0,0 +1,10 @@ +DROP TABLE oauth_account; + +DROP TABLE oauth_provider; + +ALTER TABLE local_site + DROP COLUMN oauth_registration; + +ALTER TABLE local_user + ALTER COLUMN password_encrypted SET NOT NULL; + diff --git a/migrations/2024-09-16-174833_create_oauth_provider/up.sql b/migrations/2024-09-16-174833_create_oauth_provider/up.sql new file mode 100644 index 000000000..a75f01228 --- /dev/null +++ b/migrations/2024-09-16-174833_create_oauth_provider/up.sql @@ -0,0 +1,34 @@ +ALTER TABLE local_user + ALTER COLUMN password_encrypted DROP NOT NULL; + +CREATE TABLE oauth_provider ( + id serial PRIMARY KEY, + display_name text NOT NULL, + issuer text NOT NULL, + authorization_endpoint text NOT NULL, + token_endpoint text NOT NULL, + userinfo_endpoint text NOT NULL, + id_claim text NOT NULL, + client_id text NOT NULL UNIQUE, + client_secret text NOT NULL, + scopes text NOT NULL, + auto_verify_email boolean DEFAULT TRUE NOT NULL, + account_linking_enabled boolean DEFAULT FALSE NOT NULL, + enabled boolean DEFAULT FALSE NOT NULL, + published timestamp with time zone DEFAULT now() NOT NULL, + updated timestamp with time zone +); + +ALTER TABLE local_site + ADD COLUMN oauth_registration boolean DEFAULT FALSE NOT NULL; + +CREATE TABLE oauth_account ( + local_user_id int REFERENCES local_user ON UPDATE CASCADE ON DELETE CASCADE NOT NULL, + oauth_provider_id int REFERENCES oauth_provider ON UPDATE CASCADE ON DELETE RESTRICT NOT NULL, + oauth_user_id text NOT NULL, + published timestamp with time zone DEFAULT now() NOT NULL, + updated timestamp with time zone, + UNIQUE (oauth_provider_id, oauth_user_id), + PRIMARY KEY (oauth_provider_id, local_user_id) +); + diff --git a/src/api_routes_http.rs b/src/api_routes_http.rs index 7b4b34158..faa7d78f2 100644 --- a/src/api_routes_http.rs +++ b/src/api_routes_http.rs @@ -109,6 +109,11 @@ use lemmy_api_crud::{ delete::delete_custom_emoji, update::update_custom_emoji, }, + oauth_provider::{ + create::create_oauth_provider, + delete::delete_oauth_provider, + update::update_oauth_provider, + }, post::{ create::create_post, delete::delete_post, @@ -123,7 +128,10 @@ use lemmy_api_crud::{ update::update_private_message, }, site::{create::create_site, read::get_site, update::update_site}, - user::{create::register, delete::delete_account}, + user::{ + create::{authenticate_with_oauth, register}, + delete::delete_account, + }, }; use lemmy_apub::api::{ list_comments::list_comments, @@ -381,6 +389,18 @@ pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) { .route("", web::post().to(create_custom_emoji)) .route("", web::put().to(update_custom_emoji)) .route("/delete", web::post().to(delete_custom_emoji)), + ) + .service( + web::scope("/oauth_provider") + .wrap(rate_limit.message()) + .route("", web::post().to(create_oauth_provider)) + .route("", web::put().to(update_oauth_provider)) + .route("/delete", web::post().to(delete_oauth_provider)), + ) + .service( + web::scope("/oauth") + .wrap(rate_limit.register()) + .route("/authenticate", web::post().to(authenticate_with_oauth)), ), ); cfg.service( diff --git a/src/code_migrations.rs b/src/code_migrations.rs index 12a688b80..bc03d513a 100644 --- a/src/code_migrations.rs +++ b/src/code_migrations.rs @@ -471,7 +471,7 @@ async fn initialize_local_site_2022_10_10( let local_user_form = LocalUserInsertForm { email: setup.admin_email.clone(), admin: Some(true), - ..LocalUserInsertForm::new(person_inserted.id, setup.admin_password.clone()) + ..LocalUserInsertForm::new(person_inserted.id, Some(setup.admin_password.clone())) }; LocalUser::create(pool, &local_user_form, vec![]).await?; };