use actix_rt::time::Timeout; use actix_web::{ dev::{Service, ServiceRequest, Transform}, http::StatusCode, HttpResponse, ResponseError, }; use std::{ future::{ready, Future, Ready}, pin::Pin, task::{Context, Poll}, }; pub(crate) struct Deadline; pub(crate) struct DeadlineMiddleware { inner: S, } #[derive(Debug)] struct DeadlineExceeded; pin_project_lite::pin_project! { pub(crate) struct DeadlineFuture { #[pin] inner: DeadlineFutureInner, } } pin_project_lite::pin_project! { #[project = DeadlineFutureInnerProj] #[project_replace = DeadlineFutureInnerProjReplace] enum DeadlineFutureInner { Timed { #[pin] timeout: Timeout, }, Untimed { #[pin] future: F, }, } } pub(crate) struct Internal(pub(crate) Option); pub(crate) struct InternalMiddleware(Option, S); #[derive(Clone, Debug, thiserror::Error)] #[error("Invalid API Key")] pub(crate) struct ApiError; pin_project_lite::pin_project! { #[project = InternalFutureProj] #[project_replace = InternalFutureProjReplace] pub(crate) enum InternalFuture { Internal { #[pin] future: F, }, Error { error: Option, }, } } impl ResponseError for ApiError { fn status_code(&self) -> StatusCode { StatusCode::UNAUTHORIZED } fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type("application/json") .body( serde_json::to_string(&serde_json::json!({ "msg": self.to_string() })) .unwrap_or_else(|_| r#"{"msg":"unauthorized"}"#.to_string()), ) } } impl Transform for Deadline where S: Service, S::Future: 'static, actix_web::Error: From, { type Response = S::Response; type Error = actix_web::Error; type InitError = (); type Transform = DeadlineMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(DeadlineMiddleware { inner: service })) } } impl Service for DeadlineMiddleware where S: Service, S::Future: 'static, actix_web::Error: From, { type Response = S::Response; type Error = actix_web::Error; type Future = DeadlineFuture; fn poll_ready(&self, cx: &mut core::task::Context<'_>) -> Poll> { self.inner .poll_ready(cx) .map(|res| res.map_err(actix_web::Error::from)) } fn call(&self, req: ServiceRequest) -> Self::Future { let duration = req .headers() .get("X-Request-Deadline") .and_then(|deadline| { let deadline = time::OffsetDateTime::from_unix_timestamp_nanos( deadline.to_str().ok()?.parse().ok()?, ) .ok()?; let now = time::OffsetDateTime::now_utc(); if now < deadline { Some((deadline - now).try_into().ok()?) } else { None } }); DeadlineFuture::new(self.inner.call(req), duration) } } impl std::fmt::Display for DeadlineExceeded { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Deadline exceeded") } } impl std::error::Error for DeadlineExceeded {} impl actix_web::error::ResponseError for DeadlineExceeded { fn status_code(&self) -> StatusCode { StatusCode::REQUEST_TIMEOUT } fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type("application/json") .body( serde_json::to_string(&serde_json::json!({ "msg": self.to_string() })) .unwrap_or_else(|_| r#"{"msg":"request timeout"}"#.to_string()), ) } } impl DeadlineFuture where F: Future, { fn new(future: F, timeout: Option) -> Self { DeadlineFuture { inner: match timeout { Some(duration) => DeadlineFutureInner::Timed { timeout: actix_rt::time::timeout(duration, future), }, None => DeadlineFutureInner::Untimed { future }, }, } } } impl Future for DeadlineFuture where F: Future>, actix_web::Error: From, { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut().project(); match this.inner.project() { DeadlineFutureInnerProj::Timed { timeout } => timeout.poll(cx).map(|res| match res { Ok(res) => res.map_err(actix_web::Error::from), Err(_) => Err(DeadlineExceeded.into()), }), DeadlineFutureInnerProj::Untimed { future } => future .poll(cx) .map(|res| res.map_err(actix_web::Error::from)), } } } impl Transform for Internal where S: Service, S::Future: 'static, { type Response = S::Response; type Error = S::Error; type InitError = (); type Transform = InternalMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(InternalMiddleware(self.0.clone(), service))) } } impl Service for InternalMiddleware where S: Service, S::Future: 'static, { type Response = S::Response; type Error = S::Error; type Future = InternalFuture; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.1.poll_ready(cx) } fn call(&self, req: ServiceRequest) -> Self::Future { if let Some(value) = req.headers().get("x-api-token") { if let (Ok(header), Some(api_key)) = (value.to_str(), &self.0) { if header == api_key { return InternalFuture::Internal { future: self.1.call(req), }; } } } InternalFuture::Error { error: Some(ApiError), } } } impl Future for InternalFuture where F: Future>, E: From, { type Output = F::Output; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project() { InternalFutureProj::Internal { future } => future.poll(cx), InternalFutureProj::Error { error } => Poll::Ready(Err(error.take().unwrap().into())), } } }