diff --git a/src/process.rs b/src/process.rs index 46c2023..ab84d41 100644 --- a/src/process.rs +++ b/src/process.rs @@ -1,5 +1,4 @@ use actix_web::web::Bytes; -use flume::r#async::RecvFut; use std::{ ffi::OsStr, future::Future, @@ -73,7 +72,7 @@ impl std::fmt::Debug for Process { } struct DropHandle { - inner: JoinHandle<()>, + inner: JoinHandle>, } struct ProcessReadState { @@ -88,7 +87,6 @@ struct ProcessReadWaker { pub(crate) struct ProcessRead { inner: ChildStdout, - err_recv: RecvFut<'static, std::io::Error>, handle: DropHandle, closed: bool, state: Arc, @@ -233,9 +231,6 @@ impl Process { let stdin = child.stdin.take().expect("stdin exists"); let stdout = child.stdout.take().expect("stdout exists"); - let (tx, rx) = crate::sync::channel::(1); - let rx = rx.into_recv_async(); - let background_span = tracing::info_span!(parent: None, "Background process task", %command); background_span.follows_from(Span::current()); @@ -255,7 +250,7 @@ impl Process { let error = match child_fut.with_timeout(timeout).await { Ok(Ok(status)) if status.success() => { guard.disarm(); - return; + return Ok(()); } Ok(Ok(status)) => { std::io::Error::new(std::io::ErrorKind::Other, StatusError(status)) @@ -264,15 +259,15 @@ impl Process { Err(_) => std::io::ErrorKind::TimedOut.into(), }; - let _ = tx.send(error); - let _ = child.kill().await; + child.kill().await?; + + Err(error) } .instrument(background_span), ); ProcessRead { inner: stdout, - err_recv: rx, handle: DropHandle { inner: handle }, closed: false, state: ProcessReadState::new_woken(), @@ -291,7 +286,7 @@ impl ProcessReadState { fn clone_parent(&self) -> Option { let guard = self.parent.lock().unwrap(); - guard.as_ref().map(|w| w.clone()) + guard.as_ref().cloned() } fn into_parts(self) -> (AtomicU8, Option) { @@ -322,19 +317,26 @@ impl ProcessRead { } } - fn set_parent_waker(&self, parent: &Waker) { + fn set_parent_waker(&self, parent: &Waker) -> bool { let mut guard = self.state.parent.lock().unwrap(); if let Some(waker) = guard.as_mut() { if !waker.will_wake(parent) { *waker = parent.clone(); + true + } else { + false } } else { *guard = Some(parent.clone()); + true } } + + fn mark_all_woken(&self) { + self.state.flags.store(0xff, Ordering::Release); + } } -const RECV_WAKER: u8 = 0b_0010; const HANDLE_WAKER: u8 = 0b_0100; impl AsyncRead for ProcessRead { @@ -343,8 +345,6 @@ impl AsyncRead for ProcessRead { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.set_parent_waker(cx.waker()); - let span = self.span.clone(); let guard = span.enter(); @@ -364,29 +364,30 @@ impl AsyncRead for ProcessRead { } else { break Poll::Ready(Ok(())); } - } else if let Some(waker) = self.get_waker(RECV_WAKER) { - // only poll recv if we've been explicitly woken - let mut recv_cx = Context::from_waker(&waker); - - if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(&mut recv_cx) { - if let Ok(err) = res { - self.closed = true; - break Poll::Ready(Err(err)); - } - } } else if let Some(waker) = self.get_waker(HANDLE_WAKER) { // only poll handle if we've been explicitly woken let mut handle_cx = Context::from_waker(&waker); if let Poll::Ready(res) = Pin::new(&mut self.handle.inner).poll(&mut handle_cx) { - if let Err(e) = res { - self.closed = true; - break Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))); - } + let error = match res { + Ok(Ok(())) => continue, + Ok(Err(e)) => e, + Err(e) => std::io::Error::new(std::io::ErrorKind::Other, e), + }; + + self.closed = true; + break Poll::Ready(Err(error)); } } else if self.closed { + // Stop if we're closed break Poll::Ready(Ok(())); + } else if self.set_parent_waker(cx.waker()) { + // if we updated the stored waker, mark all as woken an try polling again + // This doesn't actually "wake" the waker, it just allows the handle to be polled + // again next iteration + self.mark_all_woken(); } else { + // if the waker hasn't changed and nothing polled ready, return pending break Poll::Pending; } };