2
0
Fork 0
mirror of https://git.asonix.dog/asonix/pict-rs synced 2024-12-22 19:31:35 +00:00

Add external validation check

This commit is contained in:
asonix 2023-09-05 20:45:07 -05:00
parent 509a52ec6b
commit e15a82c0c7
11 changed files with 222 additions and 40 deletions

View file

@ -23,6 +23,7 @@ path = 'data/sled-repo-local'
cache_capacity = 67108864 cache_capacity = 67108864
[media] [media]
# external_validation = 'http://localhost:8076'
max_file_size = 40 max_file_size = 40
filters = ['blur', 'crop', 'identity', 'resize', 'thumbnail'] filters = ['blur', 'crop', 'identity', 'resize', 'thumbnail']

View file

@ -143,6 +143,15 @@ cache_capacity = 67108864
## Media Processing Configuration ## Media Processing Configuration
[media] [media]
## Optional: URL for external validation of media
# environment variable: PICTRS__MEDIA__EXTERNAL_VALIDATION
# default: empty
#
# The expected API for external validators is to accept a POST with the media as the request body,
# and a valid `Content-Type` header. The validator should return a 2XX response when the media
# passes validation. Any other status code is considered a validation failure.
external_validation = 'http://localhost:8076'
## Optional: max file size (in Megabytes) ## Optional: max file size (in Megabytes)
# environment variable: PICTRS__MEDIA__MAX_FILE_SIZE # environment variable: PICTRS__MEDIA__MAX_FILE_SIZE
# default: 40 # default: 40
@ -499,13 +508,15 @@ crf_1440 = 24
crf_2160 = 15 crf_2160 = 15
## Database configuration ### Database configuration
## Sled repo configuration example
[repo] [repo]
## Optional: database backend to use ## Optional: database backend to use
# environment variable: PICTRS__REPO__TYPE # environment variable: PICTRS__REPO__TYPE
# default: sled # default: sled
# #
# available options: sled # available options: sled, postgres
type = 'sled' type = 'sled'
## Optional: path to sled repository ## Optional: path to sled repository
@ -527,7 +538,39 @@ cache_capacity = 67108864
export_path = "/mnt/exports" export_path = "/mnt/exports"
## Media storage configuration ## Postgres repo configuration example
[repo]
## Optional: database backend to use
# environment variable: PICTRS__REPO__TYPE
# default: sled
#
# available options: sled, postgres
type = 'postgres'
## Required: URL to postgres database
# environment variable: PICTRS__REPO__URL
# default: empty
url = 'postgres://user:password@host:5432/db'
### Media storage configuration
## Filesystem media storage example
[store]
## Optional: type of media storage to use
# environment variable: PICTRS__STORE__TYPE
# default: filesystem
#
# available options: filesystem, object_storage
type = 'filesystem'
## Optional: path to uploaded media
# environment variable: PICTRS__STORE__PATH
# default: /mnt/files
path = '/mnt/files'
## Object media storage example
[store] [store]
## Optional: type of media storage to use ## Optional: type of media storage to use
# environment variable: PICTRS__STORE__TYPE # environment variable: PICTRS__STORE__TYPE
@ -597,18 +640,3 @@ signature_expiration = 15
# This value is the total wait time, and not additional wait time on top of the # This value is the total wait time, and not additional wait time on top of the
# signature_expiration. # signature_expiration.
client_timeout = 30 client_timeout = 30
## Filesystem media storage example
# ## Media storage configuration
# [store]
# ## Optional: type of media storage to use
# # environment variable: PICTRS__STORE__TYPE
# # default: filesystem
# #
# # available options: filesystem, object_storage
# type = 'filesystem'
#
# ## Optional: path to uploaded media
# # environment variable: PICTRS__STORE__PATH
# # default: /mnt/files
# path = '/mnt/files'

View file

