Change IdentifierAccess to VariantAccess

This commit is contained in:
asonix 2023-07-22 17:57:52 -05:00
parent e7141c0533
commit b786406ad0
2 changed files with 130 additions and 64 deletions

View File

@ -74,7 +74,7 @@ pub(crate) trait FullRepo:
+ HashRepo + HashRepo
+ MigrationRepo + MigrationRepo
+ AliasAccessRepo + AliasAccessRepo
+ IdentifierAccessRepo + VariantAccessRepo
+ Send + Send
+ Sync + Sync
+ Clone + Clone
@ -183,41 +183,39 @@ where
} }
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
pub(crate) trait IdentifierAccessRepo: BaseRepo { pub(crate) trait VariantAccessRepo: BaseRepo {
type IdentifierAccessStream<I>: Stream<Item = Result<I, StoreError>> type VariantAccessStream: Stream<Item = Result<(Self::Bytes, String), RepoError>>;
where
I: Identifier;
async fn accessed<I: Identifier>(&self, identifier: I) -> Result<(), StoreError>; async fn accessed(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError>;
async fn older_identifiers<I: Identifier>( async fn older_variants(
&self, &self,
timestamp: time::OffsetDateTime, timestamp: time::OffsetDateTime,
) -> Result<Self::IdentifierAccessStream<I>, RepoError>; ) -> Result<Self::VariantAccessStream, RepoError>;
async fn remove<I: Identifier>(&self, identifier: I) -> Result<(), StoreError>; async fn remove(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError>;
} }
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
impl<T> IdentifierAccessRepo for actix_web::web::Data<T> impl<T> VariantAccessRepo for actix_web::web::Data<T>
where where
T: IdentifierAccessRepo, T: VariantAccessRepo,
{ {
type IdentifierAccessStream<I> = T::IdentifierAccessStream<I> where I: Identifier; type VariantAccessStream = T::VariantAccessStream;
async fn accessed<I: Identifier>(&self, identifier: I) -> Result<(), StoreError> { async fn accessed(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> {
T::accessed(self, identifier).await T::accessed(self, hash, variant).await
} }
async fn older_identifiers<I: Identifier>( async fn older_variants(
&self, &self,
timestamp: time::OffsetDateTime, timestamp: time::OffsetDateTime,
) -> Result<Self::IdentifierAccessStream<I>, RepoError> { ) -> Result<Self::VariantAccessStream, RepoError> {
T::older_identifiers(self, timestamp).await T::older_variants(self, timestamp).await
} }
async fn remove<I: Identifier>(&self, identifier: I) -> Result<(), StoreError> { async fn remove(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> {
T::remove(self, identifier).await T::remove(self, hash, variant).await
} }
} }

View File

@ -13,7 +13,6 @@ use futures_util::{Future, Stream};
use sled::{CompareAndSwapError, Db, IVec, Tree}; use sled::{CompareAndSwapError, Db, IVec, Tree};
use std::{ use std::{
collections::HashMap, collections::HashMap,
marker::PhantomData,
path::PathBuf, path::PathBuf,
pin::Pin, pin::Pin,
sync::{ sync::{
@ -23,7 +22,7 @@ use std::{
}; };
use tokio::{sync::Notify, task::JoinHandle}; use tokio::{sync::Notify, task::JoinHandle};
use super::{AliasAccessRepo, IdentifierAccessRepo, RepoError}; use super::{AliasAccessRepo, RepoError, VariantAccessRepo};
macro_rules! b { macro_rules! b {
($self:ident.$ident:ident, $expr:expr) => {{ ($self:ident.$ident:ident, $expr:expr) => {{
@ -51,6 +50,9 @@ pub(crate) enum SledError {
#[error("Error formatting timestamp")] #[error("Error formatting timestamp")]
Format(#[source] time::error::Format), Format(#[source] time::error::Format),
#[error("Error parsing variant key")]
VariantKey(#[from] VariantKeyError),
#[error("Operation panicked")] #[error("Operation panicked")]
Panic, Panic,
} }
@ -72,8 +74,8 @@ pub(crate) struct SledRepo {
queue: Tree, queue: Tree,
alias_access: Tree, alias_access: Tree,
inverse_alias_access: Tree, inverse_alias_access: Tree,
identifier_access: Tree, variant_access: Tree,
inverse_identifier_access: Tree, inverse_variant_access: Tree,
in_progress_queue: Tree, in_progress_queue: Tree,
queue_notifier: Arc<RwLock<HashMap<&'static str, Arc<Notify>>>>, queue_notifier: Arc<RwLock<HashMap<&'static str, Arc<Notify>>>>,
uploads: Tree, uploads: Tree,
@ -108,8 +110,8 @@ impl SledRepo {
queue: db.open_tree("pict-rs-queue-tree")?, queue: db.open_tree("pict-rs-queue-tree")?,
alias_access: db.open_tree("pict-rs-alias-access-tree")?, alias_access: db.open_tree("pict-rs-alias-access-tree")?,
inverse_alias_access: db.open_tree("pict-rs-inverse-alias-access-tree")?, inverse_alias_access: db.open_tree("pict-rs-inverse-alias-access-tree")?,
identifier_access: db.open_tree("pict-rs-identifier-access-tree")?, variant_access: db.open_tree("pict-rs-variant-access-tree")?,
inverse_identifier_access: db.open_tree("pict-rs-inverse-identifier-access-tree")?, inverse_variant_access: db.open_tree("pict-rs-inverse-variant-access-tree")?,
in_progress_queue: db.open_tree("pict-rs-in-progress-queue-tree")?, in_progress_queue: db.open_tree("pict-rs-in-progress-queue-tree")?,
queue_notifier: Arc::new(RwLock::new(HashMap::new())), queue_notifier: Arc::new(RwLock::new(HashMap::new())),
uploads: db.open_tree("pict-rs-uploads-tree")?, uploads: db.open_tree("pict-rs-uploads-tree")?,
@ -169,7 +171,7 @@ impl FullRepo for SledRepo {
} }
} }
type IterValue = Option<(sled::Iter, Result<sled::IVec, RepoError>)>; type IterValue = Option<(sled::Iter, Result<IVec, RepoError>)>;
pub(crate) struct IterStream { pub(crate) struct IterStream {
iter: Option<sled::Iter>, iter: Option<sled::Iter>,
@ -177,7 +179,7 @@ pub(crate) struct IterStream {
} }
impl futures_util::Stream for IterStream { impl futures_util::Stream for IterStream {
type Item = Result<sled::IVec, RepoError>; type Item = Result<IVec, RepoError>;
fn poll_next( fn poll_next(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
@ -219,9 +221,8 @@ pub(crate) struct AliasAccessStream {
iter: IterStream, iter: IterStream,
} }
pub(crate) struct IdentifierAccessStream<I> { pub(crate) struct VariantAccessStream {
iter: IterStream, iter: IterStream,
identifier: PhantomData<fn() -> I>,
} }
impl futures_util::Stream for AliasAccessStream { impl futures_util::Stream for AliasAccessStream {
@ -245,19 +246,20 @@ impl futures_util::Stream for AliasAccessStream {
} }
} }
impl<I> futures_util::Stream for IdentifierAccessStream<I> impl futures_util::Stream for VariantAccessStream {
where type Item = Result<(IVec, String), RepoError>;
I: Identifier,
{
type Item = Result<I, StoreError>;
fn poll_next( fn poll_next(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> { ) -> std::task::Poll<Option<Self::Item>> {
match std::task::ready!(Pin::new(&mut self.iter).poll_next(cx)) { match std::task::ready!(Pin::new(&mut self.iter).poll_next(cx)) {
Some(Ok(bytes)) => std::task::Poll::Ready(Some(I::from_bytes(bytes.to_vec()))), Some(Ok(bytes)) => std::task::Poll::Ready(Some(
Some(Err(e)) => std::task::Poll::Ready(Some(Err(e.into()))), parse_variant_access_key(bytes)
.map_err(SledError::from)
.map_err(RepoError::from),
)),
Some(Err(e)) => std::task::Poll::Ready(Some(Err(e))),
None => std::task::Poll::Ready(None), None => std::task::Poll::Ready(None),
} }
} }
@ -327,75 +329,69 @@ impl AliasAccessRepo for SledRepo {
} }
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
impl IdentifierAccessRepo for SledRepo { impl VariantAccessRepo for SledRepo {
type IdentifierAccessStream<I> = IdentifierAccessStream<I> where I: Identifier; type VariantAccessStream = VariantAccessStream;
async fn accessed(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> {
let key = variant_access_key(&hash, &variant);
async fn accessed<I: Identifier>(&self, identifier: I) -> Result<(), StoreError> {
let now_string = time::OffsetDateTime::now_utc() let now_string = time::OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339) .format(&time::format_description::well_known::Rfc3339)
.map_err(SledError::Format) .map_err(SledError::Format)?;
.map_err(RepoError::from)?;
let identifier_access = self.identifier_access.clone(); let variant_access = self.variant_access.clone();
let inverse_identifier_access = self.inverse_identifier_access.clone(); let inverse_variant_access = self.inverse_variant_access.clone();
let identifier = identifier.to_bytes()?;
actix_rt::task::spawn_blocking(move || { actix_rt::task::spawn_blocking(move || {
if let Some(old) = if let Some(old) = variant_access.insert(&key, now_string.as_bytes())? {
identifier_access.insert(identifier.clone(), now_string.as_bytes())? inverse_variant_access.remove(old)?;
{
inverse_identifier_access.remove(old)?;
} }
inverse_identifier_access.insert(now_string, identifier)?; inverse_variant_access.insert(now_string, key)?;
Ok(()) as Result<(), SledError> Ok(()) as Result<(), SledError>
}) })
.await .await
.map_err(|_| RepoError::Canceled)? .map_err(|_| RepoError::Canceled)?
.map_err(RepoError::from) .map_err(RepoError::from)
.map_err(StoreError::from)
} }
async fn older_identifiers<I: Identifier>( async fn older_variants(
&self, &self,
timestamp: time::OffsetDateTime, timestamp: time::OffsetDateTime,
) -> Result<Self::IdentifierAccessStream<I>, RepoError> { ) -> Result<Self::VariantAccessStream, RepoError> {
let time_string = timestamp let time_string = timestamp
.format(&time::format_description::well_known::Rfc3339) .format(&time::format_description::well_known::Rfc3339)
.map_err(SledError::Format)?; .map_err(SledError::Format)?;
let inverse_identifier_access = self.inverse_identifier_access.clone(); let inverse_variant_access = self.inverse_variant_access.clone();
let iter = let iter =
actix_rt::task::spawn_blocking(move || inverse_identifier_access.range(..=time_string)) actix_rt::task::spawn_blocking(move || inverse_variant_access.range(..=time_string))
.await .await
.map_err(|_| RepoError::Canceled)?; .map_err(|_| RepoError::Canceled)?;
Ok(IdentifierAccessStream { Ok(VariantAccessStream {
iter: IterStream { iter: IterStream {
iter: Some(iter), iter: Some(iter),
next: None, next: None,
}, },
identifier: PhantomData,
}) })
} }
async fn remove<I: Identifier>(&self, identifier: I) -> Result<(), StoreError> { async fn remove(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> {
let identifier_access = self.identifier_access.clone(); let key = variant_access_key(&hash, &variant);
let inverse_identifier_access = self.inverse_identifier_access.clone();
let identifier = identifier.to_bytes()?; let variant_access = self.variant_access.clone();
let inverse_variant_access = self.inverse_variant_access.clone();
actix_rt::task::spawn_blocking(move || { actix_rt::task::spawn_blocking(move || {
if let Some(old) = identifier_access.remove(identifier)? { if let Some(old) = variant_access.remove(key)? {
inverse_identifier_access.remove(old)?; inverse_variant_access.remove(old)?;
} }
Ok(()) as Result<(), SledError> Ok(()) as Result<(), SledError>
}) })
.await .await
.map_err(|_| RepoError::Canceled)? .map_err(|_| RepoError::Canceled)?
.map_err(RepoError::from) .map_err(RepoError::from)
.map_err(StoreError::from)
} }
} }
@ -648,6 +644,61 @@ impl SettingsRepo for SledRepo {
} }
} }
fn variant_access_key(hash: &[u8], variant: &str) -> Vec<u8> {
let variant = variant.as_bytes();
let hash_len: u64 = u64::try_from(hash.len()).expect("Length is reasonable");
let mut out = Vec::with_capacity(8 + hash.len() + variant.len());
let hash_length_bytes: [u8; 8] = hash_len.to_be_bytes();
out.extend(hash_length_bytes);
out.extend(hash);
out.extend(variant);
out
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum VariantKeyError {
#[error("Bytes too short to be VariantAccessKey")]
TooShort,
#[error("Prefix Length is longer than backing bytes")]
InvalidLength,
#[error("Invalid utf8 in Variant")]
Utf8,
}
fn parse_variant_access_key(bytes: IVec) -> Result<(IVec, String), VariantKeyError> {
if bytes.len() < 8 {
return Err(VariantKeyError::TooShort);
}
let hash_len = u64::from_be_bytes(bytes[..8].try_into().expect("Verified length"));
let hash_len: usize = usize::try_from(hash_len).expect("Length is reasonable");
if (hash_len + 8) > bytes.len() {
return Err(VariantKeyError::InvalidLength);
}
let hash = bytes.subslice(8, hash_len);
let variant_len = bytes.len().saturating_sub(8).saturating_sub(hash_len);
if variant_len == 0 {
return Ok((hash, String::new()));
}
let variant_start = 8 + hash_len;
let variant = std::str::from_utf8(&bytes[variant_start..])
.map_err(|_| VariantKeyError::Utf8)?
.to_string();
Ok((hash, variant))
}
fn variant_key(hash: &[u8], variant: &str) -> Vec<u8> { fn variant_key(hash: &[u8], variant: &str) -> Vec<u8> {
let mut bytes = hash.to_vec(); let mut bytes = hash.to_vec();
bytes.push(b'/'); bytes.push(b'/');
@ -1106,3 +1157,20 @@ impl From<actix_rt::task::JoinError> for SledError {
SledError::Panic SledError::Panic
} }
} }
#[cfg(test)]
mod tests {
#[test]
fn round_trip() {
let hash = sled::IVec::from(b"some hash value");
let variant = String::from("some string value");
let key = super::variant_access_key(&hash, &variant);
let (out_hash, out_variant) =
super::parse_variant_access_key(sled::IVec::from(key)).expect("Parsed bytes");
assert_eq!(out_hash, hash);
assert_eq!(out_variant, variant);
}
}