Include not_found key in repo migration, rework ProcessRead, add timeout to 0.4 migration

This commit is contained in:
asonix 2023-12-17 23:14:39 -06:00
parent 3c6d676e51
commit c9155f7ce7
2 changed files with 192 additions and 64 deletions

View File

@ -5,7 +5,11 @@ use std::{
future::Future, future::Future,
pin::Pin, pin::Pin,
process::{ExitStatus, Stdio}, process::{ExitStatus, Stdio},
task::{Context, Poll}, sync::{
atomic::{AtomicU8, Ordering},
Arc, Mutex,
},
task::{Context, Poll, Wake, Waker},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use tokio::{ use tokio::{
@ -72,14 +76,23 @@ struct DropHandle {
inner: JoinHandle<()>, inner: JoinHandle<()>,
} }
pub(crate) struct ProcessRead<I> { struct ProcessReadState {
inner: I, flags: AtomicU8,
parent: Mutex<Option<Waker>>,
}
struct ProcessReadWaker {
state: Arc<ProcessReadState>,
flag: u8,
}
pub(crate) struct ProcessRead {
inner: ChildStdout,
err_recv: RecvFut<'static, std::io::Error>, err_recv: RecvFut<'static, std::io::Error>,
err_closed: bool,
#[allow(dead_code)]
handle: DropHandle, handle: DropHandle,
eof: bool, closed: bool,
sleep: Pin<Box<tokio::time::Sleep>>, state: Arc<ProcessReadState>,
span: Span,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -191,21 +204,21 @@ impl Process {
} }
} }
pub(crate) fn bytes_read(self, input: Bytes) -> ProcessRead<ChildStdout> { pub(crate) fn bytes_read(self, input: Bytes) -> ProcessRead {
self.spawn_fn(move |mut stdin| { self.spawn_fn(move |mut stdin| {
let mut input = input; let mut input = input;
async move { stdin.write_all_buf(&mut input).await } async move { stdin.write_all_buf(&mut input).await }
}) })
} }
pub(crate) fn read(self) -> ProcessRead<ChildStdout> { pub(crate) fn read(self) -> ProcessRead {
self.spawn_fn(|_| async { Ok(()) }) self.spawn_fn(|_| async { Ok(()) })
} }
#[allow(unknown_lints)] #[allow(unknown_lints)]
#[allow(clippy::let_with_type_underscore)] #[allow(clippy::let_with_type_underscore)]
#[tracing::instrument(level = "trace", skip_all)] #[tracing::instrument(level = "trace", skip_all)]
fn spawn_fn<F, Fut>(self, f: F) -> ProcessRead<ChildStdout> fn spawn_fn<F, Fut>(self, f: F) -> ProcessRead
where where
F: FnOnce(ChildStdin) -> Fut + 'static, F: FnOnce(ChildStdin) -> Fut + 'static,
Fut: Future<Output = std::io::Result<()>>, Fut: Future<Output = std::io::Result<()>>,
@ -223,7 +236,11 @@ impl Process {
let (tx, rx) = crate::sync::channel::<std::io::Error>(1); let (tx, rx) = crate::sync::channel::<std::io::Error>(1);
let rx = rx.into_recv_async(); let rx = rx.into_recv_async();
let span = tracing::info_span!(parent: None, "Background process task", %command); let background_span =
tracing::info_span!(parent: None, "Background process task", %command);
background_span.follows_from(Span::current());
let span = tracing::info_span!(parent: None, "Foreground process task", %command);
span.follows_from(Span::current()); span.follows_from(Span::current());
let handle = crate::sync::spawn( let handle = crate::sync::spawn(
@ -250,81 +267,133 @@ impl Process {
let _ = tx.send(error); let _ = tx.send(error);
let _ = child.kill().await; let _ = child.kill().await;
} }
.instrument(span), .instrument(background_span),
); );
let sleep = tokio::time::sleep(timeout);
ProcessRead { ProcessRead {
inner: stdout, inner: stdout,
err_recv: rx, err_recv: rx,
err_closed: false,
handle: DropHandle { inner: handle }, handle: DropHandle { inner: handle },
eof: false, closed: false,
sleep: Box::pin(sleep), state: ProcessReadState::new_woken(),
span,
} }
} }
} }
impl<I> AsyncRead for ProcessRead<I> impl ProcessReadState {
where fn new_woken() -> Arc<Self> {
I: AsyncRead + Unpin, Arc::new(Self {
{ flags: AtomicU8::new(0xff),
parent: Mutex::new(None),
})
}
fn clone_parent(&self) -> Option<Waker> {
let guard = self.parent.lock().unwrap();
guard.as_ref().map(|w| w.clone())
}
fn into_parts(self) -> (AtomicU8, Option<Waker>) {
let ProcessReadState { flags, parent } = self;
let parent = parent.lock().unwrap().take();
(flags, parent)
}
}
impl ProcessRead {
fn get_waker(&self, flag: u8) -> Option<Waker> {
let mask = 0xff ^ flag;
let previous = self.state.flags.fetch_and(mask, Ordering::AcqRel);
let active = previous & flag;
if active == flag {
Some(
Arc::new(ProcessReadWaker {
state: self.state.clone(),
flag,
})
.into(),
)
} else {
None
}
}
fn set_parent_waker(&self, parent: &Waker) {
let mut guard = self.state.parent.lock().unwrap();
if let Some(waker) = guard.as_mut() {
if !waker.will_wake(parent) {
*waker = parent.clone();
}
} else {
*guard = Some(parent.clone());
}
}
}
const RECV_WAKER: u8 = 0b_0010;
const HANDLE_WAKER: u8 = 0b_0100;
impl AsyncRead for ProcessRead {
fn poll_read( fn poll_read(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> { ) -> Poll<std::io::Result<()>> {
if !self.err_closed { self.set_parent_waker(cx.waker());
if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(cx) {
self.err_closed = true;
if let Ok(err) = res { let span = self.span.clone();
return Poll::Ready(Err(err)); let guard = span.enter();
}
if self.eof { let value = loop {
return Poll::Ready(Ok(())); // always poll for bytes when poll_read is called
}
}
if let Poll::Ready(()) = self.sleep.as_mut().poll(cx) {
self.err_closed = true;
return Poll::Ready(Err(std::io::ErrorKind::TimedOut.into()));
}
}
if !self.eof {
let before_size = buf.filled().len(); let before_size = buf.filled().len();
return match Pin::new(&mut self.inner).poll_read(cx, buf) { if let Poll::Ready(res) = Pin::new(&mut self.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => { if let Err(e) = res {
if buf.filled().len() == before_size { self.closed = true;
self.eof = true;
if !self.err_closed { break Poll::Ready(Err(e));
// reached end of stream & haven't received process signal } else if buf.filled().len() == before_size {
return Poll::Pending; self.closed = true;
}
break Poll::Ready(Ok(()));
} else {
break Poll::Ready(Ok(()));
}
} else if let Some(waker) = self.get_waker(RECV_WAKER) {
// only poll recv if we've been explicitly woken
let mut recv_cx = Context::from_waker(&waker);
if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(&mut recv_cx) {
if let Ok(err) = res {
self.closed = true;
break Poll::Ready(Err(err));
} }
Poll::Ready(Ok(()))
} }
Poll::Ready(Err(e)) => { } else if let Some(waker) = self.get_waker(HANDLE_WAKER) {
self.eof = true; // only poll handle if we've been explicitly woken
let mut handle_cx = Context::from_waker(&waker);
Poll::Ready(Err(e)) if let Poll::Ready(res) = Pin::new(&mut self.handle.inner).poll(&mut handle_cx) {
if let Err(e) = res {
self.closed = true;
break Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
}
} }
Poll::Pending => Poll::Pending, } else if self.closed {
}; break Poll::Ready(Ok(()));
} } else {
break Poll::Pending;
}
};
if self.err_closed && self.eof { drop(guard);
return Poll::Ready(Ok(()));
}
Poll::Pending value
} }
} }
@ -334,6 +403,40 @@ impl Drop for DropHandle {
} }
} }
impl Wake for ProcessReadWaker {
fn wake(self: Arc<Self>) {
match Arc::try_unwrap(self) {
Ok(ProcessReadWaker { state, flag }) => match Arc::try_unwrap(state) {
Ok(state) => {
let (flags, parent) = state.into_parts();
flags.fetch_and(flag, Ordering::AcqRel);
if let Some(parent) = parent {
parent.wake();
}
}
Err(state) => {
state.flags.fetch_or(flag, Ordering::AcqRel);
if let Some(waker) = state.clone_parent() {
waker.wake();
}
}
},
Err(this) => this.wake_by_ref(),
}
}
fn wake_by_ref(self: &Arc<Self>) {
self.state.flags.fetch_or(self.flag, Ordering::AcqRel);
if let Some(parent) = self.state.clone_parent() {
parent.wake();
}
}
}
impl std::fmt::Display for StatusError { impl std::fmt::Display for StatusError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Command failed with bad status: {}", self.0) write!(f, "Command failed with bad status: {}", self.0)

View File

@ -1,4 +1,7 @@
use std::sync::{Arc, OnceLock}; use std::{
sync::{Arc, OnceLock},
time::Duration,
};
use streem::IntoStreamer; use streem::IntoStreamer;
use tokio::{sync::Semaphore, task::JoinSet}; use tokio::{sync::Semaphore, task::JoinSet};
@ -6,7 +9,7 @@ use tokio::{sync::Semaphore, task::JoinSet};
use crate::{ use crate::{
config::Configuration, config::Configuration,
details::Details, details::Details,
error::Error, error::{Error, UploadError},
repo::{ArcRepo, DeleteToken, Hash}, repo::{ArcRepo, DeleteToken, Hash},
repo_04::{ repo_04::{
AliasRepo as _, HashRepo as _, IdentifierRepo as _, SettingsRepo as _, AliasRepo as _, HashRepo as _, IdentifierRepo as _, SettingsRepo as _,
@ -41,7 +44,7 @@ pub(crate) async fn migrate_repo(old_repo: ArcRepo, new_repo: ArcRepo) -> Result
let mut index = 0; let mut index = 0;
while let Some(res) = hash_stream.next().await { while let Some(res) = hash_stream.next().await {
if let Ok(hash) = res { if let Ok(hash) = res {
let _ = migrate_hash(old_repo.clone(), new_repo.clone(), hash).await; migrate_hash(old_repo.clone(), new_repo.clone(), hash).await;
} else { } else {
tracing::warn!("Failed to read hash, skipping"); tracing::warn!("Failed to read hash, skipping");
} }
@ -61,6 +64,12 @@ pub(crate) async fn migrate_repo(old_repo: ArcRepo, new_repo: ArcRepo) -> Result
.await?; .await?;
} }
if let Some(generator_state) = old_repo.get(crate::NOT_FOUND_KEY).await? {
new_repo
.set(crate::NOT_FOUND_KEY, generator_state.to_vec().into())
.await?;
}
tracing::info!("Migration complete"); tracing::info!("Migration complete");
Ok(()) Ok(())
@ -181,7 +190,7 @@ async fn migrate_hash_04<S: Store>(
) { ) {
let mut hash_failures = 0; let mut hash_failures = 0;
while let Err(e) = do_migrate_hash_04( while let Err(e) = timed_migrate_hash_04(
&tmp_dir, &tmp_dir,
&old_repo, &old_repo,
&new_repo, &new_repo,
@ -275,6 +284,22 @@ async fn do_migrate_hash(old_repo: &ArcRepo, new_repo: &ArcRepo, hash: Hash) ->
Ok(()) Ok(())
} }
async fn timed_migrate_hash_04<S: Store>(
tmp_dir: &TmpDir,
old_repo: &OldSledRepo,
new_repo: &ArcRepo,
store: &S,
config: &Configuration,
old_hash: sled::IVec,
) -> Result<(), Error> {
tokio::time::timeout(
Duration::from_secs(config.media.external_validation_timeout * 6),
do_migrate_hash_04(tmp_dir, old_repo, new_repo, store, config, old_hash),
)
.await
.map_err(|_| UploadError::ProcessTimeout)?
}
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn do_migrate_hash_04<S: Store>( async fn do_migrate_hash_04<S: Store>(
tmp_dir: &TmpDir, tmp_dir: &TmpDir,