diff --git a/src/middleware.rs b/src/middleware.rs index 452add9..2d991de 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,10 +1,10 @@ +mod deadline; mod metrics; mod payload; use actix_web::{ dev::{Service, ServiceRequest, Transform}, http::StatusCode, - rt::time::Timeout, HttpResponse, ResponseError, }; use std::{ @@ -13,41 +13,10 @@ use std::{ task::{Context, Poll}, }; -use crate::future::WithTimeout; - +pub(crate) use self::deadline::Deadline; pub(crate) use self::metrics::Metrics; pub(crate) use self::payload::Payload; -pub(crate) struct Deadline; -pub(crate) struct DeadlineMiddleware { - inner: S, -} - -#[derive(Debug)] -struct DeadlineExceeded; - -pin_project_lite::pin_project! { - pub(crate) struct DeadlineFuture { - #[pin] - inner: DeadlineFutureInner, - } -} - -pin_project_lite::pin_project! { - #[project = DeadlineFutureInnerProj] - #[project_replace = DeadlineFutureInnerProjReplace] - enum DeadlineFutureInner { - Timed { - #[pin] - timeout: Timeout, - }, - Untimed { - #[pin] - future: F, - }, - } -} - pub(crate) struct Internal(pub(crate) Option); pub(crate) struct InternalMiddleware(Option, S); #[derive(Clone, Debug, thiserror::Error)] @@ -83,124 +52,6 @@ impl ResponseError for ApiError { } } -impl Transform for Deadline -where - S: Service, - S::Future: 'static, - actix_web::Error: From, -{ - type Response = S::Response; - type Error = actix_web::Error; - type InitError = (); - type Transform = DeadlineMiddleware; - type Future = Ready>; - - fn new_transform(&self, service: S) -> Self::Future { - ready(Ok(DeadlineMiddleware { inner: service })) - } -} - -impl Service for DeadlineMiddleware -where - S: Service, - S::Future: 'static, - actix_web::Error: From, -{ - type Response = S::Response; - type Error = actix_web::Error; - type Future = DeadlineFuture; - - fn poll_ready(&self, cx: &mut core::task::Context<'_>) -> Poll> { - self.inner - .poll_ready(cx) - .map(|res| res.map_err(actix_web::Error::from)) - } - - fn call(&self, req: ServiceRequest) -> Self::Future { - let duration = req - .headers() - .get("X-Request-Deadline") - .and_then(|deadline| { - let deadline = time::OffsetDateTime::from_unix_timestamp_nanos( - deadline.to_str().ok()?.parse().ok()?, - ) - .ok()?; - let now = time::OffsetDateTime::now_utc(); - - if now < deadline { - Some((deadline - now).try_into().ok()?) - } else { - Some(std::time::Duration::from_secs(0)) - } - }); - DeadlineFuture::new(self.inner.call(req), duration) - } -} - -impl std::fmt::Display for DeadlineExceeded { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Deadline exceeded") - } -} - -impl std::error::Error for DeadlineExceeded {} -impl actix_web::error::ResponseError for DeadlineExceeded { - fn status_code(&self) -> StatusCode { - StatusCode::REQUEST_TIMEOUT - } - - fn error_response(&self) -> HttpResponse { - HttpResponse::build(self.status_code()) - .content_type("application/json") - .body( - serde_json::to_string( - &serde_json::json!({ "msg": self.to_string(), "code": "request-timeout" }), - ) - .unwrap_or_else(|_| { - r#"{"msg":"request timeout","code":"request-timeout"}"#.to_string() - }), - ) - } -} - -impl DeadlineFuture -where - F: Future, -{ - fn new(future: F, timeout: Option) -> Self { - DeadlineFuture { - inner: match timeout { - Some(duration) => DeadlineFutureInner::Timed { - timeout: future.with_timeout(duration), - }, - None => DeadlineFutureInner::Untimed { future }, - }, - } - } -} - -impl Future for DeadlineFuture -where - F: Future>, - actix_web::Error: From, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - match this.inner.project() { - DeadlineFutureInnerProj::Timed { timeout } => timeout.poll(cx).map(|res| match res { - Ok(res) => res.map_err(actix_web::Error::from), - Err(_) => Err(DeadlineExceeded.into()), - }), - DeadlineFutureInnerProj::Untimed { future } => future - .poll(cx) - .map(|res| res.map_err(actix_web::Error::from)), - } - } -} - impl Transform for Internal where S: Service, diff --git a/src/middleware/deadline.rs b/src/middleware/deadline.rs new file mode 100644 index 0000000..c6ddb43 --- /dev/null +++ b/src/middleware/deadline.rs @@ -0,0 +1,195 @@ +use actix_web::{ + dev::{Service, ServiceRequest, Transform}, + http::StatusCode, + rt::time::Timeout, + HttpResponse, ResponseError, +}; +use std::{ + future::{ready, Future, Ready}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use crate::future::WithTimeout; + +pub(crate) struct Deadline; +pub(crate) struct DeadlineMiddleware { + inner: S, +} + +#[derive(Debug, thiserror::Error)] +#[error("Deadline exceeded")] +struct DeadlineExceeded; + +#[derive(Debug, thiserror::Error)] +enum ParseDeadlineError { + #[error("Invalid header string")] + HeaderString, + + #[error("Invalid deadline format")] + HeaderFormat, + + #[error("Invalid deadline timestmap")] + Timestamp, + + #[error("Invalid deadline duration")] + Duration, +} + +pin_project_lite::pin_project! { + pub(crate) struct DeadlineFuture { + #[pin] + inner: DeadlineFutureInner, + } +} + +pin_project_lite::pin_project! { + #[project = DeadlineFutureInnerProj] + #[project_replace = DeadlineFutureInnerProjReplace] + enum DeadlineFutureInner { + Timed { + #[pin] + timeout: Timeout, + }, + Untimed { + #[pin] + future: F, + }, + Error { + error: Option, + }, + } +} + +impl Transform for Deadline +where + S: Service, + S::Future: 'static, + actix_web::Error: From, +{ + type Response = S::Response; + type Error = actix_web::Error; + type InitError = (); + type Transform = DeadlineMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(DeadlineMiddleware { inner: service })) + } +} + +impl Service for DeadlineMiddleware +where + S: Service, + S::Future: 'static, + actix_web::Error: From, +{ + type Response = S::Response; + type Error = actix_web::Error; + type Future = DeadlineFuture; + + fn poll_ready(&self, cx: &mut core::task::Context<'_>) -> Poll> { + self.inner + .poll_ready(cx) + .map(|res| res.map_err(actix_web::Error::from)) + } + + fn call(&self, req: ServiceRequest) -> Self::Future { + let duration: Result, ParseDeadlineError> = req + .headers() + .get("X-Request-Deadline") + .map(|deadline| { + let deadline_str = deadline + .to_str() + .map_err(|_| ParseDeadlineError::HeaderString)?; + + let deadline_i128 = deadline_str + .parse() + .map_err(|_| ParseDeadlineError::HeaderFormat)?; + + let deadline = time::OffsetDateTime::from_unix_timestamp_nanos(deadline_i128) + .map_err(|_| ParseDeadlineError::Timestamp)?; + + let now = time::OffsetDateTime::now_utc(); + + if now < deadline { + (deadline - now) + .try_into() + .map_err(|_| ParseDeadlineError::Duration) + } else { + Ok(Duration::from_secs(0)) + } + }) + .transpose(); + DeadlineFuture::new(self.inner.call(req), duration) + } +} + +impl ResponseError for DeadlineExceeded { + fn status_code(&self) -> StatusCode { + StatusCode::REQUEST_TIMEOUT + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()).json(serde_json::json!({ + "msg": self.to_string(), + "code": "request-timeout" + })) + } +} + +impl ResponseError for ParseDeadlineError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()).json(serde_json::json!({ + "msg": self.to_string(), + "code": "parse-deadline" + })) + } +} + +impl DeadlineFuture +where + F: Future, +{ + fn new(future: F, timeout: Result, ParseDeadlineError>) -> Self { + DeadlineFuture { + inner: match timeout { + Ok(Some(duration)) => DeadlineFutureInner::Timed { + timeout: future.with_timeout(duration), + }, + Ok(None) => DeadlineFutureInner::Untimed { future }, + Err(e) => DeadlineFutureInner::Error { error: Some(e) }, + }, + } + } +} + +impl Future for DeadlineFuture +where + F: Future>, + actix_web::Error: From, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); + + match this.inner.project() { + DeadlineFutureInnerProj::Timed { timeout } => timeout.poll(cx).map(|res| match res { + Ok(res) => res.map_err(actix_web::Error::from), + Err(_) => Err(DeadlineExceeded.into()), + }), + DeadlineFutureInnerProj::Untimed { future } => future + .poll(cx) + .map(|res| res.map_err(actix_web::Error::from)), + DeadlineFutureInnerProj::Error { error } => { + Poll::Ready(Err(error.take().expect("Polled after completion").into())) + } + } + } +}