Fix poll after completion, misused compare_and_swap

This commit is contained in:
asonix 2023-08-14 21:17:57 -05:00
parent 26ca3a7195
commit 09236d731d
4 changed files with 154 additions and 73 deletions

View File

@ -278,9 +278,10 @@ where
+ Copy, + Copy,
{ {
loop { loop {
let (job_id, bytes) = repo.pop(queue).await?; let fut = async {
let (job_id, bytes) = repo.pop(queue, worker_id).await?;
let span = tracing::info_span!("Running Job", worker_id = ?worker_id); let span = tracing::info_span!("Running Job");
let guard = MetricsGuard::guard(worker_id, queue); let guard = MetricsGuard::guard(worker_id, queue);
@ -289,6 +290,7 @@ where
heartbeat( heartbeat(
repo, repo,
queue, queue,
worker_id,
job_id, job_id,
(callback)(repo, store, config, bytes.as_ref()), (callback)(repo, store, config, bytes.as_ref()),
) )
@ -296,11 +298,17 @@ where
.instrument(span) .instrument(span)
.await; .await;
repo.complete_job(queue, job_id).await?; repo.complete_job(queue, worker_id, job_id).await?;
res?; res?;
guard.disarm(); guard.disarm();
Ok(()) as Result<(), Error>
};
fut.instrument(tracing::info_span!("tick", worker_id = %worker_id))
.await?;
} }
} }
@ -361,9 +369,10 @@ where
+ Copy, + Copy,
{ {
loop { loop {
let (job_id, bytes) = repo.pop(queue).await?; let fut = async {
let (job_id, bytes) = repo.pop(queue, worker_id).await?;
let span = tracing::info_span!("Running Job", worker_id = ?worker_id); let span = tracing::info_span!("Running Job");
let guard = MetricsGuard::guard(worker_id, queue); let guard = MetricsGuard::guard(worker_id, queue);
@ -372,6 +381,7 @@ where
heartbeat( heartbeat(
repo, repo,
queue, queue,
worker_id,
job_id, job_id,
(callback)(repo, store, process_map, config, bytes.as_ref()), (callback)(repo, store, process_map, config, bytes.as_ref()),
) )
@ -379,20 +389,32 @@ where
.instrument(span) .instrument(span)
.await; .await;
repo.complete_job(queue, job_id).await?; repo.complete_job(queue, worker_id, job_id).await?;
res?; res?;
guard.disarm(); guard.disarm();
Ok(()) as Result<(), Error>
};
fut.instrument(tracing::info_span!("tick", worker_id = %worker_id))
.await?;
} }
} }
async fn heartbeat<R, Fut>(repo: &R, queue: &'static str, job_id: JobId, fut: Fut) -> Fut::Output async fn heartbeat<R, Fut>(
repo: &R,
queue: &'static str,
worker_id: uuid::Uuid,
job_id: JobId,
fut: Fut,
) -> Fut::Output
where where
R: QueueRepo, R: QueueRepo,
Fut: std::future::Future, Fut: std::future::Future,
{ {
let mut fut = std::pin::pin!(fut); let mut fut =
std::pin::pin!(fut.instrument(tracing::info_span!("job-future", job_id = ?job_id)));
let mut interval = actix_rt::time::interval(Duration::from_secs(5)); let mut interval = actix_rt::time::interval(Duration::from_secs(5));
@ -405,10 +427,12 @@ where
} }
_ = interval.tick() => { _ = interval.tick() => {
if hb.is_none() { if hb.is_none() {
hb = Some(repo.heartbeat(queue, job_id)); hb = Some(repo.heartbeat(queue, worker_id, job_id));
} }
} }
opt = poll_opt(hb.as_mut()), if hb.is_some() => { opt = poll_opt(hb.as_mut()), if hb.is_some() => {
hb.take();
if let Some(Err(e)) = opt { if let Some(Err(e)) = opt {
tracing::warn!("Failed heartbeat\n{}", format!("{e:?}")); tracing::warn!("Failed heartbeat\n{}", format!("{e:?}"));
} }
@ -423,6 +447,6 @@ where
{ {
match opt { match opt {
None => None, None => None,
Some(fut) => std::future::poll_fn(|cx| Pin::new(&mut *fut).poll(cx).map(Some)).await, Some(fut) => Some(fut.await),
} }
} }

View File

@ -73,13 +73,8 @@ where
errors.push(e); errors.push(e);
} }
if !errors.is_empty() {
let span = tracing::error_span!("Error deleting files");
span.in_scope(|| {
for error in errors { for error in errors {
tracing::error!("{}", format!("{error}")); tracing::error!("{}", format!("{error:?}"));
}
});
} }
Ok(()) Ok(())

