use std::{ future::{ready, Ready}, rc::Rc, }; use actix_web::{ dev::{Service, ServiceRequest, Transform}, HttpMessage, }; use streem::IntoStreamer; use tokio::task::JoinSet; use crate::{future::NowOrNever, stream::LocalBoxStream}; const LIMIT: usize = 256; async fn drain(rx: flume::Receiver) { let mut set = JoinSet::new(); while let Ok(payload) = rx.recv_async().await { set.spawn_local(async move { let mut streamer = payload.into_streamer(); while streamer.next().await.is_some() {} }); let mut count = 0; // drain completed tasks while set.join_next().now_or_never().is_some() { count += 1; } // if we're past the limit, wait for completions while set.len() > LIMIT { if set.join_next().await.is_some() { count += 1; } } if count > 0 { tracing::info!("Drained {count} dropped payloads"); } } // drain set while set.join_next().await.is_some() {} } #[derive(Clone)] struct DrainHandle(Option>>); pub(crate) struct Payload { sender: flume::Sender, handle: DrainHandle, } pub(crate) struct PayloadMiddleware { inner: S, sender: flume::Sender, _handle: DrainHandle, } pub(crate) struct PayloadStream { inner: Option, sender: flume::Sender, } impl DrainHandle { fn new(handle: tokio::task::JoinHandle<()>) -> Self { Self(Some(Rc::new(handle))) } } impl Payload { pub(crate) fn new() -> Self { let (tx, rx) = crate::sync::channel(LIMIT); let handle = DrainHandle::new(crate::sync::spawn("drain-payloads", drain(rx))); Payload { sender: tx, handle } } } impl Drop for DrainHandle { fn drop(&mut self) { if let Some(handle) = self.0.take().and_then(Rc::into_inner) { handle.abort(); } } } impl Drop for PayloadStream { fn drop(&mut self) { if let Some(payload) = self.inner.take() { tracing::warn!("Dropped unclosed payload, draining"); if self.sender.try_send(payload).is_err() { tracing::error!("Failed to send unclosed payload for draining"); } } } } impl futures_core::Stream for PayloadStream { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { if let Some(inner) = self.inner.as_mut() { let opt = std::task::ready!(std::pin::Pin::new(inner).poll_next(cx)); if opt.is_none() { self.inner.take(); } std::task::Poll::Ready(opt) } else { std::task::Poll::Ready(None) } } fn size_hint(&self) -> (usize, Option) { if let Some(inner) = self.inner.as_ref() { inner.size_hint() } else { (0, Some(0)) } } } impl Transform for Payload where S: Service, S::Future: 'static, { type Response = S::Response; type Error = S::Error; type InitError = (); type Transform = PayloadMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(PayloadMiddleware { inner: service, sender: self.sender.clone(), _handle: self.handle.clone(), })) } } impl Service for PayloadMiddleware where S: Service, S::Future: 'static, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready( &self, ctx: &mut core::task::Context<'_>, ) -> std::task::Poll> { self.inner.poll_ready(ctx) } fn call(&self, mut req: ServiceRequest) -> Self::Future { let payload = req.take_payload(); if !matches!(payload, actix_web::dev::Payload::None) { let payload: LocalBoxStream<'static, _> = Box::pin(PayloadStream { inner: Some(payload), sender: self.sender.clone(), }); req.set_payload(payload.into()); } self.inner.call(req) } }