use actix_web::web::Bytes; use futures::stream::{LocalBoxStream, Stream, StreamExt}; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf}; use tokio_stream::wrappers::ReceiverStream; pub(crate) struct ReadAdapter { inner: S, } pub(crate) struct Process { child: tokio::process::Child, } pub(crate) struct ProcessRead { inner: I, err_recv: tokio::sync::oneshot::Receiver, err_closed: bool, } pub(crate) struct ProcessSink { stdin: tokio::process::ChildStdin, } pub(crate) struct ProcessStream { stream: LocalBoxStream<'static, std::io::Result>, } pub(crate) struct ProcessSinkStream { stream: LocalBoxStream<'static, Result>, } pub(crate) struct TryDuplicateStream { inner: ReceiverStream>, } #[derive(Debug)] pub(crate) struct StringError(String); impl ReadAdapter { pub(crate) fn new_unsync( mut stream: S, ) -> ReadAdapter>> where S: Stream> + Unpin + 'static, E: std::fmt::Display, { let (tx, rx) = tokio::sync::mpsc::channel(1); actix_rt::spawn(async move { while let Some(res) = stream.next().await { if tx .send(res.map_err(|e| StringError(e.to_string()))) .await .is_err() { break; } } }); ReadAdapter::new_sync(ReceiverStream::new(rx)) } fn new_sync(stream: S) -> Self where S: Stream> + Unpin, E: Into>, { ReadAdapter { inner: stream } } } impl Process { fn new(child: tokio::process::Child) -> Self { Process { child } } pub(crate) fn spawn(cmd: &mut tokio::process::Command) -> std::io::Result { cmd.stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .spawn() .map(Process::new) } pub(crate) fn take_sink(&mut self) -> Option { self.child.stdin.take().map(ProcessSink::new) } pub(crate) fn take_stream(&mut self) -> Option { self.child.stdout.take().map(ProcessStream::new) } pub(crate) fn bytes_read(mut self, mut input: Bytes) -> Option { let mut stdin = self.child.stdin.take()?; let stdout = self.child.stdout.take()?; let (tx, rx) = tokio::sync::oneshot::channel(); actix_rt::spawn(async move { if let Err(e) = stdin.write_all_buf(&mut input).await { let _ = tx.send(e); } }); Some(Box::pin(ProcessRead { inner: stdout, err_recv: rx, err_closed: false, })) } pub(crate) fn write_read( mut self, mut input_reader: impl AsyncRead + Unpin + 'static, ) -> Option { let mut stdin = self.child.stdin.take()?; let stdout = self.child.stdout.take()?; let (tx, rx) = tokio::sync::oneshot::channel(); actix_rt::spawn(async move { if let Err(e) = tokio::io::copy(&mut input_reader, &mut stdin).await { let _ = tx.send(e); } }); Some(Box::pin(ProcessRead { inner: stdout, err_recv: rx, err_closed: false, })) } pub(crate) fn sink_stream(mut self, input_stream: S) -> Option> where S: Stream> + Unpin + 'static, E: From + 'static, { let mut stdin = self.take_sink()?; let mut stdout = self.take_stream()?; let (tx, mut rx) = tokio::sync::mpsc::channel(1); actix_rt::spawn(async move { if let Err(e) = stdin.send(input_stream).await { let _ = tx.send(e).await; } }); Some(ProcessSinkStream { stream: Box::pin(async_stream::stream! { loop { tokio::select! { opt = rx.recv() => { if let Some(e) = opt { yield Err(e); break; } } res = stdout.next() => { match res { Some(Ok(bytes)) => yield Ok(bytes), Some(Err(e)) => { yield Err(e.into()); break; } None => break, } } } } drop(stdout); match self.child.wait().await { Ok(status) if status.success() => return, Ok(_) => yield Err(std::io::Error::from(std::io::ErrorKind::Other).into()), Err(e) => yield Err(e.into()), } }), }) } } impl ProcessSink { fn new(stdin: tokio::process::ChildStdin) -> Self { ProcessSink { stdin } } pub(crate) async fn send(&mut self, mut stream: S) -> Result<(), E> where S: Stream> + Unpin, E: From, { while let Some(res) = stream.next().await { let mut bytes = res?; self.stdin.write_all_buf(&mut bytes).await?; } Ok(()) } } impl ProcessStream { fn new(mut stdout: tokio::process::ChildStdout) -> ProcessStream { let s = async_stream::stream! { loop { let mut buf = actix_web::web::BytesMut::with_capacity(65_536); match stdout.read_buf(&mut buf).await { Ok(len) if len == 0 => { break; } Ok(_) => { yield Ok(buf.freeze()); } Err(e) => { yield Err(e); break; } } } }; ProcessStream { stream: Box::pin(s), } } } pub(crate) fn try_duplicate( mut stream: S, buffer: usize, ) -> (impl Stream>, TryDuplicateStream) where S: Stream> + Unpin, T: Clone, { let (tx, rx) = tokio::sync::mpsc::channel(buffer); let s = async_stream::stream! { while let Some(value) = stream.next().await { match value { Ok(t) => { let _ = tx.send(Ok(t.clone())).await; yield Ok(t); } Err(e) => yield Err(e), } } }; ( s, TryDuplicateStream { inner: ReceiverStream::new(rx), }, ) } impl AsyncRead for ReadAdapter where S: Stream> + Unpin, E: Into>, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match Pin::new(&mut self.inner).poll_next(cx) { Poll::Ready(Some(Ok(bytes))) => { buf.put_slice(&bytes); Poll::Ready(Ok(())) } Poll::Ready(None) => Poll::Ready(Ok(())), Poll::Ready(Some(Err(e))) => { Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))) } Poll::Pending => Poll::Pending, } } } impl AsyncRead for ProcessRead where I: AsyncRead + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { if !self.err_closed { if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(cx) { self.err_closed = true; if let Ok(err) = res { return Poll::Ready(Err(err)); } } } if let Poll::Ready(res) = Pin::new(&mut self.inner).poll_read(cx, buf) { return Poll::Ready(res); } Poll::Pending } } impl Stream for ProcessStream { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_next(cx) } } impl Stream for ProcessSinkStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_next(cx) } } impl Stream for TryDuplicateStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_next(cx) } } impl std::fmt::Display for StringError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", self.0) } } impl std::error::Error for StringError {}