2
0
Fork 0
mirror of https://git.asonix.dog/asonix/pict-rs synced 2024-12-22 19:31:35 +00:00

Start replacing manual stream implementations with streem

This commit is contained in:
asonix 2023-09-10 18:55:13 -04:00
parent 5a6179c0ff
commit 1b97ac1c5a
13 changed files with 212 additions and 430 deletions

11
Cargo.lock generated
View file

@ -1852,6 +1852,7 @@ dependencies = [
"sha2", "sha2",
"sled", "sled",
"storage-path-generator", "storage-path-generator",
"streem",
"thiserror", "thiserror",
"time", "time",
"tokio", "tokio",
@ -2646,6 +2647,16 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f11d35dae9818c4313649da4a97c8329e29357a7fe584526c1d78f5b63ef836" checksum = "7f11d35dae9818c4313649da4a97c8329e29357a7fe584526c1d78f5b63ef836"
[[package]]
name = "streem"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "641396a5ae90767cb12d21832444ab760841ee887717d802b2c456c4f8199114"
dependencies = [
"futures-core",
"pin-project-lite",
]
[[package]] [[package]]
name = "stringprep" name = "stringprep"
version = "0.1.4" version = "0.1.4"

View file

@ -57,6 +57,7 @@ serde_urlencoded = "0.7.1"
sha2 = "0.10.0" sha2 = "0.10.0"
sled = { version = "0.34.7" } sled = { version = "0.34.7" }
storage-path-generator = "0.1.0" storage-path-generator = "0.1.0"
streem = "0.1.1"
thiserror = "1.0" thiserror = "1.0"
time = { version = "0.3.0", features = ["serde", "serde-well-known"] } time = { version = "0.3.0", features = ["serde", "serde-well-known"] }
tokio = { version = "1", features = ["full", "tracing"] } tokio = { version = "1", features = ["full", "tracing"] }

View file

@ -4,7 +4,6 @@ use crate::{
error::Error, error::Error,
repo::{ArcRepo, UploadId}, repo::{ArcRepo, UploadId},
store::Store, store::Store,
stream::StreamMap,
}; };
use actix_web::web::Bytes; use actix_web::web::Bytes;
use futures_core::Stream; use futures_core::Stream;
@ -34,7 +33,7 @@ impl Backgrounded {
pub(crate) async fn proxy<S, P>(repo: ArcRepo, store: S, stream: P) -> Result<Self, Error> pub(crate) async fn proxy<S, P>(repo: ArcRepo, store: S, stream: P) -> Result<Self, Error>
where where
S: Store, S: Store,
P: Stream<Item = Result<Bytes, Error>> + Unpin + 'static, P: Stream<Item = Result<Bytes, Error>> + 'static,
{ {
let mut this = Self { let mut this = Self {
repo, repo,
@ -50,12 +49,13 @@ impl Backgrounded {
async fn do_proxy<S, P>(&mut self, store: S, stream: P) -> Result<(), Error> async fn do_proxy<S, P>(&mut self, store: S, stream: P) -> Result<(), Error>
where where
S: Store, S: Store,
P: Stream<Item = Result<Bytes, Error>> + Unpin + 'static, P: Stream<Item = Result<Bytes, Error>> + 'static,
{ {
self.upload_id = Some(self.repo.create_upload().await?); self.upload_id = Some(self.repo.create_upload().await?);
let stream = let stream = Box::pin(crate::stream::map_err(stream, |e| {
stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); std::io::Error::new(std::io::ErrorKind::Other, e)
}));
// use octet-stream, we don't know the upload's real type yet // use octet-stream, we don't know the upload's real type yet
let identifier = store.save_stream(stream, APPLICATION_OCTET_STREAM).await?; let identifier = store.save_stream(stream, APPLICATION_OCTET_STREAM).await?;

View file

@ -7,9 +7,9 @@ use crate::{
formats::{InternalFormat, InternalVideoFormat}, formats::{InternalFormat, InternalVideoFormat},
serde_str::Serde, serde_str::Serde,
store::Store, store::Store,
stream::IntoStreamer,
}; };
use actix_web::web; use actix_web::web;
use streem::IntoStreamer;
use time::{format_description::well_known::Rfc3339, OffsetDateTime}; use time::{format_description::well_known::Rfc3339, OffsetDateTime};
#[derive(Copy, Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Copy, Clone, Debug, serde::Deserialize, serde::Serialize)]

View file

@ -6,14 +6,11 @@ pub(crate) use tokio_file::File;
#[cfg(not(feature = "io-uring"))] #[cfg(not(feature = "io-uring"))]
mod tokio_file { mod tokio_file {
use crate::{ use crate::{store::file_store::FileError, Either};
store::file_store::FileError,
stream::{IntoStreamer, StreamMap},
Either,
};
use actix_web::web::{Bytes, BytesMut}; use actix_web::web::{Bytes, BytesMut};
use futures_core::Stream; use futures_core::Stream;
use std::{io::SeekFrom, path::Path}; use std::{io::SeekFrom, path::Path};
use streem::IntoStreamer;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt};
use tokio_util::codec::{BytesCodec, FramedRead}; use tokio_util::codec::{BytesCodec, FramedRead};
@ -100,14 +97,17 @@ mod tokio_file {
(None, None) => Either::right(self.inner), (None, None) => Either::right(self.inner),
}; };
Ok(FramedRead::new(obj, BytesCodec::new()).map(|res| res.map(BytesMut::freeze))) Ok(crate::stream::map_ok(
FramedRead::new(obj, BytesCodec::new()),
BytesMut::freeze,
))
} }
} }
} }
#[cfg(feature = "io-uring")] #[cfg(feature = "io-uring")]
mod io_uring { mod io_uring {
use crate::{store::file_store::FileError, stream::IntoStreamer}; use crate::store::file_store::FileError;
use actix_web::web::{Bytes, BytesMut}; use actix_web::web::{Bytes, BytesMut};
use futures_core::Stream; use futures_core::Stream;
use std::{ use std::{
@ -118,6 +118,7 @@ mod io_uring {
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use streem::IntoStreamer;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio_uring::{ use tokio_uring::{
buf::{IoBuf, IoBufMut}, buf::{IoBuf, IoBufMut},

View file

@ -7,12 +7,12 @@ use crate::{
formats::{InternalFormat, Validations}, formats::{InternalFormat, Validations},
repo::{Alias, ArcRepo, DeleteToken, Hash}, repo::{Alias, ArcRepo, DeleteToken, Hash},
store::Store, store::Store,
stream::{IntoStreamer, MakeSend},
}; };
use actix_web::web::Bytes; use actix_web::web::Bytes;
use futures_core::Stream; use futures_core::Stream;
use reqwest::Body; use reqwest::Body;
use reqwest_middleware::ClientWithMiddleware; use reqwest_middleware::ClientWithMiddleware;
use streem::IntoStreamer;
use tracing::{Instrument, Span}; use tracing::{Instrument, Span};
mod hasher; mod hasher;
@ -30,10 +30,11 @@ pub(crate) struct Session {
#[tracing::instrument(skip(stream))] #[tracing::instrument(skip(stream))]
async fn aggregate<S>(stream: S) -> Result<Bytes, Error> async fn aggregate<S>(stream: S) -> Result<Bytes, Error>
where where
S: Stream<Item = Result<Bytes, Error>> + Unpin, S: Stream<Item = Result<Bytes, Error>>,
{ {
let mut buf = BytesStream::new(); let mut buf = BytesStream::new();
let stream = std::pin::pin!(stream);
let mut stream = stream.into_streamer(); let mut stream = stream.into_streamer();
while let Some(res) = stream.next().await { while let Some(res) = stream.next().await {
@ -48,7 +49,7 @@ pub(crate) async fn ingest<S>(
repo: &ArcRepo, repo: &ArcRepo,
store: &S, store: &S,
client: &ClientWithMiddleware, client: &ClientWithMiddleware,
stream: impl Stream<Item = Result<Bytes, Error>> + Unpin + 'static, stream: impl Stream<Item = Result<Bytes, Error>> + 'static,
declared_alias: Option<Alias>, declared_alias: Option<Alias>,
media: &crate::config::Media, media: &crate::config::Media,
) -> Result<Session, Error> ) -> Result<Session, Error>
@ -117,12 +118,12 @@ where
}; };
if let Some(endpoint) = &media.external_validation { if let Some(endpoint) = &media.external_validation {
let stream = store.to_stream(&identifier, None, None).await?.make_send(); let stream = store.to_stream(&identifier, None, None).await?;
let response = client let response = client
.post(endpoint.as_str()) .post(endpoint.as_str())
.header("Content-Type", input_type.media_type().as_ref()) .header("Content-Type", input_type.media_type().as_ref())
.body(Body::wrap_stream(stream)) .body(Body::wrap_stream(crate::stream::make_send(stream)))
.send() .send()
.instrument(tracing::info_span!("external-validation")) .instrument(tracing::info_span!("external-validation"))
.await?; .await?;

View file

@ -54,6 +54,7 @@ use std::{
sync::Arc, sync::Arc,
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use streem::IntoStreamer;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use tracing::Instrument; use tracing::Instrument;
use tracing_actix_web::TracingLogger; use tracing_actix_web::TracingLogger;
@ -74,7 +75,7 @@ use self::{
repo::{sled::SledRepo, Alias, DeleteToken, Hash, Repo, UploadId, UploadResult}, repo::{sled::SledRepo, Alias, DeleteToken, Hash, Repo, UploadId, UploadResult},
serde_str::Serde, serde_str::Serde,
store::{file_store::FileStore, object_store::ObjectStore, Store}, store::{file_store::FileStore, object_store::ObjectStore, Store},
stream::{empty, once, StreamLimit, StreamMap, StreamTimeout}, stream::{empty, once},
}; };
pub use self::config::{ConfigSource, PictRsConfiguration}; pub use self::config::{ConfigSource, PictRsConfiguration};
@ -165,14 +166,14 @@ impl<S: Store + 'static> FormData for Upload<S> {
let span = tracing::info_span!("file-upload", ?filename); let span = tracing::info_span!("file-upload", ?filename);
let stream = stream.map(|res| res.map_err(Error::from));
Box::pin( Box::pin(
async move { async move {
if config.server.read_only { if config.server.read_only {
return Err(UploadError::ReadOnly.into()); return Err(UploadError::ReadOnly.into());
} }
let stream = crate::stream::from_err(stream);
ingest::ingest(&repo, &**store, &client, stream, None, &config.media) ingest::ingest(&repo, &**store, &client, stream, None, &config.media)
.await .await
} }
@ -230,14 +231,14 @@ impl<S: Store + 'static> FormData for Import<S> {
let span = tracing::info_span!("file-import", ?filename); let span = tracing::info_span!("file-import", ?filename);
let stream = stream.map(|res| res.map_err(Error::from));
Box::pin( Box::pin(
async move { async move {
if config.server.read_only { if config.server.read_only {
return Err(UploadError::ReadOnly.into()); return Err(UploadError::ReadOnly.into());
} }
let stream = crate::stream::from_err(stream);
ingest::ingest( ingest::ingest(
&repo, &repo,
&**store, &**store,
@ -368,14 +369,14 @@ impl<S: Store + 'static> FormData for BackgroundedUpload<S> {
let span = tracing::info_span!("file-proxy", ?filename); let span = tracing::info_span!("file-proxy", ?filename);
let stream = stream.map(|res| res.map_err(Error::from));
Box::pin( Box::pin(
async move { async move {
if read_only { if read_only {
return Err(UploadError::ReadOnly.into()); return Err(UploadError::ReadOnly.into());
} }
let stream = crate::stream::from_err(stream);
Backgrounded::proxy(repo, store, stream).await Backgrounded::proxy(repo, store, stream).await
} }
.instrument(span), .instrument(span),
@ -488,7 +489,7 @@ struct UrlQuery {
} }
async fn ingest_inline<S: Store + 'static>( async fn ingest_inline<S: Store + 'static>(
stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, stream: impl Stream<Item = Result<web::Bytes, Error>> + 'static,
repo: &ArcRepo, repo: &ArcRepo,
store: &S, store: &S,
client: &ClientWithMiddleware, client: &ClientWithMiddleware,
@ -527,7 +528,7 @@ async fn download_stream(
client: &ClientWithMiddleware, client: &ClientWithMiddleware,
url: &str, url: &str,
config: &Configuration, config: &Configuration,
) -> Result<impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, Error> { ) -> Result<impl Stream<Item = Result<web::Bytes, Error>> + 'static, Error> {
if config.server.read_only { if config.server.read_only {
return Err(UploadError::ReadOnly.into()); return Err(UploadError::ReadOnly.into());
} }
@ -538,10 +539,10 @@ async fn download_stream(
return Err(UploadError::Download(res.status()).into()); return Err(UploadError::Download(res.status()).into());
} }
let stream = res let stream = crate::stream::limit(
.bytes_stream() config.media.max_file_size * MEGABYTES,
.map(|res| res.map_err(Error::from)) crate::stream::from_err(res.bytes_stream()),
.limit((config.media.max_file_size * MEGABYTES) as u64); );
Ok(stream) Ok(stream)
} }
@ -551,7 +552,7 @@ async fn download_stream(
skip(stream, repo, store, client, config) skip(stream, repo, store, client, config)
)] )]
async fn do_download_inline<S: Store + 'static>( async fn do_download_inline<S: Store + 'static>(
stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, stream: impl Stream<Item = Result<web::Bytes, Error>> + 'static,
repo: web::Data<ArcRepo>, repo: web::Data<ArcRepo>,
store: web::Data<S>, store: web::Data<S>,
client: &ClientWithMiddleware, client: &ClientWithMiddleware,
@ -574,7 +575,7 @@ async fn do_download_inline<S: Store + 'static>(
#[tracing::instrument(name = "Downloading file in background", skip(stream, repo, store))] #[tracing::instrument(name = "Downloading file in background", skip(stream, repo, store))]
async fn do_download_backgrounded<S: Store + 'static>( async fn do_download_backgrounded<S: Store + 'static>(
stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, stream: impl Stream<Item = Result<web::Bytes, Error>> + 'static,
repo: web::Data<ArcRepo>, repo: web::Data<ArcRepo>,
store: web::Data<S>, store: web::Data<S>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
@ -1325,9 +1326,7 @@ async fn ranged_file_resp<S: Store + 'static>(
( (
builder, builder,
Either::left(Either::left( Either::left(Either::left(
range::chop_store(range, store, &identifier, len) range::chop_store(range, store, &identifier, len).await?,
.await?
.map(|res| res.map_err(Error::from)),
)), )),
) )
} else { } else {
@ -1341,10 +1340,7 @@ async fn ranged_file_resp<S: Store + 'static>(
} }
} else { } else {
//No Range header in the request - return the entire document //No Range header in the request - return the entire document
let stream = store let stream = crate::stream::from_err(store.to_stream(&identifier, None, None).await?);
.to_stream(&identifier, None, None)
.await?
.map(|res| res.map_err(Error::from));
if not_found { if not_found {
(HttpResponse::NotFound(), Either::right(stream)) (HttpResponse::NotFound(), Either::right(stream))
@ -1375,10 +1371,18 @@ where
E: std::error::Error + 'static, E: std::error::Error + 'static,
actix_web::Error: From<E>, actix_web::Error: From<E>,
{ {
let stream = stream.timeout(Duration::from_secs(5)).map(|res| match res { let stream = crate::stream::timeout(Duration::from_secs(5), stream);
Ok(Ok(item)) => Ok(item),
Ok(Err(e)) => Err(actix_web::Error::from(e)), let stream = streem::try_from_fn(|yielder| async move {
Err(e) => Err(Error::from(e).into()), let stream = std::pin::pin!(stream);
let mut streamer = stream.into_streamer();
while let Some(res) = streamer.next().await {
let item = res.map_err(Error::from)??;
yielder.yield_ok(item).await;
}
Ok(()) as Result<(), actix_web::Error>
}); });
srv_head(builder, ext, expires, modified).streaming(stream) srv_head(builder, ext, expires, modified).streaming(stream)

View file

@ -7,12 +7,13 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use streem::IntoStreamer;
use crate::{ use crate::{
details::Details, details::Details,
error::{Error, UploadError}, error::{Error, UploadError},
repo::{ArcRepo, Hash}, repo::{ArcRepo, Hash},
store::Store, store::Store,
stream::IntoStreamer,
}; };
pub(super) async fn migrate_store<S1, S2>( pub(super) async fn migrate_store<S1, S2>(

View file

@ -1,5 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use streem::IntoStreamer;
use crate::{ use crate::{
config::Configuration, config::Configuration,
error::{Error, UploadError}, error::{Error, UploadError},
@ -8,7 +10,6 @@ use crate::{
repo::{Alias, ArcRepo, DeleteToken, Hash}, repo::{Alias, ArcRepo, DeleteToken, Hash},
serde_str::Serde, serde_str::Serde,
store::Store, store::Store,
stream::IntoStreamer,
}; };
pub(super) fn perform<'a, S>( pub(super) fn perform<'a, S>(

View file

@ -12,7 +12,6 @@ use crate::{
repo::{Alias, ArcRepo, UploadId, UploadResult}, repo::{Alias, ArcRepo, UploadId, UploadResult},
serde_str::Serde, serde_str::Serde,
store::Store, store::Store,
stream::StreamMap,
}; };
use std::{path::PathBuf, sync::Arc}; use std::{path::PathBuf, sync::Arc};
@ -131,10 +130,7 @@ where
let media = media.clone(); let media = media.clone();
let error_boundary = crate::sync::spawn(async move { let error_boundary = crate::sync::spawn(async move {
let stream = store2 let stream = crate::stream::from_err(store2.to_stream(&ident, None, None).await?);
.to_stream(&ident, None, None)
.await?
.map(|res| res.map_err(Error::from));
let session = let session =
crate::ingest::ingest(&repo, &store2, &client, stream, declared_alias, &media) crate::ingest::ingest(&repo, &store2, &client, stream, declared_alias, &media)

View file

@ -1,5 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use streem::IntoStreamer;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use crate::{ use crate::{
@ -12,7 +13,6 @@ use crate::{
SledRepo as OldSledRepo, SledRepo as OldSledRepo,
}, },
store::Store, store::Store,
stream::IntoStreamer,
}; };
const MIGRATE_CONCURRENCY: usize = 32; const MIGRATE_CONCURRENCY: usize = 32;

View file

@ -1,9 +1,6 @@
use crate::{ use crate::{
bytes_stream::BytesStream, bytes_stream::BytesStream, error_code::ErrorCode, repo::ArcRepo, store::Store,
error_code::ErrorCode, stream::LocalBoxStream,
repo::ArcRepo,
store::Store,
stream::{IntoStreamer, LocalBoxStream, StreamMap},
}; };
use actix_rt::task::JoinError; use actix_rt::task::JoinError;
use actix_web::{ use actix_web::{
@ -21,6 +18,7 @@ use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
use rusty_s3::{actions::S3Action, Bucket, BucketError, Credentials, UrlStyle}; use rusty_s3::{actions::S3Action, Bucket, BucketError, Credentials, UrlStyle};
use std::{string::FromUtf8Error, sync::Arc, time::Duration}; use std::{string::FromUtf8Error, sync::Arc, time::Duration};
use storage_path_generator::{Generator, Path}; use storage_path_generator::{Generator, Path};
use streem::IntoStreamer;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::io::ReaderStream; use tokio_util::io::ReaderStream;
use tracing::Instrument; use tracing::Instrument;
@ -393,11 +391,10 @@ impl Store for ObjectStore {
return Err(status_error(response).await); return Err(status_error(response).await);
} }
Ok(Box::pin( Ok(Box::pin(crate::stream::map_err(
response response.bytes_stream(),
.bytes_stream() payload_to_io_error,
.map(|res| res.map_err(payload_to_io_error)), )))
))
} }
#[tracing::instrument(skip(self, writer))] #[tracing::instrument(skip(self, writer))]

View file

@ -1,269 +1,170 @@
use actix_rt::{task::JoinHandle, time::Sleep};
use actix_web::web::Bytes; use actix_web::web::Bytes;
use flume::r#async::RecvStream;
use futures_core::Stream; use futures_core::Stream;
use std::{ use std::{pin::Pin, time::Duration};
future::Future, use streem::IntoStreamer;
marker::PhantomData,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Wake, Waker},
time::Duration,
};
pub(crate) trait MakeSend<T>: Stream<Item = std::io::Result<T>> pub(crate) fn make_send<S>(stream: S) -> impl Stream<Item = S::Item> + Send
where where
T: 'static, S: Stream + 'static,
S::Item: Send + Sync,
{ {
fn make_send(self) -> MakeSendStream<T> let (tx, rx) = crate::sync::channel(1);
where
Self: Sized + 'static,
{
let (tx, rx) = crate::sync::channel(4);
MakeSendStream { let handle = crate::sync::spawn(async move {
handle: crate::sync::spawn(async move { let stream = std::pin::pin!(stream);
let this = std::pin::pin!(self); let mut streamer = stream.into_streamer();
let mut stream = this.into_streamer(); while let Some(res) = streamer.next().await {
if tx.send_async(res).await.is_err() {
while let Some(res) = stream.next().await { break;
if tx.send_async(res).await.is_err() {
return;
}
}
}),
rx: rx.into_stream(),
}
}
}
impl<S, T> MakeSend<T> for S
where
S: Stream<Item = std::io::Result<T>>,
T: 'static,
{
}
pub(crate) struct MakeSendStream<T>
where
T: 'static,
{
handle: actix_rt::task::JoinHandle<()>,
rx: flume::r#async::RecvStream<'static, std::io::Result<T>>,
}
impl<T> Stream for MakeSendStream<T>
where
T: 'static,
{
type Item = std::io::Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Ready(opt) => Poll::Ready(opt),
Poll::Pending if std::task::ready!(Pin::new(&mut self.handle).poll(cx)).is_err() => {
Poll::Ready(Some(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Stream panicked",
))))
} }
Poll::Pending => Poll::Pending,
} }
} });
streem::from_fn(|yiedler| async move {
let mut stream = rx.into_stream().into_streamer();
while let Some(res) = stream.next().await {
yiedler.yield_(res).await;
}
let _ = handle.await;
})
} }
pin_project_lite::pin_project! { pub(crate) fn from_iterator<I>(iterator: I, buffer: usize) -> impl Stream<Item = I::Item> + Send
pub(crate) struct Map<S, F> { where
#[pin] I: IntoIterator + Send + 'static,
stream: S, I::Item: Send + Sync,
func: F, {
} let (tx, rx) = crate::sync::channel(buffer);
let handle = crate::sync::spawn_blocking(move || {
for value in iterator {
if tx.send(value).is_err() {
break;
}
}
});
streem::from_fn(|yielder| async move {
let mut stream = rx.into_stream().into_streamer();
while let Some(res) = stream.next().await {
yielder.yield_(res).await;
}
let _ = handle.await;
})
} }
pub(crate) trait StreamMap: Stream { pub(crate) fn map_ok<S, T1, T2, E, F>(stream: S, f: F) -> impl Stream<Item = Result<T2, E>>
fn map<F, U>(self, func: F) -> Map<Self, F> where
where S: Stream<Item = Result<T1, E>>,
F: FnMut(Self::Item) -> U, T2: 'static,
Self: Sized, E: 'static,
{ F: Fn(T1) -> T2 + Copy,
Map { stream: self, func } {
} streem::from_fn(|yielder| async move {
let stream = std::pin::pin!(stream);
let mut streamer = stream.into_streamer();
while let Some(res) = streamer.next().await {
yielder.yield_(res.map(f)).await;
}
})
} }
impl<T> StreamMap for T where T: Stream {} pub(crate) fn map_err<S, T, E1, E2, F>(stream: S, f: F) -> impl Stream<Item = Result<T, E2>>
where
S: Stream<Item = Result<T, E1>>,
T: 'static,
E2: 'static,
F: Fn(E1) -> E2 + Copy,
{
streem::from_fn(|yielder| async move {
let stream = std::pin::pin!(stream);
let mut streamer = stream.into_streamer();
impl<S, F, U> Stream for Map<S, F> while let Some(res) = streamer.next().await {
yielder.yield_(res.map_err(f)).await;
}
})
}
pub(crate) fn from_err<S, T, E1, E2>(stream: S) -> impl Stream<Item = Result<T, E2>>
where
S: Stream<Item = Result<T, E1>>,
T: 'static,
E1: Into<E2>,
E2: 'static,
{
map_err(stream, Into::into)
}
pub(crate) fn empty<T>() -> impl Stream<Item = T>
where
T: 'static,
{
streem::from_fn(|_| std::future::ready(()))
}
pub(crate) fn once<T>(value: T) -> impl Stream<Item = T>
where
T: 'static,
{
streem::from_fn(|yielder| yielder.yield_(value))
}
pub(crate) fn timeout<S>(
duration: Duration,
stream: S,
) -> impl Stream<Item = Result<S::Item, TimeoutError>>
where where
S: Stream, S: Stream,
F: FnMut(S::Item) -> U, S::Item: 'static,
{ {
type Item = U; streem::try_from_fn(|yielder| async move {
actix_rt::time::timeout(duration, async move {
let stream = std::pin::pin!(stream);
let mut streamer = stream.into_streamer();
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { while let Some(res) = streamer.next().await {
let this = self.project(); yielder.yield_ok(res).await;
}
let value = std::task::ready!(this.stream.poll_next(cx)); })
.await
Poll::Ready(value.map(this.func)) .map_err(|_| TimeoutError)
} })
} }
pub(crate) struct Empty<T>(PhantomData<T>); pub(crate) fn limit<S, E>(limit: usize, stream: S) -> impl Stream<Item = Result<Bytes, E>>
impl<T> Stream for Empty<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(None)
}
}
pub(crate) fn empty<T>() -> Empty<T> {
Empty(PhantomData)
}
pub(crate) struct Once<T>(Option<T>);
impl<T> Stream for Once<T>
where where
T: Unpin, S: Stream<Item = Result<Bytes, E>>,
E: From<LimitError> + 'static,
{ {
type Item = T; streem::try_from_fn(|yielder| async move {
let stream = std::pin::pin!(stream);
let mut streamer = stream.into_streamer();
fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> { let mut count = 0;
Poll::Ready(self.0.take())
}
}
pub(crate) fn once<T>(value: T) -> Once<T> { while let Some(bytes) = streamer.try_next().await? {
Once(Some(value)) count += bytes.len();
if count > limit {
return Err(LimitError.into());
}
yielder.yield_ok(bytes).await;
}
Ok(())
})
} }
pub(crate) type LocalBoxStream<'a, T> = Pin<Box<dyn Stream<Item = T> + 'a>>; pub(crate) type LocalBoxStream<'a, T> = Pin<Box<dyn Stream<Item = T> + 'a>>;
pub(crate) trait StreamLimit {
fn limit(self, limit: u64) -> Limit<Self>
where
Self: Sized,
{
Limit {
inner: self,
count: 0,
limit,
}
}
}
pub(crate) trait StreamTimeout {
fn timeout(self, duration: Duration) -> Timeout<Self>
where
Self: Sized,
{
Timeout {
sleep: actix_rt::time::sleep(duration),
inner: self,
expired: false,
woken: Arc::new(AtomicBool::new(true)),
}
}
}
pub(crate) trait IntoStreamer: Stream {
fn into_streamer(self) -> Streamer<Self>
where
Self: Sized,
{
Streamer(Some(self))
}
}
impl<T> IntoStreamer for T where T: Stream + Unpin {}
pub(crate) fn from_iterator<I: IntoIterator + Unpin + Send + 'static>(
iterator: I,
buffer: usize,
) -> IterStream<I, I::Item> {
IterStream {
state: IterStreamState::New { iterator, buffer },
}
}
impl<S, E> StreamLimit for S where S: Stream<Item = Result<Bytes, E>> {}
impl<S> StreamTimeout for S where S: Stream {}
pub(crate) struct Streamer<S>(Option<S>);
impl<S> Streamer<S> {
pub(crate) async fn next(&mut self) -> Option<S::Item>
where
S: Stream + Unpin,
{
let stream = self.0.as_mut().take()?;
let opt = std::future::poll_fn(|cx| Pin::new(&mut *stream).poll_next(cx)).await;
if opt.is_none() {
self.0.take();
}
opt
}
}
pin_project_lite::pin_project! {
pub(crate) struct Limit<S> {
#[pin]
inner: S,
count: u64,
limit: u64,
}
}
pin_project_lite::pin_project! {
pub(crate) struct Timeout<S> {
#[pin]
sleep: Sleep,
#[pin]
inner: S,
expired: bool,
woken: Arc<AtomicBool>,
}
}
enum IterStreamState<I, T>
where
T: 'static,
{
New {
iterator: I,
buffer: usize,
},
Running {
handle: JoinHandle<()>,
receiver: RecvStream<'static, T>,
},
Pending,
}
pub(crate) struct IterStream<I, T>
where
T: 'static,
{
state: IterStreamState<I, T>,
}
struct TimeoutWaker {
woken: Arc<AtomicBool>,
inner: Waker,
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("Resonse body larger than size limit")] #[error("Resonse body larger than size limit")]
pub(crate) struct LimitError; pub(crate) struct LimitError;
@ -271,135 +172,3 @@ pub(crate) struct LimitError;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("Timeout in body")] #[error("Timeout in body")]
pub(crate) struct TimeoutError; pub(crate) struct TimeoutError;
impl<S, E> Stream for Limit<S>
where
S: Stream<Item = Result<Bytes, E>>,
E: From<LimitError>,
{
type Item = Result<Bytes, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().project();
let limit = this.limit;
let count = this.count;
let inner = this.inner;
inner.poll_next(cx).map(|opt| {
opt.map(|res| match res {
Ok(bytes) => {
*count += bytes.len() as u64;
if *count > *limit {
return Err(LimitError.into());
}
Ok(bytes)
}
Err(e) => Err(e),
})
})
}
}
impl Wake for TimeoutWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
self.woken.store(true, Ordering::Release);
self.inner.wake_by_ref();
}
}
impl<S, T> Stream for Timeout<S>
where
S: Stream<Item = T>,
{
type Item = Result<T, TimeoutError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().project();
if *this.expired {
return Poll::Ready(None);
}
if this.woken.swap(false, Ordering::Acquire) {
let timeout_waker = Arc::new(TimeoutWaker {
woken: Arc::clone(this.woken),
inner: cx.waker().clone(),
})
.into();
let mut timeout_cx = Context::from_waker(&timeout_waker);
if this.sleep.poll(&mut timeout_cx).is_ready() {
*this.expired = true;
return Poll::Ready(Some(Err(TimeoutError)));
}
}
this.inner.poll_next(cx).map(|opt| opt.map(Ok))
}
}
impl<I, T> Stream for IterStream<I, T>
where
I: IntoIterator<Item = T> + Send + Unpin + 'static,
T: Send + 'static,
{
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().get_mut();
match std::mem::replace(&mut this.state, IterStreamState::Pending) {
IterStreamState::New { iterator, buffer } => {
let (sender, receiver) = crate::sync::channel(buffer);
let mut handle = crate::sync::spawn_blocking(move || {
let iterator = iterator.into_iter();
for item in iterator {
if sender.send(item).is_err() {
break;
}
}
});
if Pin::new(&mut handle).poll(cx).is_ready() {
return Poll::Ready(None);
}
this.state = IterStreamState::Running {
handle,
receiver: receiver.into_stream(),
};
self.poll_next(cx)
}
IterStreamState::Running {
mut handle,
mut receiver,
} => match Pin::new(&mut receiver).poll_next(cx) {
Poll::Ready(Some(item)) => {
this.state = IterStreamState::Running { handle, receiver };
Poll::Ready(Some(item))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => {
if Pin::new(&mut handle).poll(cx).is_ready() {
return Poll::Ready(None);
}
this.state = IterStreamState::Running { handle, receiver };
Poll::Pending
}
},
IterStreamState::Pending => panic!("Polled after completion"),
}
}
}