@ -56,6 +56,7 @@ impl Args {
client_timeout, client_timeout,
metrics_prometheus_address, metrics_prometheus_address,
media_preprocess_steps, media_preprocess_steps,
media_external_validation,
media_max_file_size, media_max_file_size,
media_process_timeout, media_process_timeout,
media_retention_variants, media_retention_variants,
@ -183,6 +184,7 @@ impl Args {
max_file_size: media_max_file_size, max_file_size: media_max_file_size,
process_timeout: media_process_timeout, process_timeout: media_process_timeout,
preprocess_steps: media_preprocess_steps, preprocess_steps: media_preprocess_steps,
external_validation: media_external_validation,
filters: media_filters, filters: media_filters,
retention: retention.set(), retention: retention.set(),
image: image.set(), image: image.set(),
@ -549,6 +551,8 @@ struct Media {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
preprocess_steps: Option<String>, preprocess_steps: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
external_validation: Option<Url>,
#[serde(skip_serializing_if = "Option::is_none")]
filters: Option<Vec<String>>, filters: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
retention: Option<Retention>, retention: Option<Retention>,
@ -884,6 +888,10 @@ struct Run {
#[arg(long)] #[arg(long)]
media_preprocess_steps: Option<String>, media_preprocess_steps: Option<String>,
/// Optional endpoint to submit uploaded media to for validation
#[arg(long)]
media_external_validation: Option<Url>,
/// Which media filters should be enabled on the `process` endpoint /// Which media filters should be enabled on the `process` endpoint
#[arg(long)] #[arg(long)]
media_filters: Option<Vec<String>>, media_filters: Option<Vec<String>>,

View file

@ -165,6 +165,8 @@ pub(crate) struct OldDb {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub(crate) struct Media { pub(crate) struct Media {
pub(crate) external_validation: Option<Url>,
pub(crate) max_file_size: usize, pub(crate) max_file_size: usize,
pub(crate) process_timeout: u64, pub(crate) process_timeout: u64,

View file

@ -151,6 +151,9 @@ pub(crate) enum UploadError {
#[error("Response timeout")] #[error("Response timeout")]
Timeout(#[from] crate::stream::TimeoutError), Timeout(#[from] crate::stream::TimeoutError),
#[error("Failed external validation")]
FailedExternalValidation,
} }
impl UploadError { impl UploadError {
@ -184,6 +187,7 @@ impl UploadError {
Self::Range => ErrorCode::RANGE_NOT_SATISFIABLE, Self::Range => ErrorCode::RANGE_NOT_SATISFIABLE,
Self::Limit(_) => ErrorCode::VALIDATE_FILE_SIZE, Self::Limit(_) => ErrorCode::VALIDATE_FILE_SIZE,
Self::Timeout(_) => ErrorCode::STREAM_TOO_SLOW, Self::Timeout(_) => ErrorCode::STREAM_TOO_SLOW,
Self::FailedExternalValidation => ErrorCode::FAILED_EXTERNAL_VALIDATION,
} }
} }
@ -232,7 +236,8 @@ impl ResponseError for Error {
| UploadError::Validation(_) | UploadError::Validation(_)
| UploadError::UnsupportedProcessExtension | UploadError::UnsupportedProcessExtension
| UploadError::InvalidProcessExtension | UploadError::InvalidProcessExtension
| UploadError::ReadOnly, | UploadError::ReadOnly
| UploadError::FailedExternalValidation,
) => StatusCode::BAD_REQUEST, ) => StatusCode::BAD_REQUEST,
Some(UploadError::Magick(e)) if e.is_client_error() => StatusCode::BAD_REQUEST, Some(UploadError::Magick(e)) if e.is_client_error() => StatusCode::BAD_REQUEST,
Some(UploadError::Ffmpeg(e)) if e.is_client_error() => StatusCode::BAD_REQUEST, Some(UploadError::Ffmpeg(e)) if e.is_client_error() => StatusCode::BAD_REQUEST,

View file

@ -141,4 +141,7 @@ impl ErrorCode {
pub(crate) const UNKNOWN_ERROR: ErrorCode = ErrorCode { pub(crate) const UNKNOWN_ERROR: ErrorCode = ErrorCode {
code: "unknown-error", code: "unknown-error",
}; };
pub(crate) const FAILED_EXTERNAL_VALIDATION: ErrorCode = ErrorCode {
code: "failed-external-validation",
};
} }

View file

@ -7,10 +7,12 @@ use crate::{
formats::{InternalFormat, Validations}, formats::{InternalFormat, Validations},
repo::{Alias, ArcRepo, DeleteToken, Hash}, repo::{Alias, ArcRepo, DeleteToken, Hash},
store::Store, store::Store,
stream::IntoStreamer, stream::{IntoStreamer, MakeSend},
}; };
use actix_web::web::Bytes; use actix_web::web::Bytes;
use futures_core::Stream; use futures_core::Stream;
use reqwest::Body;
use reqwest_middleware::ClientWithMiddleware;
use tracing::{Instrument, Span}; use tracing::{Instrument, Span};
mod hasher; mod hasher;
@ -41,10 +43,11 @@ where
Ok(buf.into_bytes()) Ok(buf.into_bytes())
} }
#[tracing::instrument(skip(repo, store, stream, media))] #[tracing::instrument(skip(repo, store, client, stream, media))]
pub(crate) async fn ingest<S>( pub(crate) async fn ingest<S>(
repo: &ArcRepo, repo: &ArcRepo,
store: &S, store: &S,
client: &ClientWithMiddleware,
stream: impl Stream<Item = Result<Bytes, Error>> + Unpin + 'static, stream: impl Stream<Item = Result<Bytes, Error>> + Unpin + 'static,
declared_alias: Option<Alias>, declared_alias: Option<Alias>,
media: &crate::config::Media, media: &crate::config::Media,
@ -113,6 +116,22 @@ where
identifier: Some(identifier.clone()), identifier: Some(identifier.clone()),
}; };
if let Some(endpoint) = &media.external_validation {
let stream = store.to_stream(&identifier, None, None).await?.make_send();
let response = client
.post(endpoint.as_str())
.header("Content-Type", input_type.media_type().as_ref())
.body(Body::wrap_stream(stream))
.send()
.instrument(tracing::info_span!("external-validation"))
.await?;
if !response.status().is_success() {
return Err(UploadError::FailedExternalValidation.into());
}
}
let (hash, size) = state.borrow_mut().finalize_reset(); let (hash, size) = state.borrow_mut().finalize_reset();
let hash = Hash::new(hash, size, input_type); let hash = Hash::new(hash, size, input_type);

View file

@ -140,6 +140,10 @@ impl<S: Store + 'static> FormData for Upload<S> {
.app_data::<web::Data<S>>() .app_data::<web::Data<S>>()
.expect("No store in request") .expect("No store in request")
.clone(); .clone();
let client = req
.app_data::<web::Data<ClientWithMiddleware>>()
.expect("No client in request")
.clone();
let config = req let config = req
.app_data::<web::Data<Configuration>>() .app_data::<web::Data<Configuration>>()
.expect("No configuration in request") .expect("No configuration in request")
@ -154,6 +158,7 @@ impl<S: Store + 'static> FormData for Upload<S> {
Field::array(Field::file(move |filename, _, stream| { Field::array(Field::file(move |filename, _, stream| {
let repo = repo.clone(); let repo = repo.clone();
let store = store.clone(); let store = store.clone();
let client = client.clone();
let config = config.clone(); let config = config.clone();
metrics::increment_counter!("pict-rs.files", "upload" => "inline"); metrics::increment_counter!("pict-rs.files", "upload" => "inline");
@ -168,7 +173,8 @@ impl<S: Store + 'static> FormData for Upload<S> {
return Err(UploadError::ReadOnly.into()); return Err(UploadError::ReadOnly.into());
} }
ingest::ingest(&repo, &**store, stream, None, &config.media).await ingest::ingest(&repo, &**store, &client, stream, None, &config.media)
.await
} }
.instrument(span), .instrument(span),
) )
@ -196,6 +202,10 @@ impl<S: Store + 'static> FormData for Import<S> {
.app_data::<web::Data<S>>() .app_data::<web::Data<S>>()
.expect("No store in request") .expect("No store in request")
.clone(); .clone();
let client = req
.app_data::<ClientWithMiddleware>()
.expect("No client in request")
.clone();
let config = req let config = req
.app_data::<web::Data<Configuration>>() .app_data::<web::Data<Configuration>>()
.expect("No configuration in request") .expect("No configuration in request")
@ -213,6 +223,7 @@ impl<S: Store + 'static> FormData for Import<S> {
Field::array(Field::file(move |filename, _, stream| { Field::array(Field::file(move |filename, _, stream| {
let repo = repo.clone(); let repo = repo.clone();
let store = store.clone(); let store = store.clone();
let client = client.clone();
let config = config.clone(); let config = config.clone();
metrics::increment_counter!("pict-rs.files", "import" => "inline"); metrics::increment_counter!("pict-rs.files", "import" => "inline");
@ -230,6 +241,7 @@ impl<S: Store + 'static> FormData for Import<S> {
ingest::ingest( ingest::ingest(
&repo, &repo,
&**store, &**store,
&client,
stream, stream,
Some(Alias::from_existing(&filename)), Some(Alias::from_existing(&filename)),
&config.media, &config.media,
@ -479,9 +491,10 @@ async fn ingest_inline<S: Store + 'static>(
stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static,
repo: &ArcRepo, repo: &ArcRepo,
store: &S, store: &S,
client: &ClientWithMiddleware,
config: &Configuration, config: &Configuration,
) -> Result<(Alias, DeleteToken, Details), Error> { ) -> Result<(Alias, DeleteToken, Details), Error> {
let session = ingest::ingest(repo, store, stream, None, &config.media).await?; let session = ingest::ingest(repo, store, client, stream, None, &config.media).await?;
let alias = session.alias().expect("alias should exist").to_owned(); let alias = session.alias().expect("alias should exist").to_owned();
@ -501,17 +514,17 @@ async fn download<S: Store + 'static>(
config: web::Data<Configuration>, config: web::Data<Configuration>,
query: web::Query<UrlQuery>, query: web::Query<UrlQuery>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let stream = download_stream(client, &query.url, &config).await?; let stream = download_stream(&client, &query.url, &config).await?;
if query.backgrounded { if query.backgrounded {
do_download_backgrounded(stream, repo, store).await do_download_backgrounded(stream, repo, store).await
} else { } else {
do_download_inline(stream, repo, store, config).await do_download_inline(stream, repo, store, &client, config).await
} }
} }
async fn download_stream( async fn download_stream(
client: web::Data<ClientWithMiddleware>, client: &ClientWithMiddleware,
url: &str, url: &str,
config: &Configuration, config: &Configuration,
) -> Result<impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, Error> { ) -> Result<impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, Error> {
@ -533,16 +546,21 @@ async fn download_stream(
Ok(stream) Ok(stream)
} }
#[tracing::instrument(name = "Downloading file inline", skip(stream, repo, store, config))] #[tracing::instrument(
name = "Downloading file inline",
skip(stream, repo, store, client, config)
)]
async fn do_download_inline<S: Store + 'static>( async fn do_download_inline<S: Store + 'static>(
stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static, stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static,
repo: web::Data<ArcRepo>, repo: web::Data<ArcRepo>,
store: web::Data<S>, store: web::Data<S>,
client: &ClientWithMiddleware,
config: web::Data<Configuration>, config: web::Data<Configuration>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
metrics::increment_counter!("pict-rs.files", "download" => "inline"); metrics::increment_counter!("pict-rs.files", "download" => "inline");
let (alias, delete_token, details) = ingest_inline(stream, &repo, &store, &config).await?; let (alias, delete_token, details) =
ingest_inline(stream, &repo, &store, &client, &config).await?;
Ok(HttpResponse::Created().json(&serde_json::json!({ Ok(HttpResponse::Created().json(&serde_json::json!({
"msg": "ok", "msg": "ok",
@ -817,9 +835,9 @@ async fn process<S: Store + 'static>(
let alias = if let Some(alias) = repo.related(proxy.clone()).await? { let alias = if let Some(alias) = repo.related(proxy.clone()).await? {
alias alias
} else if !config.server.read_only { } else if !config.server.read_only {
let stream = download_stream(client, proxy.as_str(), &config).await?; let stream = download_stream(&client, proxy.as_str(), &config).await?;
let (alias, _, _) = ingest_inline(stream, &repo, &store, &config).await?; let (alias, _, _) = ingest_inline(stream, &repo, &store, &client, &config).await?;
repo.relate_url(proxy, alias.clone()).await?; repo.relate_url(proxy, alias.clone()).await?;
@ -1115,9 +1133,9 @@ async fn serve_query<S: Store + 'static>(
let alias = if let Some(alias) = repo.related(proxy.clone()).await? { let alias = if let Some(alias) = repo.related(proxy.clone()).await? {
alias alias
} else if !config.server.read_only { } else if !config.server.read_only {
let stream = download_stream(client, proxy.as_str(), &config).await?; let stream = download_stream(&client, proxy.as_str(), &config).await?;
let (alias, _, _) = ingest_inline(stream, &repo, &store, &config).await?; let (alias, _, _) = ingest_inline(stream, &repo, &store, &client, &config).await?;
repo.relate_url(proxy, alias.clone()).await?; repo.relate_url(proxy, alias.clone()).await?;
@ -1703,8 +1721,13 @@ fn spawn_cleanup(repo: ArcRepo, config: &Configuration) {
}); });
} }
fn spawn_workers<S>(repo: ArcRepo, store: S, config: Configuration, process_map: ProcessMap) fn spawn_workers<S>(
where repo: ArcRepo,
store: S,
client: ClientWithMiddleware,
config: Configuration,
process_map: ProcessMap,
) where
S: Store + 'static, S: Store + 'static,
{ {
crate::sync::spawn(queue::process_cleanup( crate::sync::spawn(queue::process_cleanup(
@ -1712,7 +1735,13 @@ where
store.clone(), store.clone(),
config.clone(), config.clone(),
)); ));
crate::sync::spawn(queue::process_images(repo, store, process_map, config)); crate::sync::spawn(queue::process_images(
repo,
store,
client,
process_map,
config,
));
} }
async fn launch_file_store<F: Fn(&mut web::ServiceConfig) + Send + Clone + 'static>( async fn launch_file_store<F: Fn(&mut web::ServiceConfig) + Send + Clone + 'static>(
@ -1738,6 +1767,7 @@ async fn launch_file_store<F: Fn(&mut web::ServiceConfig) + Send + Clone + 'stat
spawn_workers( spawn_workers(
repo.clone(), repo.clone(),
store.clone(), store.clone(),
client.clone(),
config.clone(), config.clone(),
process_map.clone(), process_map.clone(),
); );
@ -1777,6 +1807,7 @@ async fn launch_object_store<F: Fn(&mut web::ServiceConfig) + Send + Clone + 'st
spawn_workers( spawn_workers(
repo.clone(), repo.clone(),
store.clone(), store.clone(),
client.clone(),
config.clone(), config.clone(),
process_map.clone(), process_map.clone(),
); );

View file

@ -8,6 +8,7 @@ use crate::{
serde_str::Serde, serde_str::Serde,
store::Store, store::Store,
}; };
use reqwest_middleware::ClientWithMiddleware;
use std::{ use std::{
path::PathBuf, path::PathBuf,
sync::Arc, sync::Arc,
@ -164,12 +165,14 @@ pub(crate) async fn process_cleanup<S: Store>(
pub(crate) async fn process_images<S: Store + 'static>( pub(crate) async fn process_images<S: Store + 'static>(
repo: Arc<dyn FullRepo>, repo: Arc<dyn FullRepo>,
store: S, store: S,
client: ClientWithMiddleware,
process_map: ProcessMap, process_map: ProcessMap,
config: Configuration, config: Configuration,
) { ) {
process_image_jobs( process_image_jobs(
&repo, &repo,
&store, &store,
&client,
&process_map, &process_map,
&config, &config,
PROCESS_QUEUE, PROCESS_QUEUE,
@ -301,6 +304,7 @@ where
async fn process_image_jobs<S, F>( async fn process_image_jobs<S, F>(
repo: &Arc<dyn FullRepo>, repo: &Arc<dyn FullRepo>,
store: &S, store: &S,
client: &ClientWithMiddleware,
process_map: &ProcessMap, process_map: &ProcessMap,
config: &Configuration, config: &Configuration,
queue: &'static str, queue: &'static str,
@ -310,6 +314,7 @@ async fn process_image_jobs<S, F>(
for<'a> F: Fn( for<'a> F: Fn(
&'a Arc<dyn FullRepo>, &'a Arc<dyn FullRepo>,
&'a S, &'a S,
&'a ClientWithMiddleware,
&'a ProcessMap, &'a ProcessMap,
&'a Configuration, &'a Configuration,
serde_json::Value, serde_json::Value,
@ -319,8 +324,17 @@ async fn process_image_jobs<S, F>(
let worker_id = uuid::Uuid::new_v4(); let worker_id = uuid::Uuid::new_v4();
loop { loop {
let res = let res = image_job_loop(
image_job_loop(repo, store, process_map, config, worker_id, queue, callback).await; repo,
store,
client,
process_map,
config,
worker_id,
queue,
callback,
)
.await;
if let Err(e) = res { if let Err(e) = res {
tracing::warn!("Error processing jobs: {}", format!("{e}")); tracing::warn!("Error processing jobs: {}", format!("{e}"));
@ -340,6 +354,7 @@ async fn process_image_jobs<S, F>(
async fn image_job_loop<S, F>( async fn image_job_loop<S, F>(
repo: &Arc<dyn FullRepo>, repo: &Arc<dyn FullRepo>,
store: &S, store: &S,
client: &ClientWithMiddleware,
process_map: &ProcessMap, process_map: &ProcessMap,
config: &Configuration, config: &Configuration,
worker_id: uuid::Uuid, worker_id: uuid::Uuid,
@ -351,6 +366,7 @@ where
for<'a> F: Fn( for<'a> F: Fn(
&'a Arc<dyn FullRepo>, &'a Arc<dyn FullRepo>,
&'a S, &'a S,
&'a ClientWithMiddleware,
&'a ProcessMap, &'a ProcessMap,
&'a Configuration, &'a Configuration,
serde_json::Value, serde_json::Value,
@ -372,7 +388,7 @@ where
queue, queue,
worker_id, worker_id,
job_id, job_id,
(callback)(repo, store, process_map, config, job), (callback)(repo, store, client, process_map, config, job),
) )
}) })
.instrument(span) .instrument(span)

View file

@ -1,3 +1,5 @@
use reqwest_middleware::ClientWithMiddleware;
use crate::{ use crate::{
concurrent_processor::ProcessMap, concurrent_processor::ProcessMap,
config::Configuration, config::Configuration,
@ -16,6 +18,7 @@ use std::{path::PathBuf, sync::Arc};
pub(super) fn perform<'a, S>( pub(super) fn perform<'a, S>(
repo: &'a ArcRepo, repo: &'a ArcRepo,
store: &'a S, store: &'a S,
client: &'a ClientWithMiddleware,
process_map: &'a ProcessMap, process_map: &'a ProcessMap,
config: &'a Configuration, config: &'a Configuration,
job: serde_json::Value, job: serde_json::Value,
@ -34,6 +37,7 @@ where
process_ingest( process_ingest(
repo, repo,
store, store,
client,
Arc::from(identifier), Arc::from(identifier),
Serde::into_inner(upload_id), Serde::into_inner(upload_id),
declared_alias.map(Serde::into_inner), declared_alias.map(Serde::into_inner),
@ -69,10 +73,11 @@ where
}) })
} }
#[tracing::instrument(skip(repo, store, media))] #[tracing::instrument(skip(repo, store, client, media))]
async fn process_ingest<S>( async fn process_ingest<S>(
repo: &ArcRepo, repo: &ArcRepo,
store: &S, store: &S,
client: &ClientWithMiddleware,
unprocessed_identifier: Arc<str>, unprocessed_identifier: Arc<str>,
upload_id: UploadId, upload_id: UploadId,
declared_alias: Option<Alias>, declared_alias: Option<Alias>,
@ -85,6 +90,7 @@ where
let ident = unprocessed_identifier.clone(); let ident = unprocessed_identifier.clone();
let store2 = store.clone(); let store2 = store.clone();
let repo = repo.clone(); let repo = repo.clone();
let client = client.clone();
let media = media.clone(); let media = media.clone();
let error_boundary = crate::sync::spawn(async move { let error_boundary = crate::sync::spawn(async move {
@ -94,7 +100,8 @@ where
.map(|res| res.map_err(Error::from)); .map(|res| res.map_err(Error::from));
let session = let session =
crate::ingest::ingest(&repo, &store2, stream, declared_alias, &media).await?; crate::ingest::ingest(&repo, &store2, &client, stream, declared_alias, &media)
.await?;
Ok(session) as Result<Session, Error> Ok(session) as Result<Session, Error>
}) })

View file

@ -14,6 +14,68 @@ use std::{
time::Duration, time::Duration,
}; };
pub(crate) trait MakeSend<T>: Stream<Item = std::io::Result<T>>
where
T: 'static,
{
fn make_send(self) -> MakeSendStream<T>
where
Self: Sized + 'static,
{
let (tx, rx) = crate::sync::channel(4);
MakeSendStream {
handle: crate::sync::spawn(async move {
let this = std::pin::pin!(self);
let mut stream = this.into_streamer();
while let Some(res) = stream.next().await {
if tx.send_async(res).await.is_err() {
return;
}
}
}),
rx: rx.into_stream(),
}
}
}
impl<S, T> MakeSend<T> for S
where
S: Stream<Item = std::io::Result<T>>,
T: 'static,
{
}
pub(crate) struct MakeSendStream<T>
where
T: 'static,
{
handle: actix_rt::task::JoinHandle<()>,
rx: flume::r#async::RecvStream<'static, std::io::Result<T>>,
}
impl<T> Stream for MakeSendStream<T>
where
T: 'static,
{
type Item = std::io::Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Ready(opt) => Poll::Ready(opt),
Poll::Pending if std::task::ready!(Pin::new(&mut self.handle).poll(cx)).is_err() => {
Poll::Ready(Some(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Stream panicked",
))))
}
Poll::Pending => Poll::Pending,
}
}
}
pin_project_lite::pin_project! { pin_project_lite::pin_project! {
pub(crate) struct Map<S, F> { pub(crate) struct Map<S, F> {
#[pin] #[pin]