2
0
Fork 0
mirror of https://git.asonix.dog/asonix/pict-rs synced 2024-11-10 06:25:00 +00:00

Move some Stream implementations into 'stream' module

This commit is contained in:
Aode (lion) 2022-03-29 15:59:17 -05:00
parent 0e490ff54a
commit 5adb3fde89
5 changed files with 253 additions and 210 deletions

View file

@ -118,7 +118,7 @@ pub(crate) enum UploadError {
Range, Range,
#[error("Hit limit")] #[error("Hit limit")]
Limit(#[from] super::LimitError), Limit(#[from] crate::stream::LimitError),
} }
impl From<awc::error::SendRequestError> for UploadError { impl From<awc::error::SendRequestError> for UploadError {

View file

@ -14,13 +14,11 @@ use std::{
collections::BTreeSet, collections::BTreeSet,
future::ready, future::ready,
path::PathBuf, path::PathBuf,
pin::Pin,
sync::atomic::{AtomicU64, Ordering}, sync::atomic::{AtomicU64, Ordering},
task::{Context, Poll},
time::SystemTime, time::SystemTime,
}; };
use tokio::{io::AsyncReadExt, sync::Semaphore}; use tokio::{io::AsyncReadExt, sync::Semaphore};
use tracing::{debug, error, info, instrument}; use tracing::{debug, info, instrument};
use tracing_actix_web::TracingLogger; use tracing_actix_web::TracingLogger;
use tracing_awc::Tracing; use tracing_awc::Tracing;
use tracing_futures::Instrument; use tracing_futures::Instrument;
@ -44,6 +42,7 @@ mod range;
mod repo; mod repo;
mod serde_str; mod serde_str;
mod store; mod store;
mod stream;
mod tmp_file; mod tmp_file;
mod upload_manager; mod upload_manager;
mod validate; mod validate;
@ -61,6 +60,7 @@ use self::{
repo::{Alias, DeleteToken, Repo}, repo::{Alias, DeleteToken, Repo},
serde_str::Serde, serde_str::Serde,
store::{file_store::FileStore, object_store::ObjectStore, Store}, store::{file_store::FileStore, object_store::ObjectStore, Store},
stream::StreamLimit,
upload_manager::{UploadManager, UploadManagerSession}, upload_manager::{UploadManager, UploadManagerSession},
}; };
@ -138,59 +138,6 @@ struct UrlQuery {
url: String, url: String,
} }
pin_project_lite::pin_project! {
struct Limit<S> {
#[pin]
inner: S,
count: u64,
limit: u64,
}
}
impl<S> Limit<S> {
fn new(inner: S, limit: u64) -> Self {
Limit {
inner,
count: 0,
limit,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("Resonse body larger than size limit")]
struct LimitError;
impl<S, E> Stream for Limit<S>
where
S: Stream<Item = Result<web::Bytes, E>>,
E: From<LimitError>,
{
type Item = Result<web::Bytes, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().project();
let limit = this.limit;
let count = this.count;
let inner = this.inner;
inner.poll_next(cx).map(|opt| {
opt.map(|res| match res {
Ok(bytes) => {
*count += bytes.len() as u64;
if *count > *limit {
return Err(LimitError.into());
}
Ok(bytes)
}
Err(e) => Err(e),
})
})
}
}
/// download an image from a URL /// download an image from a URL
#[instrument(name = "Downloading file", skip(client, manager))] #[instrument(name = "Downloading file", skip(client, manager))]
async fn download<S: Store>( async fn download<S: Store>(
@ -205,10 +152,9 @@ async fn download<S: Store>(
return Err(UploadError::Download(res.status()).into()); return Err(UploadError::Download(res.status()).into());
} }
let stream = Limit::new( let stream = res
res.map_err(Error::from), .map_err(Error::from)
(CONFIG.media.max_file_size * MEGABYTES) as u64, .limit((CONFIG.media.max_file_size * MEGABYTES) as u64);
);
futures_util::pin_mut!(stream); futures_util::pin_mut!(stream);

View file

@ -4,9 +4,11 @@ use crate::{
Alias, AliasRepo, AlreadyExists, DeleteToken, Details, HashRepo, Identifier, Alias, AliasRepo, AlreadyExists, DeleteToken, Details, HashRepo, Identifier,
IdentifierRepo, QueueRepo, SettingsRepo, IdentifierRepo, QueueRepo, SettingsRepo,
}, },
stream::from_iterator,
}; };
use futures_util::Stream;
use sled::{Db, IVec, Tree}; use sled::{Db, IVec, Tree};
use std::sync::Arc; use std::{pin::Pin, sync::Arc};
use tokio::sync::Notify; use tokio::sync::Notify;
use super::BaseRepo; use super::BaseRepo;
@ -205,65 +207,8 @@ impl IdentifierRepo for SledRepo {
} }
} }
type BoxIterator<'a, T> = Box<dyn std::iter::Iterator<Item = T> + Send + 'a>;
type HashIterator = BoxIterator<'static, Result<IVec, sled::Error>>;
type StreamItem = Result<IVec, Error>; type StreamItem = Result<IVec, Error>;
type LocalBoxStream<'a, T> = Pin<Box<dyn Stream<Item = T> + 'a>>;
type NextFutResult = Result<(HashIterator, Option<StreamItem>), Error>;
pub(crate) struct HashStream {
hashes: Option<HashIterator>,
next_fut: Option<futures_util::future::LocalBoxFuture<'static, NextFutResult>>,
}
impl futures_util::Stream for HashStream {
type Item = StreamItem;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(mut fut) = this.next_fut.take() {
match fut.as_mut().poll(cx) {
std::task::Poll::Ready(Ok((iter, opt))) => {
this.hashes = Some(iter);
std::task::Poll::Ready(opt)
}
std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Some(Err(e))),
std::task::Poll::Pending => {
this.next_fut = Some(fut);
std::task::Poll::Pending
}
}
} else if let Some(mut iter) = this.hashes.take() {
let fut = Box::pin(async move {
actix_rt::task::spawn_blocking(move || {
let opt = iter.next();
(iter, opt)
})
.await
.map(|(iter, opt)| {
(
iter,
opt.map(|res| res.map_err(SledError::from).map_err(Error::from)),
)
})
.map_err(SledError::from)
.map_err(Error::from)
});
this.next_fut = Some(fut);
std::pin::Pin::new(this).poll_next(cx)
} else {
std::task::Poll::Ready(None)
}
}
}
fn hash_alias_key(hash: &IVec, alias: &Alias) -> Vec<u8> { fn hash_alias_key(hash: &IVec, alias: &Alias) -> Vec<u8> {
let mut v = hash.to_vec(); let mut v = hash.to_vec();
@ -273,15 +218,16 @@ fn hash_alias_key(hash: &IVec, alias: &Alias) -> Vec<u8> {
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
impl HashRepo for SledRepo { impl HashRepo for SledRepo {
type Stream = HashStream; type Stream = LocalBoxStream<'static, StreamItem>;
async fn hashes(&self) -> Self::Stream { async fn hashes(&self) -> Self::Stream {
let iter = self.hashes.iter().keys(); let iter = self
.hashes
.iter()
.keys()
.map(|res| res.map_err(Error::from));
HashStream { Box::pin(from_iterator(iter))
hashes: Some(Box::new(iter)),
next_fut: None,
}
} }
#[tracing::instrument] #[tracing::instrument]

View file

@ -2,22 +2,16 @@ use crate::{
error::Error, error::Error,
repo::{Repo, SettingsRepo}, repo::{Repo, SettingsRepo},
store::Store, store::Store,
stream::StreamTimeout,
}; };
use actix_rt::time::Sleep;
use actix_web::web::Bytes; use actix_web::web::Bytes;
use futures_util::{stream::Stream, TryStreamExt}; use futures_util::{Stream, StreamExt};
use s3::{ use s3::{
client::Client, command::Command, creds::Credentials, request_trait::Request, Bucket, Region, client::Client, command::Command, creds::Credentials, request_trait::Request, Bucket, Region,
}; };
use std::{ use std::{
future::Future,
pin::Pin, pin::Pin,
string::FromUtf8Error, string::FromUtf8Error,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Wake, Waker},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use storage_path_generator::{Generator, Path}; use storage_path_generator::{Generator, Path};
@ -58,17 +52,6 @@ pub(crate) struct ObjectStore {
client: reqwest::Client, client: reqwest::Client,
} }
pin_project_lite::pin_project! {
struct Timeout<S> {
sleep: Option<Pin<Box<Sleep>>>,
woken: Arc<AtomicBool>,
#[pin]
inner: S,
}
}
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
impl Store for ObjectStore { impl Store for ObjectStore {
type Identifier = ObjectId; type Identifier = ObjectId;
@ -139,11 +122,12 @@ impl Store for ObjectStore {
let allotted = allotted.saturating_sub(now.elapsed()); let allotted = allotted.saturating_sub(now.elapsed());
let stream = response let stream = response.bytes_stream().timeout(allotted).map(|res| {
.bytes_stream() res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); .and_then(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)))
});
Ok(request_span.in_scope(|| Box::pin(timeout(allotted, stream)))) Ok(request_span.in_scope(|| Box::pin(stream)))
} }
#[tracing::instrument(skip(writer))] #[tracing::instrument(skip(writer))]
@ -266,67 +250,6 @@ async fn init_generator(repo: &Repo) -> Result<Generator, Error> {
} }
} }
fn timeout<S, T>(duration: Duration, stream: S) -> impl Stream<Item = std::io::Result<T>>
where
S: Stream<Item = std::io::Result<T>>,
{
Timeout {
sleep: Some(Box::pin(actix_rt::time::sleep(duration))),
woken: Arc::new(AtomicBool::new(true)),
inner: stream,
}
}
struct TimeoutWaker {
woken: Arc<AtomicBool>,
inner: Waker,
}
impl Wake for TimeoutWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
self.woken.store(true, Ordering::Release);
self.inner.wake_by_ref();
}
}
impl<S, T> Stream for Timeout<S>
where
S: Stream<Item = std::io::Result<T>>,
{
type Item = std::io::Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().project();
if this.woken.swap(false, Ordering::Acquire) {
if let Some(mut sleep) = this.sleep.take() {
let timeout_waker = Arc::new(TimeoutWaker {
woken: Arc::clone(this.woken),
inner: cx.waker().clone(),
})
.into();
let mut timeout_cx = Context::from_waker(&timeout_waker);
if let Poll::Ready(()) = sleep.as_mut().poll(&mut timeout_cx) {
return Poll::Ready(Some(Err(std::io::Error::new(
std::io::ErrorKind::Other,
Error::from(ObjectError::Elapsed),
))));
} else {
*this.sleep = Some(sleep);
}
} else {
return Poll::Ready(None);
}
}
this.inner.poll_next(cx)
}
}
impl std::fmt::Debug for ObjectStore { impl std::fmt::Debug for ObjectStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ObjectStore") f.debug_struct("ObjectStore")

228
src/stream.rs Normal file
View file

@ -0,0 +1,228 @@
use actix_rt::{task::JoinHandle, time::Sleep};
use actix_web::web::Bytes;
use futures_util::Stream;
use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Wake, Waker},
time::Duration,
};
pub(crate) trait StreamLimit {
fn limit(self, limit: u64) -> Limit<Self>
where
Self: Sized,
{
Limit {
inner: self,
count: 0,
limit,
}
}
}
pub(crate) trait StreamTimeout {
fn timeout(self, duration: Duration) -> Timeout<Self>
where
Self: Sized,
{
Timeout {
sleep: actix_rt::time::sleep(duration),
inner: self,
expired: false,
woken: Arc::new(AtomicBool::new(true)),
}
}
}
pub(crate) fn from_iterator<I: IntoIterator + Unpin + Send + 'static>(
iterator: I,
) -> IterStream<I, I::Item> {
IterStream {
state: IterStreamState::New { iterator },
}
}
impl<S, E> StreamLimit for S where S: Stream<Item = Result<Bytes, E>> {}
impl<S> StreamTimeout for S where S: Stream {}
pin_project_lite::pin_project! {
pub(crate) struct Limit<S> {
#[pin]
inner: S,
count: u64,
limit: u64,
}
}
pin_project_lite::pin_project! {
pub(crate) struct Timeout<S> {
#[pin]
sleep: Sleep,
#[pin]
inner: S,
expired: bool,
woken: Arc<AtomicBool>,
}
}
enum IterStreamState<I, T> {
New {
iterator: I,
},
Running {
handle: JoinHandle<()>,
receiver: tokio::sync::mpsc::Receiver<T>,
},
Pending,
}
pub(crate) struct IterStream<I, T> {
state: IterStreamState<I, T>,
}
struct TimeoutWaker {
woken: Arc<AtomicBool>,
inner: Waker,
}
#[derive(Debug, thiserror::Error)]
#[error("Resonse body larger than size limit")]
pub(crate) struct LimitError;
#[derive(Debug, thiserror::Error)]
#[error("Timeout in body")]
pub(crate) struct TimeoutError;
impl<S, E> Stream for Limit<S>
where
S: Stream<Item = Result<Bytes, E>>,
E: From<LimitError>,
{
type Item = Result<Bytes, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().project();
let limit = this.limit;
let count = this.count;
let inner = this.inner;
inner.poll_next(cx).map(|opt| {
opt.map(|res| match res {
Ok(bytes) => {
*count += bytes.len() as u64;
if *count > *limit {
return Err(LimitError.into());
}
Ok(bytes)
}
Err(e) => Err(e),
})
})
}
}
impl Wake for TimeoutWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
self.woken.store(true, Ordering::Release);
self.inner.wake_by_ref();
}
}
impl<S, T> Stream for Timeout<S>
where
S: Stream<Item = T>,
{
type Item = Result<T, TimeoutError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().project();
if *this.expired {
return Poll::Ready(None);
}
if this.woken.swap(false, Ordering::Acquire) {
let timeout_waker = Arc::new(TimeoutWaker {
woken: Arc::clone(this.woken),
inner: cx.waker().clone(),
})
.into();
let mut timeout_cx = Context::from_waker(&timeout_waker);
if this.sleep.poll(&mut timeout_cx).is_ready() {
*this.expired = true;
return Poll::Ready(Some(Err(TimeoutError)));
}
}
this.inner.poll_next(cx).map(|opt| opt.map(Ok))
}
}
impl<I, T> Stream for IterStream<I, T>
where
I: IntoIterator<Item = T> + Send + Unpin + 'static,
T: Send + 'static,
{
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().get_mut();
match std::mem::replace(&mut this.state, IterStreamState::Pending) {
IterStreamState::New { iterator } => {
let (sender, receiver) = tokio::sync::mpsc::channel(1);
let mut handle = actix_rt::task::spawn_blocking(move || {
let iterator = iterator.into_iter();
for item in iterator {
if sender.blocking_send(item).is_err() {
break;
}
}
});
if Pin::new(&mut handle).poll(cx).is_ready() {
return Poll::Ready(None);
}
this.state = IterStreamState::Running { handle, receiver };
}
IterStreamState::Running {
mut handle,
mut receiver,
} => match Pin::new(&mut receiver).poll_recv(cx) {
Poll::Ready(Some(item)) => {
if Pin::new(&mut handle).poll(cx).is_ready() {
return Poll::Ready(Some(item));
}
this.state = IterStreamState::Running { handle, receiver };
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {
this.state = IterStreamState::Running { handle, receiver };
return Poll::Pending;
}
},
IterStreamState::Pending => return Poll::Ready(None),
}
self.poll_next(cx)
}
}