From 9a94a863630e3efee2866324bb59bc3682867f89 Mon Sep 17 00:00:00 2001 From: Nutomic Date: Tue, 14 Nov 2023 15:39:13 +0100 Subject: [PATCH] Fix cors wildcard (ref #4095) (#4156) * Fix cors wildcard (ref #4095) * cleanup * clippy --- src/lib.rs | 47 ++++++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6bedb97fd..e99e9ce81 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -279,22 +279,11 @@ fn create_http_server( let context: LemmyContext = federation_config.deref().clone(); let rate_limit_cell = federation_config.rate_limit_cell().clone(); - let self_origin = settings.get_protocol_and_hostname(); - let cors_origin_setting = settings.cors_origin(); - // Create Http server with websocket support - let server = HttpServer::new(move || { - let cors_config = match (cors_origin_setting.clone(), cfg!(debug_assertions)) { - (Some(origin), false) => Cors::default() - .allowed_origin(&origin) - .allowed_origin(&self_origin), - _ => Cors::default() - .allow_any_origin() - .allow_any_method() - .allow_any_header() - .expose_any_header() - .max_age(3600), - }; + // Create Http server + let bind = (settings.bind, settings.port); + let server = HttpServer::new(move || { + let cors_config = cors_config(&settings); let app = App::new() .wrap(middleware::Logger::new( // This is the default log format save for the usage of %{r}a over %a to guarantee to record the client's (forwarded) IP and not the last peer address, since the latter is frequently just a reverse proxy @@ -309,9 +298,6 @@ fn create_http_server( .wrap(FederationMiddleware::new(federation_config.clone())) .wrap(SessionMiddleware::new(context.clone())); - #[cfg(feature = "prometheus-metrics")] - let app = app.wrap(prom_api_metrics.clone()); - // The routes app .configure(|cfg| api_routes_http::config(cfg, &rate_limit_cell)) @@ -326,13 +312,36 @@ fn create_http_server( .configure(nodeinfo::config) }) .disable_signals() - .bind((settings.bind, settings.port))? + .bind(bind)? .run(); let handle = server.handle(); tokio::task::spawn(server); Ok(handle) } +fn cors_config(settings: &Settings) -> Cors { + let self_origin = settings.get_protocol_and_hostname(); + let cors_origin_setting = settings.cors_origin(); + match (cors_origin_setting.clone(), cfg!(debug_assertions)) { + (Some(origin), false) => { + // Need to call send_wildcard() explicitly, passing this into allowed_origin() results in error + if cors_origin_setting.as_deref() == Some("*") { + Cors::default().send_wildcard() + } else { + Cors::default() + .allowed_origin(&origin) + .allowed_origin(&self_origin) + } + } + _ => Cors::default() + .allow_any_origin() + .allow_any_method() + .allow_any_header() + .expose_any_header() + .max_age(3600), + } +} + pub fn init_logging(opentelemetry_url: &Option) -> Result<(), LemmyError> { LogTracer::init()?;