diff --git a/src/file.rs b/src/file.rs index ad39937..58b69e8 100644 --- a/src/file.rs +++ b/src/file.rs @@ -25,11 +25,13 @@ pub(crate) async fn write_from_stream( #[cfg(not(feature = "io-uring"))] mod tokio_file { - use crate::{future::WithPollTimer, store::file_store::FileError, Either}; + use crate::{ + future::WithPollTimer, store::file_store::FileError, stream::IntoProgressableStreamer, + Either, + }; use actix_web::web::{Bytes, BytesMut}; use futures_core::Stream; use std::{io::SeekFrom, path::Path}; - use streem::IntoStreamer; use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tokio_util::{ bytes::Buf, @@ -62,15 +64,18 @@ mod tokio_file { S: Stream>, { let stream = std::pin::pin!(stream); - let mut stream = stream.into_streamer(); + let mut stream = stream.into_progressable_streamer(); while let Some(mut bytes) = stream.try_next().with_poll_timer("try-next").await? { tracing::trace!("write_from_stream: looping"); while bytes.has_remaining() { - self.inner - .write_buf(&mut bytes) - .with_poll_timer("write-buf") + stream + .make_progress_with( + self.inner + .write_buf(&mut bytes) + .with_poll_timer("write-buf"), + ) .await?; crate::sync::cooperate().await; diff --git a/src/process.rs b/src/process.rs index 5d0b885..66fa9fc 100644 --- a/src/process.rs +++ b/src/process.rs @@ -7,7 +7,6 @@ use std::{ }; use futures_core::Stream; -use streem::IntoStreamer; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, process::{Child, ChildStdin, Command}, @@ -21,6 +20,7 @@ use crate::{ error_code::ErrorCode, future::{LocalBoxFuture, WithTimeout}, read::BoxRead, + stream::IntoProgressableStreamer, }; struct MetricsGuard { @@ -237,7 +237,7 @@ impl Process { Ok(Ok(status)) => Err(ProcessError::Status(command, status)), Ok(Err(e)) => Err(ProcessError::Other(command, e)), Err(_) => { - let _ = child.kill().await; + let _ = child.kill().with_timeout(Duration::from_secs(1)).await; Err(ProcessError::Timeout(command)) } } @@ -249,10 +249,17 @@ impl Process { { self.drive(move |mut stdin| async move { let stream = std::pin::pin!(input); - let mut stream = stream.into_streamer(); + let mut stream = stream.into_progressable_streamer(); - while let Some(mut bytes) = stream.try_next().await? { - match stdin.write_all_buf(&mut bytes).await { + while let Some(mut bytes) = stream + .try_next() + .with_timeout(Duration::from_secs(5)) + .await?? + { + match stream + .make_progress_with(stdin.write_all_buf(&mut bytes)) + .await + { Ok(()) => {} Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => break, Err(e) => return Err(e), @@ -306,7 +313,9 @@ impl Process { Err(_) => { child .kill() + .with_timeout(Duration::from_secs(1)) .await + .map_err(|_| ProcessError::Timeout(command2.clone()))? .map_err(|e| ProcessError::Other(command2.clone(), e))?; Err(ProcessError::Timeout(command2)) } diff --git a/src/stream.rs b/src/stream.rs index 43dbd6d..a83904f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -5,6 +5,96 @@ use streem::IntoStreamer; use crate::future::WithMetrics; +pub struct ProgressableStreamer { + inner: S, + next: Option>, +} + +impl ProgressableStreamer +where + S: Stream + Unpin, +{ + /// Produces the next item from the stream + /// + /// Cancel Safety + /// + /// This future is safe to drop and re-create before completion + async fn next(&mut self) -> Option { + std::future::poll_fn(|cx| { + if let Some(item) = self.next.take() { + std::task::Poll::Ready(item) + } else { + Pin::new(&mut self.inner).poll_next(cx) + } + }) + .await + } + + /// Polls the inner stream to make progress on the next item + /// + /// Cancel Safety + /// + /// This future is safe to drop and re-create before completion + pub async fn make_progress(&mut self) { + if self.next.is_none() { + self.next = + Some(std::future::poll_fn(|cx| Pin::new(&mut self.inner).poll_next(cx)).await); + } + } + + /// Polls the provided future along with the progress future + /// + /// Cancel Safety + /// + /// This future consumes the provided future, and does not ensure it's cancel safety. If the + /// provided future is also cancel safe, then this method will be cancel safe as well. + /// + /// If this future is polled to completion, the provided future is guaranteed to be polled to + /// completion as well + pub async fn make_progress_with(&mut self, f: F) -> F::Output { + let mut f = std::pin::pin!(f); + + loop { + tokio::select! { + _ = self.make_progress() => { + return f.await; + } + output = &mut f => { + return output; + } + } + } + } +} + +impl ProgressableStreamer> +where + S: Stream> + Unpin, +{ + /// Produces the next item from the stream + /// + /// Cancel Safety + /// + /// This future is safe to drop and re-create before completion + pub async fn try_next(&mut self) -> Result, E> { + self.next().await.transpose() + } +} + +pub trait IntoProgressableStreamer: Stream { + fn into_progressable_streamer(self) -> ProgressableStreamer + where + Self: Sized + Unpin, + { + ProgressableStreamer { + inner: self, + next: None, + } + } +} + +impl IntoProgressableStreamer for S where S: Stream {} + #[cfg(not(feature = "random-errors"))] pub(crate) fn error_injector( stream: impl Stream>,