diff --git a/README.md b/README.md index 028940d..13f6357 100644 --- a/README.md +++ b/README.md @@ -216,6 +216,21 @@ A secure API key can be generated by any password generator. } ``` +Additionally, all endpoints support setting deadlines, after which the request will cease +processing. To enable deadlines for your requests, you can set the `X-Request-Deadline` header to an +i128 value representing the number of nanoseconds since the UNIX Epoch. A simple way to calculate +this value is to use the `time` crate's `OffsetDateTime::unix_timestamp_nanos` method. For example, +```rust +// set deadline of 1ms +let deadline = time::OffsetDateTime::now_utc() + time::Duration::new(0, 1_000); + +let request = client + .get("http://pict-rs:8080/image/details/original/asdfghjkla.png") + .insert_header(("X-Request-Deadline", deadline.unix_timestamp_nanos().to_string()))) + .send() + .await; +``` + ## Contributing Feel free to open issues for anything you find an issue with. Please note that any contributed code will be licensed under the AGPLv3. diff --git a/src/main.rs b/src/main.rs index be767b1..c8bc42e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,7 +44,7 @@ mod validate; use self::{ config::{Config, Format}, error::UploadError, - middleware::{Internal, Tracing}, + middleware::{Deadline, Internal, Tracing}, upload_manager::{Details, UploadManager}, validate::{image_webp, video_mp4}, }; @@ -859,6 +859,7 @@ async fn main() -> Result<(), anyhow::Error> { App::new() .wrap(Logger::default()) .wrap(Tracing) + .wrap(Deadline) .app_data(web::Data::new(manager.clone())) .app_data(web::Data::new(client)) .app_data(web::Data::new(CONFIG.filter_whitelist())) diff --git a/src/middleware.rs b/src/middleware.rs index 1184a9d..a69c3fc 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -5,12 +5,30 @@ use actix_web::{ }; use futures_util::future::LocalBoxFuture; use std::{ - future::{ready, Ready}, + future::{ready, Future, Ready}, + pin::Pin, task::{Context, Poll}, }; +use actix_rt::time::Timeout; use tracing_futures::{Instrument, Instrumented}; use uuid::Uuid; +pub(crate) struct Deadline; +pub(crate) struct DeadlineMiddleware { + inner: S, +} + +#[derive(Debug)] +struct DeadlineExceeded; + +enum DeadlineFutureInner { + Timed(Pin>>), + Untimed(Pin>), +} +pub(crate) struct DeadlineFuture { + inner: DeadlineFutureInner, +} + pub(crate) struct Tracing; pub(crate) struct TracingMiddleware { @@ -38,6 +56,121 @@ impl ResponseError for ApiError { } } +impl Transform for Deadline +where + S: Service, + S::Future: 'static, + actix_web::Error: From, +{ + type Response = S::Response; + type Error = actix_web::Error; + type InitError = (); + type Transform = DeadlineMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(DeadlineMiddleware { inner: service })) + } +} + +impl Service for DeadlineMiddleware +where + S: Service, + S::Future: 'static, + actix_web::Error: From, +{ + type Response = S::Response; + type Error = actix_web::Error; + type Future = DeadlineFuture; + + fn poll_ready(&self, cx: &mut core::task::Context<'_>) -> Poll> { + self.inner + .poll_ready(cx) + .map(|res| res.map_err(actix_web::Error::from)) + } + + fn call(&self, req: ServiceRequest) -> Self::Future { + let duration = req + .headers() + .get("X-Request-Deadline") + .and_then(|deadline| { + use std::convert::TryInto; + let deadline = time::OffsetDateTime::from_unix_timestamp_nanos( + deadline.to_str().ok()?.parse().ok()?, + ) + .ok()?; + let now = time::OffsetDateTime::now_utc(); + + if now < deadline { + Some((deadline - now).try_into().ok()?) + } else { + None + } + }); + DeadlineFuture::new(self.inner.call(req), duration) + } +} + +impl std::fmt::Display for DeadlineExceeded { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Deadline exceeded") + } +} + +impl std::error::Error for DeadlineExceeded {} +impl actix_web::error::ResponseError for DeadlineExceeded { + fn status_code(&self) -> StatusCode { + StatusCode::REQUEST_TIMEOUT + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()) + .content_type("application/json") + .body( + serde_json::to_string(&serde_json::json!({ "msg": self.to_string() })) + .unwrap_or(r#"{"msg":"request timeout"}"#.to_string()), + ) + } +} + +impl DeadlineFuture +where + F: Future, +{ + fn new(inner: F, timeout: Option) -> Self { + DeadlineFuture { + inner: match timeout { + Some(duration) => { + DeadlineFutureInner::Timed(Box::pin(actix_rt::time::timeout(duration, inner))) + } + None => DeadlineFutureInner::Untimed(Box::pin(inner)), + }, + } + } +} + +impl Future for DeadlineFuture +where + F: Future>, + actix_web::Error: From, +{ + type Output = Result; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.inner { + DeadlineFutureInner::Timed(ref mut fut) => { + Pin::new(fut).poll(cx).map(|res| match res { + Ok(res) => res.map_err(actix_web::Error::from), + Err(_) => Err(DeadlineExceeded.into()), + }) + } + DeadlineFutureInner::Untimed(ref mut fut) => Pin::new(fut) + .poll(cx) + .map(|res| res.map_err(actix_web::Error::from)), + } + } +} + impl Transform for Tracing where S: Service,