From 66e17117231a3c8f677e0d556c989a85aa4a107e Mon Sep 17 00:00:00 2001 From: asonix Date: Sat, 30 Sep 2023 16:26:43 -0500 Subject: [PATCH] Enable proper draining of dropped request payloads Doing this as the outermost middleware ensures all endpoints are covered. Update request deadline to turn negative deadlines into immediate failures --- src/future.rs | 30 +++++++ src/lib.rs | 4 +- src/middleware.rs | 4 +- src/middleware/payload.rs | 180 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 2 deletions(-) create mode 100644 src/middleware/payload.rs diff --git a/src/future.rs b/src/future.rs index a89d857..c6f5fbd 100644 --- a/src/future.rs +++ b/src/future.rs @@ -1,10 +1,39 @@ use std::{ future::Future, + sync::{Arc, OnceLock}, time::{Duration, Instant}, }; +static NOOP_WAKER: OnceLock = OnceLock::new(); + +fn noop_waker() -> &'static std::task::Waker { + NOOP_WAKER.get_or_init(|| std::task::Waker::from(Arc::new(NoopWaker))) +} + +struct NoopWaker; +impl std::task::Wake for NoopWaker { + fn wake(self: std::sync::Arc) {} + fn wake_by_ref(self: &std::sync::Arc) {} +} + pub(crate) type LocalBoxFuture<'a, T> = std::pin::Pin + 'a>>; +pub(crate) trait NowOrNever: Future { + fn now_or_never(self) -> Option + where + Self: Sized, + { + let fut = std::pin::pin!(self); + + let mut cx = std::task::Context::from_waker(noop_waker()); + + match fut.poll(&mut cx) { + std::task::Poll::Pending => None, + std::task::Poll::Ready(out) => Some(out), + } + } +} + pub(crate) trait WithTimeout: Future { fn with_timeout(self, duration: Duration) -> actix_web::rt::time::Timeout where @@ -30,6 +59,7 @@ pub(crate) trait WithMetrics: Future { } } +impl NowOrNever for F where F: Future {} impl WithMetrics for F where F: Future {} impl WithTimeout for F where F: Future {} diff --git a/src/lib.rs b/src/lib.rs index d07e6a8..365ad80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,7 @@ use details::{ApiDetails, HumanDate}; use future::WithTimeout; use futures_core::Stream; use metrics_exporter_prometheus::PrometheusBuilder; -use middleware::Metrics; +use middleware::{Metrics, Payload}; use repo::ArcRepo; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest_tracing::TracingMiddleware; @@ -1784,6 +1784,7 @@ async fn launch_file_store { @@ -128,7 +130,7 @@ where if now < deadline { Some((deadline - now).try_into().ok()?) } else { - None + Some(std::time::Duration::from_secs(0)) } }); DeadlineFuture::new(self.inner.call(req), duration) diff --git a/src/middleware/payload.rs b/src/middleware/payload.rs new file mode 100644 index 0000000..4a9b4c0 --- /dev/null +++ b/src/middleware/payload.rs @@ -0,0 +1,180 @@ +use std::{ + future::{ready, Ready}, + sync::Arc, +}; + +use actix_web::{ + dev::{Service, ServiceRequest, Transform}, + http::Method, + HttpMessage, +}; +use streem::IntoStreamer; +use tokio::task::JoinSet; + +use crate::{future::NowOrNever, stream::LocalBoxStream}; + +const LIMIT: usize = 256; + +async fn drain(rx: flume::Receiver) { + let mut set = JoinSet::new(); + + while let Ok(payload) = rx.recv_async().await { + set.spawn_local(async move { + let mut streamer = payload.into_streamer(); + while streamer.next().await.is_some() {} + }); + + let mut count = 0; + + // drain completed tasks + while set.join_next().now_or_never().is_some() { + count += 1; + } + + // if we're past the limit, wait for completions + while set.len() > LIMIT { + if set.join_next().await.is_some() { + count += 1; + } + } + + if count > 0 { + tracing::info!("Drained {count} dropped payloads"); + } + } + + // drain set + while set.join_next().await.is_some() {} +} + +#[derive(Clone)] +struct DrainHandle(Option>>); + +pub(crate) struct Payload { + sender: flume::Sender, + handle: DrainHandle, +} +pub(crate) struct PayloadMiddleware { + inner: S, + sender: flume::Sender, + _handle: DrainHandle, +} + +pub(crate) struct PayloadStream { + inner: Option, + sender: flume::Sender, +} + +impl DrainHandle { + fn new(handle: actix_web::rt::task::JoinHandle<()>) -> Self { + Self(Some(Arc::new(handle))) + } +} + +impl Payload { + pub(crate) fn new() -> Self { + let (tx, rx) = crate::sync::channel(LIMIT); + + let handle = DrainHandle::new(crate::sync::spawn(async move { drain(rx).await })); + + Payload { sender: tx, handle } + } +} + +impl Drop for DrainHandle { + fn drop(&mut self) { + if let Some(handle) = self.0.take().and_then(Arc::into_inner) { + handle.abort(); + } + } +} + +impl Drop for PayloadStream { + fn drop(&mut self) { + if let Some(payload) = self.inner.take() { + tracing::warn!("Dropped unclosed payload, draining"); + if self.sender.try_send(payload).is_err() { + tracing::error!("Failed to send unclosed payload for draining"); + } + } + } +} + +impl futures_core::Stream for PayloadStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Some(inner) = self.inner.as_mut() { + let opt = std::task::ready!(std::pin::Pin::new(inner).poll_next(cx)); + + if opt.is_none() { + self.inner.take(); + } + + std::task::Poll::Ready(opt) + } else { + std::task::Poll::Ready(None) + } + } + + fn size_hint(&self) -> (usize, Option) { + if let Some(inner) = self.inner.as_ref() { + inner.size_hint() + } else { + (0, Some(0)) + } + } +} + +impl Transform for Payload +where + S: Service, + S::Future: 'static, +{ + type Response = S::Response; + type Error = S::Error; + type InitError = (); + type Transform = PayloadMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(PayloadMiddleware { + inner: service, + sender: self.sender.clone(), + _handle: self.handle.clone(), + })) + } +} + +impl Service for PayloadMiddleware +where + S: Service, + S::Future: 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready( + &self, + ctx: &mut core::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(ctx) + } + + fn call(&self, mut req: ServiceRequest) -> Self::Future { + if matches!(*req.method(), Method::POST | Method::PATCH | Method::PUT) { + let payload = req.take_payload(); + let payload: LocalBoxStream<'static, _> = Box::pin(PayloadStream { + inner: Some(payload), + sender: self.sender.clone(), + }); + req.set_payload(payload.into()); + } + + self.inner.call(req) + } +}