From 04bde6cf20491b896d2dc620329d3a8956b71457 Mon Sep 17 00:00:00 2001 From: asonix Date: Mon, 9 Dec 2024 19:12:53 -0600 Subject: [PATCH] Increase concurrency when polling some streams Specifically when writing streams to files or processes' stdin. This adds the ability to poll the source stream during the write operations which can reduce time waiting for more bytes to write. --- src/file.rs | 17 ++++++---- src/process.rs | 19 ++++++++--- src/stream.rs | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 11 deletions(-) diff --git a/src/file.rs b/src/file.rs index ad39937..58b69e8 100644 --- a/src/file.rs +++ b/src/file.rs @@ -25,11 +25,13 @@ pub(crate) async fn write_from_stream( #[cfg(not(feature = "io-uring"))] mod tokio_file { - use crate::{future::WithPollTimer, store::file_store::FileError, Either}; + use crate::{ + future::WithPollTimer, store::file_store::FileError, stream::IntoProgressableStreamer, + Either, + }; use actix_web::web::{Bytes, BytesMut}; use futures_core::Stream; use std::{io::SeekFrom, path::Path}; - use streem::IntoStreamer; use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tokio_util::{ bytes::Buf, @@ -62,15 +64,18 @@ mod tokio_file { S: Stream>, { let stream = std::pin::pin!(stream); - let mut stream = stream.into_streamer(); + let mut stream = stream.into_progressable_streamer(); while let Some(mut bytes) = stream.try_next().with_poll_timer("try-next").await? { tracing::trace!("write_from_stream: looping"); while bytes.has_remaining() { - self.inner - .write_buf(&mut bytes) - .with_poll_timer("write-buf") + stream + .make_progress_with( + self.inner + .write_buf(&mut bytes) + .with_poll_timer("write-buf"), + ) .await?; crate::sync::cooperate().await; diff --git a/src/process.rs b/src/process.rs index 5d0b885..66fa9fc 100644 --- a/src/process.rs +++ b/src/process.rs @@ -7,7 +7,6 @@ use std::{ }; use futures_core::Stream; -use streem::IntoStreamer; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, process::{Child, ChildStdin, Command}, @@ -21,6 +20,7 @@ use crate::{ error_code::ErrorCode, future::{LocalBoxFuture, WithTimeout}, read::BoxRead, + stream::IntoProgressableStreamer, }; struct MetricsGuard { @@ -237,7 +237,7 @@ impl Process { Ok(Ok(status)) => Err(ProcessError::Status(command, status)), Ok(Err(e)) => Err(ProcessError::Other(command, e)), Err(_) => { - let _ = child.kill().await; + let _ = child.kill().with_timeout(Duration::from_secs(1)).await; Err(ProcessError::Timeout(command)) } } @@ -249,10 +249,17 @@ impl Process { { self.drive(move |mut stdin| async move { let stream = std::pin::pin!(input); - let mut stream = stream.into_streamer(); + let mut stream = stream.into_progressable_streamer(); - while let Some(mut bytes) = stream.try_next().await? { - match stdin.write_all_buf(&mut bytes).await { + while let Some(mut bytes) = stream + .try_next() + .with_timeout(Duration::from_secs(5)) + .await?? + { + match stream + .make_progress_with(stdin.write_all_buf(&mut bytes)) + .await + { Ok(()) => {} Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => break, Err(e) => return Err(e), @@ -306,7 +313,9 @@ impl Process { Err(_) => { child .kill() + .with_timeout(Duration::from_secs(1)) .await + .map_err(|_| ProcessError::Timeout(command2.clone()))? .map_err(|e| ProcessError::Other(command2.clone(), e))?; Err(ProcessError::Timeout(command2)) } diff --git a/src/stream.rs b/src/stream.rs index 43dbd6d..a83904f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -5,6 +5,96 @@ use streem::IntoStreamer; use crate::future::WithMetrics; +pub struct ProgressableStreamer { + inner: S, + next: Option>, +} + +impl ProgressableStreamer +where + S: Stream + Unpin, +{ + /// Produces the next item from the stream + /// + /// Cancel Safety + /// + /// This future is safe to drop and re-create before completion + async fn next(&mut self) -> Option { + std::future::poll_fn(|cx| { + if let Some(item) = self.next.take() { + std::task::Poll::Ready(item) + } else { + Pin::new(&mut self.inner).poll_next(cx) + } + }) + .await + } + + /// Polls the inner stream to make progress on the next item + /// + /// Cancel Safety + /// + /// This future is safe to drop and re-create before completion + pub async fn make_progress(&mut self) { + if self.next.is_none() { + self.next = + Some(std::future::poll_fn(|cx| Pin::new(&mut self.inner).poll_next(cx)).await); + } + } + + /// Polls the provided future along with the progress future + /// + /// Cancel Safety + /// + /// This future consumes the provided future, and does not ensure it's cancel safety. If the + /// provided future is also cancel safe, then this method will be cancel safe as well. + /// + /// If this future is polled to completion, the provided future is guaranteed to be polled to + /// completion as well + pub async fn make_progress_with(&mut self, f: F) -> F::Output { + let mut f = std::pin::pin!(f); + + loop { + tokio::select! { + _ = self.make_progress() => { + return f.await; + } + output = &mut f => { + return output; + } + } + } + } +} + +impl ProgressableStreamer> +where + S: Stream> + Unpin, +{ + /// Produces the next item from the stream + /// + /// Cancel Safety + /// + /// This future is safe to drop and re-create before completion + pub async fn try_next(&mut self) -> Result, E> { + self.next().await.transpose() + } +} + +pub trait IntoProgressableStreamer: Stream { + fn into_progressable_streamer(self) -> ProgressableStreamer + where + Self: Sized + Unpin, + { + ProgressableStreamer { + inner: self, + next: None, + } + } +} + +impl IntoProgressableStreamer for S where S: Stream {} + #[cfg(not(feature = "random-errors"))] pub(crate) fn error_injector( stream: impl Stream>,