diff --git a/src/repo.rs b/src/repo.rs index 429f925..24a38ae 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -73,6 +73,8 @@ pub(crate) trait FullRepo: + QueueRepo + HashRepo + MigrationRepo + + AliasAccessRepo + + IdentifierAccessRepo + Send + Sync + Clone @@ -143,6 +145,82 @@ where type Bytes = T::Bytes; } +#[async_trait::async_trait(?Send)] +pub(crate) trait AliasAccessRepo: BaseRepo { + type AliasAccessStream: Stream>; + + async fn accessed(&self, alias: Alias) -> Result<(), RepoError>; + + async fn older_aliases( + &self, + timestamp: time::OffsetDateTime, + ) -> Result; + + async fn remove(&self, alias: Alias) -> Result<(), RepoError>; +} + +#[async_trait::async_trait(?Send)] +impl AliasAccessRepo for actix_web::web::Data +where + T: AliasAccessRepo, +{ + type AliasAccessStream = T::AliasAccessStream; + + async fn accessed(&self, alias: Alias) -> Result<(), RepoError> { + T::accessed(self, alias).await + } + + async fn older_aliases( + &self, + timestamp: time::OffsetDateTime, + ) -> Result { + T::older_aliases(self, timestamp).await + } + + async fn remove(&self, alias: Alias) -> Result<(), RepoError> { + T::remove(self, alias).await + } +} + +#[async_trait::async_trait(?Send)] +pub(crate) trait IdentifierAccessRepo: BaseRepo { + type IdentifierAccessStream: Stream> + where + I: Identifier; + + async fn accessed(&self, identifier: I) -> Result<(), StoreError>; + + async fn older_identifiers( + &self, + timestamp: time::OffsetDateTime, + ) -> Result, RepoError>; + + async fn remove(&self, identifier: I) -> Result<(), StoreError>; +} + +#[async_trait::async_trait(?Send)] +impl IdentifierAccessRepo for actix_web::web::Data +where + T: IdentifierAccessRepo, +{ + type IdentifierAccessStream = T::IdentifierAccessStream where I: Identifier; + + async fn accessed(&self, identifier: I) -> Result<(), StoreError> { + T::accessed(self, identifier).await + } + + async fn older_identifiers( + &self, + timestamp: time::OffsetDateTime, + ) -> Result, RepoError> { + T::older_identifiers(self, timestamp).await + } + + async fn remove(&self, identifier: I) -> Result<(), StoreError> { + T::remove(self, identifier).await + } +} + #[async_trait::async_trait(?Send)] pub(crate) trait UploadRepo: BaseRepo { async fn create(&self, upload_id: UploadId) -> Result<(), RepoError>; diff --git a/src/repo/sled.rs b/src/repo/sled.rs index e752582..d4f34ea 100644 --- a/src/repo/sled.rs +++ b/src/repo/sled.rs @@ -9,10 +9,11 @@ use crate::{ store::StoreError, stream::from_iterator, }; -use futures_util::Stream; +use futures_util::{Future, Stream}; use sled::{CompareAndSwapError, Db, IVec, Tree}; use std::{ collections::HashMap, + marker::PhantomData, path::PathBuf, pin::Pin, sync::{ @@ -20,9 +21,9 @@ use std::{ Arc, RwLock, }, }; -use tokio::sync::Notify; +use tokio::{sync::Notify, task::JoinHandle}; -use super::RepoError; +use super::{AliasAccessRepo, IdentifierAccessRepo, RepoError}; macro_rules! b { ($self:ident.$ident:ident, $expr:expr) => {{ @@ -47,6 +48,9 @@ pub(crate) enum SledError { #[error("Invalid details json")] Details(#[from] serde_json::Error), + #[error("Error formatting timestamp")] + Format(#[source] time::error::Format), + #[error("Operation panicked")] Panic, } @@ -66,6 +70,10 @@ pub(crate) struct SledRepo { alias_hashes: Tree, alias_delete_tokens: Tree, queue: Tree, + alias_access: Tree, + inverse_alias_access: Tree, + identifier_access: Tree, + inverse_identifier_access: Tree, in_progress_queue: Tree, queue_notifier: Arc>>>, uploads: Tree, @@ -98,6 +106,10 @@ impl SledRepo { alias_hashes: db.open_tree("pict-rs-alias-hashes-tree")?, alias_delete_tokens: db.open_tree("pict-rs-alias-delete-tokens-tree")?, queue: db.open_tree("pict-rs-queue-tree")?, + alias_access: db.open_tree("pict-rs-alias-access-tree")?, + inverse_alias_access: db.open_tree("pict-rs-inverse-alias-access-tree")?, + identifier_access: db.open_tree("pict-rs-identifier-access-tree")?, + inverse_identifier_access: db.open_tree("pict-rs-inverse-identifier-access-tree")?, in_progress_queue: db.open_tree("pict-rs-in-progress-queue-tree")?, queue_notifier: Arc::new(RwLock::new(HashMap::new())), uploads: db.open_tree("pict-rs-uploads-tree")?, @@ -157,6 +169,234 @@ impl FullRepo for SledRepo { } } +pub(crate) struct IterStream { + iter: Option, + next: Option)>>>, +} + +impl futures_util::Stream for IterStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Some(ref mut next) = self.next { + let res = std::task::ready!(Pin::new(next).poll(cx)); + + self.next.take(); + + let opt = match res { + Ok(opt) => opt, + Err(_) => return std::task::Poll::Ready(Some(Err(RepoError::Canceled))), + }; + + if let Some((iter, res)) = opt { + self.iter = Some(iter); + + std::task::Poll::Ready(Some(res)) + } else { + std::task::Poll::Ready(None) + } + } else if let Some(mut iter) = self.iter.take() { + self.next = Some(tokio::task::spawn_blocking(move || { + let opt = iter + .next() + .map(|res| res.map_err(SledError::from).map_err(RepoError::from)); + + opt.map(|res| (iter, res.map(|(_, value)| value))) + })); + self.poll_next(cx) + } else { + std::task::Poll::Ready(None) + } + } +} + +pub(crate) struct AliasAccessStream { + iter: IterStream, +} + +pub(crate) struct IdentifierAccessStream { + iter: IterStream, + identifier: PhantomData I>, +} + +impl futures_util::Stream for AliasAccessStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match std::task::ready!(Pin::new(&mut self.iter).poll_next(cx)) { + Some(Ok(bytes)) => { + if let Some(alias) = Alias::from_slice(&bytes) { + std::task::Poll::Ready(Some(Ok(alias))) + } else { + self.poll_next(cx) + } + } + Some(Err(e)) => std::task::Poll::Ready(Some(Err(e))), + None => std::task::Poll::Ready(None), + } + } +} + +impl futures_util::Stream for IdentifierAccessStream +where + I: Identifier, +{ + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match std::task::ready!(Pin::new(&mut self.iter).poll_next(cx)) { + Some(Ok(bytes)) => std::task::Poll::Ready(Some(I::from_bytes(bytes.to_vec()))), + Some(Err(e)) => std::task::Poll::Ready(Some(Err(e.into()))), + None => std::task::Poll::Ready(None), + } + } +} + +#[async_trait::async_trait(?Send)] +impl AliasAccessRepo for SledRepo { + type AliasAccessStream = AliasAccessStream; + + async fn accessed(&self, alias: Alias) -> Result<(), RepoError> { + let now_string = time::OffsetDateTime::now_utc() + .format(&time::format_description::well_known::Rfc3339) + .map_err(SledError::Format)?; + + let alias_access = self.alias_access.clone(); + let inverse_alias_access = self.inverse_alias_access.clone(); + + actix_rt::task::spawn_blocking(move || { + if let Some(old) = alias_access.insert(alias.to_bytes(), now_string.as_bytes())? { + inverse_alias_access.remove(old)?; + } + inverse_alias_access.insert(now_string, alias.to_bytes())?; + Ok(()) as Result<(), SledError> + }) + .await + .map_err(|_| RepoError::Canceled)? + .map_err(RepoError::from) + } + + async fn older_aliases( + &self, + timestamp: time::OffsetDateTime, + ) -> Result { + let time_string = timestamp + .format(&time::format_description::well_known::Rfc3339) + .map_err(SledError::Format)?; + + let inverse_alias_access = self.inverse_alias_access.clone(); + + let iter = + actix_rt::task::spawn_blocking(move || inverse_alias_access.range(..=time_string)) + .await + .map_err(|_| RepoError::Canceled)?; + + Ok(AliasAccessStream { + iter: IterStream { + iter: Some(iter), + next: None, + }, + }) + } + + async fn remove(&self, alias: Alias) -> Result<(), RepoError> { + let alias_access = self.alias_access.clone(); + let inverse_alias_access = self.inverse_alias_access.clone(); + + actix_rt::task::spawn_blocking(move || { + if let Some(old) = alias_access.remove(alias.to_bytes())? { + inverse_alias_access.remove(old)?; + } + Ok(()) as Result<(), SledError> + }) + .await + .map_err(|_| RepoError::Canceled)? + .map_err(RepoError::from) + } +} + +#[async_trait::async_trait(?Send)] +impl IdentifierAccessRepo for SledRepo { + type IdentifierAccessStream = IdentifierAccessStream where I: Identifier; + + async fn accessed(&self, identifier: I) -> Result<(), StoreError> { + let now_string = time::OffsetDateTime::now_utc() + .format(&time::format_description::well_known::Rfc3339) + .map_err(SledError::Format) + .map_err(RepoError::from)?; + + let identifier_access = self.identifier_access.clone(); + let inverse_identifier_access = self.inverse_identifier_access.clone(); + + let identifier = identifier.to_bytes()?; + + actix_rt::task::spawn_blocking(move || { + if let Some(old) = + identifier_access.insert(identifier.clone(), now_string.as_bytes())? + { + inverse_identifier_access.remove(old)?; + } + inverse_identifier_access.insert(now_string, identifier)?; + Ok(()) as Result<(), SledError> + }) + .await + .map_err(|_| RepoError::Canceled)? + .map_err(RepoError::from) + .map_err(StoreError::from) + } + + async fn older_identifiers( + &self, + timestamp: time::OffsetDateTime, + ) -> Result, RepoError> { + let time_string = timestamp + .format(&time::format_description::well_known::Rfc3339) + .map_err(SledError::Format)?; + + let inverse_identifier_access = self.inverse_identifier_access.clone(); + + let iter = + actix_rt::task::spawn_blocking(move || inverse_identifier_access.range(..=time_string)) + .await + .map_err(|_| RepoError::Canceled)?; + + Ok(IdentifierAccessStream { + iter: IterStream { + iter: Some(iter), + next: None, + }, + identifier: PhantomData, + }) + } + + async fn remove(&self, identifier: I) -> Result<(), StoreError> { + let identifier_access = self.identifier_access.clone(); + let inverse_identifier_access = self.inverse_identifier_access.clone(); + + let identifier = identifier.to_bytes()?; + + actix_rt::task::spawn_blocking(move || { + if let Some(old) = identifier_access.remove(identifier)? { + inverse_identifier_access.remove(old)?; + } + Ok(()) as Result<(), SledError> + }) + .await + .map_err(|_| RepoError::Canceled)? + .map_err(RepoError::from) + .map_err(StoreError::from) + } +} + #[derive(serde::Deserialize, serde::Serialize)] enum InnerUploadResult { Success {