2
0
Fork 0
mirror of https://git.asonix.dog/asonix/pict-rs synced 2025-01-11 20:15:49 +00:00

Add external validation check

This commit is contained in:
asonix 2023-09-06 18:53:52 -05:00
parent 7b6190c045
commit b81bbb9b2d
10 changed files with 191 additions and 14 deletions

View file

@ -18,6 +18,7 @@ targets = 'info'
path = 'data/'
[media]
external_validation = "http://localhost:8076"
max_width = 10000
max_height = 10000
max_area = 40000000

View file

@ -122,6 +122,15 @@ path = '/mnt'
## Media Processing Configuration
[media]
## Optional: URL for external validation of media
# environment variable: PICTRS__MEDIA__EXTERNAL_VALIDATION
# default: empty
#
# The expected API for external validators is to accept a POST with the media as the request body,
# and a valid `Content-Type` header. The validator should return a 2XX response when the media
# passes validation. Any other status code is considered a validation failure.
external_validation = 'http://localhost:8076'
## Optional: preprocessing steps for uploaded images
# environment variable: PICTRS__MEDIA__PREPROCESS_STEPS
# default: empty

View file

@ -46,6 +46,7 @@ impl Args {
api_key,
worker_id,
client_pool_size,
media_external_validation,
media_preprocess_steps,
media_skip_validate_imports,
media_max_width,
@ -84,6 +85,7 @@ impl Args {
})
};
let media = Media {
external_validation: media_external_validation,
preprocess_steps: media_preprocess_steps,
skip_validate_imports: media_skip_validate_imports,
max_width: media_max_width,
@ -336,6 +338,8 @@ struct OldDb {
#[derive(Debug, Default, serde::Serialize)]
#[serde(rename_all = "snake_case")]
struct Media {
#[serde(skip_serializing_if = "Option::is_none")]
external_validation: Option<Url>,
#[serde(skip_serializing_if = "Option::is_none")]
preprocess_steps: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
@ -454,6 +458,10 @@ struct Run {
#[arg(long)]
client_pool_size: Option<usize>,
/// Optional endpoint to submit uploaded media to for validation
#[arg(long)]
media_external_validation: Option<Url>,
/// Optional pre-processing steps for uploaded media.
///
/// All still images will be put through these steps before saving

View file

@ -143,6 +143,9 @@ pub(crate) struct OldDb {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case")]
pub(crate) struct Media {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) external_validation: Option<Url>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) preprocess_steps: Option<String>,

View file

@ -135,6 +135,9 @@ pub(crate) enum UploadError {
#[error("Response timeout")]
Timeout(#[from] crate::stream::TimeoutError),
#[error("Failed external validation")]
ExternalValidation,
}
impl From<actix_web::error::BlockingError> for UploadError {

View file

@ -9,11 +9,14 @@ use crate::{
};
use actix_web::web::Bytes;
use futures_util::{Stream, StreamExt};
use reqwest::Body;
use reqwest_middleware::ClientWithMiddleware;
use sha2::{Digest, Sha256};
use tracing::{Instrument, Span};
mod hasher;
use hasher::Hasher;
use url::Url;
#[derive(Debug)]
pub(crate) struct Session<R, S>
@ -41,12 +44,15 @@ where
Ok(buf.into_bytes())
}
#[tracing::instrument(skip(repo, store, stream))]
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip(repo, store, client, stream))]
pub(crate) async fn ingest<R, S>(
repo: &R,
store: &S,
client: &ClientWithMiddleware,
stream: impl Stream<Item = Result<Bytes, Error>> + Unpin + 'static,
declared_alias: Option<Alias>,
external_validation: Option<&Url>,
should_validate: bool,
timeout: u64,
) -> Result<Session<R, S>, Error>
@ -97,6 +103,46 @@ where
identifier: Some(identifier.clone()),
};
if let Some(external_validation) = external_validation {
struct RxStream<T>(tokio::sync::mpsc::Receiver<T>);
impl<T> futures_util::Stream for RxStream<T> {
type Item = T;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.0.poll_recv(cx)
}
}
let mut stream = store.to_stream(&identifier, None, None).await?;
let (tx, rx) = tokio::sync::mpsc::channel(8);
let handle = actix_rt::spawn(async move {
while let Some(item) = stream.next().await {
if tx.send(item).await.is_err() {
return;
}
}
});
let result = client
.post(external_validation.as_str())
.header("Content-Type", input_type.content_type().to_string())
.body(Body::wrap_stream(RxStream(rx)))
.send()
.await?;
// structure that concurrency bb
let _ = handle.await;
if !result.status().is_success() {
return Err(UploadError::ExternalValidation.into());
}
}
let hash = hasher.borrow_mut().finalize_reset().to_vec();
save_upload(&mut session, repo, store, &hash, &identifier).await?;

View file

@ -154,6 +154,10 @@ impl<R: FullRepo, S: Store + 'static> FormData for Upload<R, S> {
.app_data::<web::Data<S>>()
.expect("No store in request")
.clone();
let client = req
.app_data::<web::Data<ClientWithMiddleware>>()
.expect("No client in request")
.clone();
Form::new()
.max_files(10)
@ -164,6 +168,7 @@ impl<R: FullRepo, S: Store + 'static> FormData for Upload<R, S> {
Field::array(Field::file(move |filename, _, stream| {
let repo = repo.clone();
let store = store.clone();
let client = client.clone();
let span = tracing::info_span!("file-upload", ?filename);
@ -174,8 +179,10 @@ impl<R: FullRepo, S: Store + 'static> FormData for Upload<R, S> {
ingest::ingest(
&**repo,
&**store,
&client,
stream,
None,
CONFIG.media.external_validation.as_ref(),
true,
CONFIG.media.process_timeout,
)
@ -207,6 +214,10 @@ impl<R: FullRepo, S: Store + 'static> FormData for Import<R, S> {
.app_data::<web::Data<S>>()
.expect("No store in request")
.clone();
let client = req
.app_data::<web::Data<ClientWithMiddleware>>()
.expect("No client in request")
.clone();
// Create a new Multipart Form validator for internal imports
//
@ -220,6 +231,7 @@ impl<R: FullRepo, S: Store + 'static> FormData for Import<R, S> {
Field::array(Field::file(move |filename, _, stream| {
let repo = repo.clone();
let store = store.clone();
let client = client.clone();
let span = tracing::info_span!("file-import", ?filename);
@ -230,8 +242,10 @@ impl<R: FullRepo, S: Store + 'static> FormData for Import<R, S> {
ingest::ingest(
&**repo,
&**store,
&client,
stream,
Some(Alias::from_existing(&filename)),
CONFIG.media.external_validation.as_ref(),
!CONFIG.media.skip_validate_imports,
CONFIG.media.process_timeout,
)
@ -479,7 +493,7 @@ async fn download<R: FullRepo + 'static, S: Store + 'static>(
if query.backgrounded {
do_download_backgrounded(stream, repo, store).await
} else {
do_download_inline(stream, repo, store).await
do_download_inline(stream, repo, store, client).await
}
}
@ -488,12 +502,15 @@ async fn do_download_inline<R: FullRepo + 'static, S: Store + 'static>(
stream: impl Stream<Item = Result<web::Bytes, Error>> + Unpin + 'static,
repo: web::Data<R>,
store: web::Data<S>,
client: web::Data<ClientWithMiddleware>,
) -> Result<HttpResponse, Error> {
let mut session = ingest::ingest(
&repo,
&store,
&client,
stream,
None,
CONFIG.media.external_validation.as_ref(),
true,
CONFIG.media.process_timeout,
)
@ -1225,7 +1242,7 @@ fn configure_endpoints<
);
}
fn spawn_workers<R, S>(repo: R, store: S)
fn spawn_workers<R, S>(repo: R, store: S, client: ClientWithMiddleware)
where
R: FullRepo + 'static,
S: Store + 'static,
@ -1234,11 +1251,18 @@ where
actix_rt::spawn(queue::process_cleanup(
repo.clone(),
store.clone(),
client.clone(),
next_worker_id(),
))
});
tracing::trace_span!(parent: None, "Spawn task").in_scope(|| {
actix_rt::spawn(queue::process_images(
repo,
store,
client.clone(),
next_worker_id(),
))
});
tracing::trace_span!(parent: None, "Spawn task")
.in_scope(|| actix_rt::spawn(queue::process_images(repo, store, next_worker_id())));
}
async fn launch_file_store<R: FullRepo + 'static, F: Fn(&mut web::ServiceConfig) + Send + Clone>(
@ -1254,7 +1278,7 @@ async fn launch_file_store<R: FullRepo + 'static, F: Fn(&mut web::ServiceConfig)
let repo = repo.clone();
let extra_config = extra_config.clone();
spawn_workers(repo.clone(), store.clone());
spawn_workers(repo.clone(), store.clone(), client.clone());
App::new()
.wrap(TracingLogger::default())
@ -1282,7 +1306,7 @@ async fn launch_object_store<
let repo = repo.clone();
let extra_config = extra_config.clone();
spawn_workers(repo.clone(), store.clone());
spawn_workers(repo.clone(), store.clone(), client.clone());
App::new()
.wrap(TracingLogger::default())

View file

@ -8,6 +8,7 @@ use crate::{
store::{Identifier, Store},
};
use base64::{prelude::BASE64_STANDARD, Engine};
use reqwest_middleware::ClientWithMiddleware;
use std::{future::Future, path::PathBuf, pin::Pin};
use tracing::Instrument;
@ -157,16 +158,38 @@ pub(crate) async fn queue_generate<R: QueueRepo>(
Ok(())
}
pub(crate) async fn process_cleanup<R: FullRepo, S: Store>(repo: R, store: S, worker_id: String) {
process_jobs(&repo, &store, worker_id, CLEANUP_QUEUE, cleanup::perform).await
pub(crate) async fn process_cleanup<R: FullRepo, S: Store>(
repo: R,
store: S,
client: ClientWithMiddleware,
worker_id: String,
) {
process_jobs(
&repo,
&store,
&client,
worker_id,
CLEANUP_QUEUE,
cleanup::perform,
)
.await
}
pub(crate) async fn process_images<R: FullRepo + 'static, S: Store + 'static>(
repo: R,
store: S,
client: ClientWithMiddleware,
worker_id: String,
) {
process_jobs(&repo, &store, worker_id, PROCESS_QUEUE, process::perform).await
process_jobs(
&repo,
&store,
&client,
worker_id,
PROCESS_QUEUE,
process::perform,
)
.await
}
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
@ -174,6 +197,7 @@ type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
async fn process_jobs<R, S, F>(
repo: &R,
store: &S,
client: &ClientWithMiddleware,
worker_id: String,
queue: &'static str,
callback: F,
@ -181,10 +205,16 @@ async fn process_jobs<R, S, F>(
R: QueueRepo + HashRepo + IdentifierRepo + AliasRepo,
R::Bytes: Clone,
S: Store,
for<'a> F: Fn(&'a R, &'a S, &'a [u8]) -> LocalBoxFuture<'a, Result<(), Error>> + Copy,
for<'a> F: Fn(
&'a R,
&'a S,
&'a ClientWithMiddleware,
&'a [u8],
) -> LocalBoxFuture<'a, Result<(), Error>>
+ Copy,
{
loop {
let res = job_loop(repo, store, worker_id.clone(), queue, callback).await;
let res = job_loop(repo, store, client, worker_id.clone(), queue, callback).await;
if let Err(e) = res {
tracing::warn!("Error processing jobs: {}", format!("{e}"));
@ -199,6 +229,7 @@ async fn process_jobs<R, S, F>(
async fn job_loop<R, S, F>(
repo: &R,
store: &S,
client: &ClientWithMiddleware,
worker_id: String,
queue: &'static str,
callback: F,
@ -207,14 +238,20 @@ where
R: QueueRepo + HashRepo + IdentifierRepo + AliasRepo,
R::Bytes: Clone,
S: Store,
for<'a> F: Fn(&'a R, &'a S, &'a [u8]) -> LocalBoxFuture<'a, Result<(), Error>> + Copy,
for<'a> F: Fn(
&'a R,
&'a S,
&'a ClientWithMiddleware,
&'a [u8],
) -> LocalBoxFuture<'a, Result<(), Error>>
+ Copy,
{
loop {
let bytes = repo.pop(queue, worker_id.as_bytes().to_vec()).await?;
let span = tracing::info_span!("Running Job", worker_id = ?worker_id);
span.in_scope(|| (callback)(repo, store, bytes.as_ref()))
span.in_scope(|| (callback)(repo, store, client, bytes.as_ref()))
.instrument(span)
.await?;
}

View file

@ -6,10 +6,12 @@ use crate::{
store::{Identifier, Store},
};
use futures_util::StreamExt;
use reqwest_middleware::ClientWithMiddleware;
pub(super) fn perform<'a, R, S>(
repo: &'a R,
store: &'a S,
_client: &'a ClientWithMiddleware,
job: &'a [u8],
) -> LocalBoxFuture<'a, Result<(), Error>>
where

View file

@ -8,11 +8,14 @@ use crate::{
store::{Identifier, Store},
};
use futures_util::TryStreamExt;
use reqwest_middleware::ClientWithMiddleware;
use std::path::PathBuf;
use url::Url;
pub(super) fn perform<'a, R, S>(
repo: &'a R,
store: &'a S,
client: &'a ClientWithMiddleware,
job: &'a [u8],
) -> LocalBoxFuture<'a, Result<(), Error>>
where
@ -31,9 +34,11 @@ where
process_ingest(
repo,
store,
client,
identifier,
Serde::into_inner(upload_id),
declared_alias.map(Serde::into_inner),
crate::CONFIG.media.external_validation.as_ref(),
should_validate,
crate::CONFIG.media.process_timeout,
)
@ -66,13 +71,44 @@ where
})
}
struct LogDroppedIngest {
armed: bool,
upload_id: UploadId,
}
impl LogDroppedIngest {
fn guard(upload_id: UploadId) -> Self {
Self {
armed: true,
upload_id,
}
}
fn disarm(mut self) {
self.armed = false;
}
}
impl Drop for LogDroppedIngest {
fn drop(&mut self) {
if self.armed {
tracing::warn!(
"Failed to complete an upload {}- clients will hang",
self.upload_id
);
}
}
}
#[tracing::instrument(skip_all)]
async fn process_ingest<R, S>(
repo: &R,
store: &S,
client: &ClientWithMiddleware,
unprocessed_identifier: Vec<u8>,
upload_id: UploadId,
declared_alias: Option<Alias>,
external_validation: Option<&Url>,
should_validate: bool,
timeout: u64,
) -> Result<(), Error>
@ -80,6 +116,8 @@ where
R: FullRepo + 'static,
S: Store + 'static,
{
let guard = LogDroppedIngest::guard(upload_id);
let fut = async {
let unprocessed_identifier = S::Identifier::from_bytes(unprocessed_identifier)?;
@ -87,6 +125,8 @@ where
let store2 = store.clone();
let repo = repo.clone();
let client = client.clone();
let external_validation = external_validation.cloned();
let error_boundary = actix_rt::spawn(async move {
let stream = store2
.to_stream(&ident, None, None)
@ -96,8 +136,10 @@ where
let session = crate::ingest::ingest(
&repo,
&store2,
&client,
stream,
declared_alias,
external_validation.as_ref(),
should_validate,
timeout,
)
@ -132,6 +174,8 @@ where
repo.complete(upload_id, result).await?;
guard.disarm();
Ok(())
}