View File

@ -296,11 +296,25 @@ impl JobId {
pub(crate) trait QueueRepo: BaseRepo { pub(crate) trait QueueRepo: BaseRepo {
async fn push(&self, queue: &'static str, job: Arc<[u8]>) -> Result<JobId, RepoError>; async fn push(&self, queue: &'static str, job: Arc<[u8]>) -> Result<JobId, RepoError>;
async fn pop(&self, queue: &'static str) -> Result<(JobId, Arc<[u8]>), RepoError>; async fn pop(
&self,
queue: &'static str,
worker_id: Uuid,
) -> Result<(JobId, Arc<[u8]>), RepoError>;
async fn heartbeat(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError>; async fn heartbeat(
&self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError>;
async fn complete_job(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError>; async fn complete_job(
&self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError>;
} }
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
@ -312,16 +326,30 @@ where
T::push(self, queue, job).await T::push(self, queue, job).await
} }
async fn pop(&self, queue: &'static str) -> Result<(JobId, Arc<[u8]>), RepoError> { async fn pop(
T::pop(self, queue).await &self,
queue: &'static str,
worker_id: Uuid,
) -> Result<(JobId, Arc<[u8]>), RepoError> {
T::pop(self, queue, worker_id).await
} }
async fn heartbeat(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError> { async fn heartbeat(
T::heartbeat(self, queue, job_id).await &self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
T::heartbeat(self, queue, worker_id, job_id).await
} }
async fn complete_job(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError> { async fn complete_job(
T::complete_job(self, queue, job_id).await &self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
T::complete_job(self, queue, worker_id, job_id).await
} }
} }

View File

@ -24,6 +24,7 @@ use std::{
}; };
use tokio::{sync::Notify, task::JoinHandle}; use tokio::{sync::Notify, task::JoinHandle};
use url::Url; use url::Url;
use uuid::Uuid;
macro_rules! b { macro_rules! b {
($self:ident.$ident:ident, $expr:expr) => {{ ($self:ident.$ident:ident, $expr:expr) => {{
@ -625,7 +626,7 @@ impl UploadRepo for SledRepo {
enum JobState { enum JobState {
Pending, Pending,
Running([u8; 8]), Running([u8; 24]),
} }
impl JobState { impl JobState {
@ -633,12 +634,26 @@ impl JobState {
Self::Pending Self::Pending
} }
fn running() -> Self { fn running(worker_id: Uuid) -> Self {
Self::Running( let first_eight = time::OffsetDateTime::now_utc()
time::OffsetDateTime::now_utc()
.unix_timestamp() .unix_timestamp()
.to_be_bytes(), .to_be_bytes();
)
let next_sixteen = worker_id.into_bytes();
let mut bytes = [0u8; 24];
bytes[0..8]
.iter_mut()
.zip(&first_eight)
.for_each(|(dest, src)| *dest = *src);
bytes[8..24]
.iter_mut()
.zip(&next_sixteen)
.for_each(|(dest, src)| *dest = *src);
Self::Running(bytes)
} }
fn as_bytes(&self) -> &[u8] { fn as_bytes(&self) -> &[u8] {
@ -703,8 +718,12 @@ impl QueueRepo for SledRepo {
Ok(id) Ok(id)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self, worker_id), fields(job_id))]
async fn pop(&self, queue_name: &'static str) -> Result<(JobId, Arc<[u8]>), RepoError> { async fn pop(
&self,
queue_name: &'static str,
worker_id: Uuid,
) -> Result<(JobId, Arc<[u8]>), RepoError> {
let metrics_guard = PopMetricsGuard::guard(queue_name); let metrics_guard = PopMetricsGuard::guard(queue_name);
let now = time::OffsetDateTime::now_utc(); let now = time::OffsetDateTime::now_utc();
@ -713,13 +732,15 @@ impl QueueRepo for SledRepo {
let queue = self.queue.clone(); let queue = self.queue.clone();
let job_state = self.job_state.clone(); let job_state = self.job_state.clone();
let span = tracing::Span::current();
let opt = actix_rt::task::spawn_blocking(move || { let opt = actix_rt::task::spawn_blocking(move || {
let _guard = span.enter();
// Job IDs are generated with Uuid version 7 - defining their first bits as a // Job IDs are generated with Uuid version 7 - defining their first bits as a
// timestamp. Scanning a prefix should give us jobs in the order they were queued. // timestamp. Scanning a prefix should give us jobs in the order they were queued.
for res in job_state.scan_prefix(queue_name) { for res in job_state.scan_prefix(queue_name) {
let (key, value) = res?; let (key, value) = res?;
if value.len() == 8 { if value.len() > 8 {
let unix_timestamp = let unix_timestamp =
i64::from_be_bytes(value[0..8].try_into().expect("Verified length")); i64::from_be_bytes(value[0..8].try_into().expect("Verified length"));
@ -734,13 +755,14 @@ impl QueueRepo for SledRepo {
} }
} }
let state = JobState::running(); let state = JobState::running(worker_id);
match job_state.compare_and_swap(&key, Some(value), Some(state.as_bytes())) { match job_state.compare_and_swap(&key, Some(value), Some(state.as_bytes()))? {
Ok(_) => { Ok(()) => {
// acquired job // acquired job
} }
Err(_) => { Err(_) => {
tracing::debug!("Contested");
// someone else acquired job // someone else acquired job
continue; continue;
} }
@ -752,6 +774,8 @@ impl QueueRepo for SledRepo {
let job_id = JobId::from_bytes(id_bytes); let job_id = JobId::from_bytes(id_bytes);
tracing::Span::current().record("job_id", &format!("{job_id:?}"));
let opt = queue let opt = queue
.get(&key)? .get(&key)?
.map(|job_bytes| (job_id, Arc::from(job_bytes.to_vec()))); .map(|job_bytes| (job_id, Arc::from(job_bytes.to_vec())));
@ -790,18 +814,23 @@ impl QueueRepo for SledRepo {
} }
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self, worker_id))]
async fn heartbeat(&self, queue_name: &'static str, job_id: JobId) -> Result<(), RepoError> { async fn heartbeat(
&self,
queue_name: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
let key = job_key(queue_name, job_id); let key = job_key(queue_name, job_id);
let job_state = self.job_state.clone(); let job_state = self.job_state.clone();
actix_rt::task::spawn_blocking(move || { actix_rt::task::spawn_blocking(move || {
if let Some(state) = job_state.get(&key)? { if let Some(state) = job_state.get(&key)? {
let new_state = JobState::running(); let new_state = JobState::running(worker_id);
match job_state.compare_and_swap(&key, Some(state), Some(new_state.as_bytes()))? { match job_state.compare_and_swap(&key, Some(state), Some(new_state.as_bytes()))? {
Ok(_) => Ok(()), Ok(()) => Ok(()),
Err(_) => Err(SledError::Conflict), Err(_) => Err(SledError::Conflict),
} }
} else { } else {
@ -814,8 +843,13 @@ impl QueueRepo for SledRepo {
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self, _worker_id))]
async fn complete_job(&self, queue_name: &'static str, job_id: JobId) -> Result<(), RepoError> { async fn complete_job(
&self,
queue_name: &'static str,
_worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
let key = job_key(queue_name, job_id); let key = job_key(queue_name, job_id);
let queue = self.queue.clone(); let queue = self.queue.clone();