diff --git a/src/repo.rs b/src/repo.rs index 24a38ae..898e0d5 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -74,7 +74,7 @@ pub(crate) trait FullRepo: + HashRepo + MigrationRepo + AliasAccessRepo - + IdentifierAccessRepo + + VariantAccessRepo + Send + Sync + Clone @@ -183,41 +183,39 @@ where } #[async_trait::async_trait(?Send)] -pub(crate) trait IdentifierAccessRepo: BaseRepo { - type IdentifierAccessStream: Stream> - where - I: Identifier; +pub(crate) trait VariantAccessRepo: BaseRepo { + type VariantAccessStream: Stream>; - async fn accessed(&self, identifier: I) -> Result<(), StoreError>; + async fn accessed(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError>; - async fn older_identifiers( + async fn older_variants( &self, timestamp: time::OffsetDateTime, - ) -> Result, RepoError>; + ) -> Result; - async fn remove(&self, identifier: I) -> Result<(), StoreError>; + async fn remove(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError>; } #[async_trait::async_trait(?Send)] -impl IdentifierAccessRepo for actix_web::web::Data +impl VariantAccessRepo for actix_web::web::Data where - T: IdentifierAccessRepo, + T: VariantAccessRepo, { - type IdentifierAccessStream = T::IdentifierAccessStream where I: Identifier; + type VariantAccessStream = T::VariantAccessStream; - async fn accessed(&self, identifier: I) -> Result<(), StoreError> { - T::accessed(self, identifier).await + async fn accessed(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> { + T::accessed(self, hash, variant).await } - async fn older_identifiers( + async fn older_variants( &self, timestamp: time::OffsetDateTime, - ) -> Result, RepoError> { - T::older_identifiers(self, timestamp).await + ) -> Result { + T::older_variants(self, timestamp).await } - async fn remove(&self, identifier: I) -> Result<(), StoreError> { - T::remove(self, identifier).await + async fn remove(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> { + T::remove(self, hash, variant).await } } diff --git a/src/repo/sled.rs b/src/repo/sled.rs index 99cb1f8..2ef0329 100644 --- a/src/repo/sled.rs +++ b/src/repo/sled.rs @@ -13,7 +13,6 @@ use futures_util::{Future, Stream}; use sled::{CompareAndSwapError, Db, IVec, Tree}; use std::{ collections::HashMap, - marker::PhantomData, path::PathBuf, pin::Pin, sync::{ @@ -23,7 +22,7 @@ use std::{ }; use tokio::{sync::Notify, task::JoinHandle}; -use super::{AliasAccessRepo, IdentifierAccessRepo, RepoError}; +use super::{AliasAccessRepo, RepoError, VariantAccessRepo}; macro_rules! b { ($self:ident.$ident:ident, $expr:expr) => {{ @@ -51,6 +50,9 @@ pub(crate) enum SledError { #[error("Error formatting timestamp")] Format(#[source] time::error::Format), + #[error("Error parsing variant key")] + VariantKey(#[from] VariantKeyError), + #[error("Operation panicked")] Panic, } @@ -72,8 +74,8 @@ pub(crate) struct SledRepo { queue: Tree, alias_access: Tree, inverse_alias_access: Tree, - identifier_access: Tree, - inverse_identifier_access: Tree, + variant_access: Tree, + inverse_variant_access: Tree, in_progress_queue: Tree, queue_notifier: Arc>>>, uploads: Tree, @@ -108,8 +110,8 @@ impl SledRepo { queue: db.open_tree("pict-rs-queue-tree")?, alias_access: db.open_tree("pict-rs-alias-access-tree")?, inverse_alias_access: db.open_tree("pict-rs-inverse-alias-access-tree")?, - identifier_access: db.open_tree("pict-rs-identifier-access-tree")?, - inverse_identifier_access: db.open_tree("pict-rs-inverse-identifier-access-tree")?, + variant_access: db.open_tree("pict-rs-variant-access-tree")?, + inverse_variant_access: db.open_tree("pict-rs-inverse-variant-access-tree")?, in_progress_queue: db.open_tree("pict-rs-in-progress-queue-tree")?, queue_notifier: Arc::new(RwLock::new(HashMap::new())), uploads: db.open_tree("pict-rs-uploads-tree")?, @@ -169,7 +171,7 @@ impl FullRepo for SledRepo { } } -type IterValue = Option<(sled::Iter, Result)>; +type IterValue = Option<(sled::Iter, Result)>; pub(crate) struct IterStream { iter: Option, @@ -177,7 +179,7 @@ pub(crate) struct IterStream { } impl futures_util::Stream for IterStream { - type Item = Result; + type Item = Result; fn poll_next( mut self: Pin<&mut Self>, @@ -219,9 +221,8 @@ pub(crate) struct AliasAccessStream { iter: IterStream, } -pub(crate) struct IdentifierAccessStream { +pub(crate) struct VariantAccessStream { iter: IterStream, - identifier: PhantomData I>, } impl futures_util::Stream for AliasAccessStream { @@ -245,19 +246,20 @@ impl futures_util::Stream for AliasAccessStream { } } -impl futures_util::Stream for IdentifierAccessStream -where - I: Identifier, -{ - type Item = Result; +impl futures_util::Stream for VariantAccessStream { + type Item = Result<(IVec, String), RepoError>; fn poll_next( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { match std::task::ready!(Pin::new(&mut self.iter).poll_next(cx)) { - Some(Ok(bytes)) => std::task::Poll::Ready(Some(I::from_bytes(bytes.to_vec()))), - Some(Err(e)) => std::task::Poll::Ready(Some(Err(e.into()))), + Some(Ok(bytes)) => std::task::Poll::Ready(Some( + parse_variant_access_key(bytes) + .map_err(SledError::from) + .map_err(RepoError::from), + )), + Some(Err(e)) => std::task::Poll::Ready(Some(Err(e))), None => std::task::Poll::Ready(None), } } @@ -327,75 +329,69 @@ impl AliasAccessRepo for SledRepo { } #[async_trait::async_trait(?Send)] -impl IdentifierAccessRepo for SledRepo { - type IdentifierAccessStream = IdentifierAccessStream where I: Identifier; +impl VariantAccessRepo for SledRepo { + type VariantAccessStream = VariantAccessStream; + + async fn accessed(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> { + let key = variant_access_key(&hash, &variant); - async fn accessed(&self, identifier: I) -> Result<(), StoreError> { let now_string = time::OffsetDateTime::now_utc() .format(&time::format_description::well_known::Rfc3339) - .map_err(SledError::Format) - .map_err(RepoError::from)?; + .map_err(SledError::Format)?; - let identifier_access = self.identifier_access.clone(); - let inverse_identifier_access = self.inverse_identifier_access.clone(); - - let identifier = identifier.to_bytes()?; + let variant_access = self.variant_access.clone(); + let inverse_variant_access = self.inverse_variant_access.clone(); actix_rt::task::spawn_blocking(move || { - if let Some(old) = - identifier_access.insert(identifier.clone(), now_string.as_bytes())? - { - inverse_identifier_access.remove(old)?; + if let Some(old) = variant_access.insert(&key, now_string.as_bytes())? { + inverse_variant_access.remove(old)?; } - inverse_identifier_access.insert(now_string, identifier)?; + inverse_variant_access.insert(now_string, key)?; Ok(()) as Result<(), SledError> }) .await .map_err(|_| RepoError::Canceled)? .map_err(RepoError::from) - .map_err(StoreError::from) } - async fn older_identifiers( + async fn older_variants( &self, timestamp: time::OffsetDateTime, - ) -> Result, RepoError> { + ) -> Result { let time_string = timestamp .format(&time::format_description::well_known::Rfc3339) .map_err(SledError::Format)?; - let inverse_identifier_access = self.inverse_identifier_access.clone(); + let inverse_variant_access = self.inverse_variant_access.clone(); let iter = - actix_rt::task::spawn_blocking(move || inverse_identifier_access.range(..=time_string)) + actix_rt::task::spawn_blocking(move || inverse_variant_access.range(..=time_string)) .await .map_err(|_| RepoError::Canceled)?; - Ok(IdentifierAccessStream { + Ok(VariantAccessStream { iter: IterStream { iter: Some(iter), next: None, }, - identifier: PhantomData, }) } - async fn remove(&self, identifier: I) -> Result<(), StoreError> { - let identifier_access = self.identifier_access.clone(); - let inverse_identifier_access = self.inverse_identifier_access.clone(); + async fn remove(&self, hash: Self::Bytes, variant: String) -> Result<(), RepoError> { + let key = variant_access_key(&hash, &variant); - let identifier = identifier.to_bytes()?; + let variant_access = self.variant_access.clone(); + let inverse_variant_access = self.inverse_variant_access.clone(); actix_rt::task::spawn_blocking(move || { - if let Some(old) = identifier_access.remove(identifier)? { - inverse_identifier_access.remove(old)?; + if let Some(old) = variant_access.remove(key)? { + inverse_variant_access.remove(old)?; } Ok(()) as Result<(), SledError> }) .await .map_err(|_| RepoError::Canceled)? .map_err(RepoError::from) - .map_err(StoreError::from) } } @@ -648,6 +644,61 @@ impl SettingsRepo for SledRepo { } } +fn variant_access_key(hash: &[u8], variant: &str) -> Vec { + let variant = variant.as_bytes(); + + let hash_len: u64 = u64::try_from(hash.len()).expect("Length is reasonable"); + + let mut out = Vec::with_capacity(8 + hash.len() + variant.len()); + + let hash_length_bytes: [u8; 8] = hash_len.to_be_bytes(); + out.extend(hash_length_bytes); + out.extend(hash); + out.extend(variant); + out +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum VariantKeyError { + #[error("Bytes too short to be VariantAccessKey")] + TooShort, + + #[error("Prefix Length is longer than backing bytes")] + InvalidLength, + + #[error("Invalid utf8 in Variant")] + Utf8, +} + +fn parse_variant_access_key(bytes: IVec) -> Result<(IVec, String), VariantKeyError> { + if bytes.len() < 8 { + return Err(VariantKeyError::TooShort); + } + + let hash_len = u64::from_be_bytes(bytes[..8].try_into().expect("Verified length")); + let hash_len: usize = usize::try_from(hash_len).expect("Length is reasonable"); + + if (hash_len + 8) > bytes.len() { + return Err(VariantKeyError::InvalidLength); + } + + let hash = bytes.subslice(8, hash_len); + + let variant_len = bytes.len().saturating_sub(8).saturating_sub(hash_len); + + if variant_len == 0 { + return Ok((hash, String::new())); + } + + let variant_start = 8 + hash_len; + + let variant = std::str::from_utf8(&bytes[variant_start..]) + .map_err(|_| VariantKeyError::Utf8)? + .to_string(); + + Ok((hash, variant)) +} + fn variant_key(hash: &[u8], variant: &str) -> Vec { let mut bytes = hash.to_vec(); bytes.push(b'/'); @@ -1106,3 +1157,20 @@ impl From for SledError { SledError::Panic } } + +#[cfg(test)] +mod tests { + #[test] + fn round_trip() { + let hash = sled::IVec::from(b"some hash value"); + let variant = String::from("some string value"); + + let key = super::variant_access_key(&hash, &variant); + + let (out_hash, out_variant) = + super::parse_variant_access_key(sled::IVec::from(key)).expect("Parsed bytes"); + + assert_eq!(out_hash, hash); + assert_eq!(out_variant, variant); + } +}