diff --git a/src/bytes_stream.rs b/src/bytes_stream.rs new file mode 100644 index 0000000..d18450a --- /dev/null +++ b/src/bytes_stream.rs @@ -0,0 +1,62 @@ +use actix_web::web::{Bytes, BytesMut}; +use futures_util::{Stream, StreamExt}; +use std::{ + collections::{vec_deque::IntoIter, VecDeque}, + pin::Pin, + task::{Context, Poll}, +}; + +#[derive(Clone, Debug)] +pub(crate) struct BytesStream { + inner: VecDeque, + total_len: usize, +} + +impl BytesStream { + pub(crate) fn new() -> Self { + Self { + inner: VecDeque::new(), + total_len: 0, + } + } + + pub(crate) fn add_bytes(&mut self, bytes: Bytes) { + self.total_len += bytes.len(); + self.inner.push_back(bytes); + } + + pub(crate) fn len(&self) -> usize { + self.total_len + } + + pub(crate) fn into_io_stream(self) -> impl Stream> + Unpin { + self.map(|bytes| Ok(bytes)) + } + + pub(crate) fn into_bytes(self) -> Bytes { + let mut buf = BytesMut::with_capacity(self.total_len); + + for bytes in self.inner { + buf.extend_from_slice(&bytes); + } + + buf.freeze() + } +} + +impl Stream for BytesStream { + type Item = Bytes; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(self.get_mut().inner.pop_front()) + } +} + +impl IntoIterator for BytesStream { + type Item = Bytes; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} diff --git a/src/ingest.rs b/src/ingest.rs index fc3fd14..8b0e5bf 100644 --- a/src/ingest.rs +++ b/src/ingest.rs @@ -1,4 +1,5 @@ use crate::{ + bytes_stream::BytesStream, either::Either, error::{Error, UploadError}, magick::ValidInputType, @@ -6,7 +7,7 @@ use crate::{ store::Store, CONFIG, }; -use actix_web::web::{Bytes, BytesMut}; +use actix_web::web::Bytes; use futures_util::{Stream, StreamExt}; use sha2::{Digest, Sha256}; use tracing::{Instrument, Span}; @@ -27,29 +28,17 @@ where } #[tracing::instrument(name = "Aggregate", skip(stream))] -async fn aggregate(stream: S) -> Result +async fn aggregate(mut stream: S) -> Result where - S: Stream>, + S: Stream> + Unpin, { - futures_util::pin_mut!(stream); + let mut buf = BytesStream::new(); - let mut total_len = 0; - let mut buf = Vec::new(); - tracing::debug!("Reading stream to memory"); while let Some(res) = stream.next().await { - let bytes = res?; - total_len += bytes.len(); - buf.push(bytes); + buf.add_bytes(res?); } - let bytes_mut = buf - .iter() - .fold(BytesMut::with_capacity(total_len), |mut acc, item| { - acc.extend_from_slice(item); - acc - }); - - Ok(bytes_mut.freeze()) + Ok(buf.into_bytes()) } #[tracing::instrument(name = "Ingest", skip(stream))] diff --git a/src/main.rs b/src/main.rs index c9b5d2a..b5961a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,6 +24,7 @@ use tracing_awc::Tracing; use tracing_futures::Instrument; mod backgrounded; +mod bytes_stream; mod concurrent_processor; mod config; mod details; diff --git a/src/store/object_store.rs b/src/store/object_store.rs index f9b01b2..c3cb033 100644 --- a/src/store/object_store.rs +++ b/src/store/object_store.rs @@ -1,4 +1,5 @@ use crate::{ + bytes_stream::BytesStream, error::Error, repo::{Repo, SettingsRepo}, store::{Store, StoreConfig}, @@ -10,9 +11,9 @@ use actix_web::{ header::{ByteRangeSpec, Range, CONTENT_LENGTH}, StatusCode, }, - web::{Bytes, BytesMut}, + web::Bytes, }; -use awc::{error::SendRequestError, Client, ClientRequest, SendClientRequest}; +use awc::{error::SendRequestError, Client, ClientRequest, ClientResponse, SendClientRequest}; use futures_util::{Stream, StreamExt, TryStreamExt}; use rusty_s3::{actions::S3Action, Bucket, BucketError, Credentials, UrlStyle}; use std::{pin::Pin, string::FromUtf8Error, time::Duration}; @@ -129,31 +130,32 @@ fn payload_to_io_error(e: PayloadError) -> std::io::Error { } #[tracing::instrument(skip(stream))] -async fn read_chunk(stream: &mut S) -> std::io::Result +async fn read_chunk(stream: &mut S) -> std::io::Result where S: Stream> + Unpin + 'static, { - let mut buf = Vec::new(); - let mut total_len = 0; + let mut buf = BytesStream::new(); - while total_len < CHUNK_SIZE { + while buf.len() < CHUNK_SIZE { if let Some(res) = stream.next().await { - let bytes = res?; - total_len += bytes.len(); - buf.push(bytes); + buf.add_bytes(res?) } else { break; } } - let bytes = buf - .iter() - .fold(BytesMut::with_capacity(total_len), |mut acc, item| { - acc.extend_from_slice(item); - acc - }); + Ok(buf) +} - Ok(bytes.freeze()) +async fn status_error(mut response: ClientResponse) -> Error { + let body = match response.body().await { + Err(e) => return e.into(), + Ok(body) => body, + }; + + let body = String::from_utf8_lossy(&body).to_string(); + + ObjectError::Status(response.status(), body).into() } #[async_trait::async_trait(?Send)] @@ -178,9 +180,7 @@ impl Store for ObjectStore { let mut response = req.send().await.map_err(ObjectError::from)?; if !response.status().is_success() { - let body = String::from_utf8_lossy(&response.body().await?).to_string(); - - return Err(ObjectError::Status(response.status(), body).into()); + return Err(status_error(response).await); } let body = response.body().await?; @@ -197,8 +197,8 @@ impl Store for ObjectStore { while !complete { part_number += 1; - let bytes = read_chunk(&mut stream).await?; - complete = bytes.len() < CHUNK_SIZE; + let buf = read_chunk(&mut stream).await?; + complete = buf.len() < CHUNK_SIZE; let this = self.clone(); @@ -206,21 +206,19 @@ impl Store for ObjectStore { let upload_id2 = upload_id.clone(); let handle = actix_rt::spawn( async move { - let mut response = this + let response = this .create_upload_part_request( - bytes.clone(), + buf.clone(), &object_id2, part_number, &upload_id2, ) .await? - .send_body(bytes) + .send_stream(buf.into_io_stream()) .await?; if !response.status().is_success() { - let body = String::from_utf8_lossy(&response.body().await?).to_string(); - - return Err(ObjectError::Status(response.status(), body).into()); + return Err(status_error(response).await); } let etag = response @@ -251,7 +249,7 @@ impl Store for ObjectStore { etags.push(future.await.map_err(ObjectError::from)??); } - let mut response = self + let response = self .send_complete_multipart_request( &object_id, upload_id, @@ -260,9 +258,7 @@ impl Store for ObjectStore { .await?; if !response.status().is_success() { - let body = String::from_utf8_lossy(&response.body().await?).to_string(); - - return Err(ObjectError::Status(response.status(), body).into()); + return Err(status_error(response).await); } Ok(()) as Result<(), Error> @@ -283,15 +279,13 @@ impl Store for ObjectStore { async fn save_bytes(&self, bytes: Bytes) -> Result { let (req, object_id) = self.put_object_request().await?; - let mut response = req.send_body(bytes).await.map_err(ObjectError::from)?; + let response = req.send_body(bytes).await.map_err(ObjectError::from)?; - if response.status().is_success() { - return Ok(object_id); + if !response.status().is_success() { + return Err(status_error(response).await); } - let body = String::from_utf8_lossy(&response.body().await?).to_string(); - - Err(ObjectError::Status(response.status(), body).into()) + Ok(object_id) } #[tracing::instrument] @@ -301,19 +295,17 @@ impl Store for ObjectStore { from_start: Option, len: Option, ) -> Result { - let mut response = self + let response = self .get_object_request(identifier, from_start, len) .send() .await .map_err(ObjectError::from)?; - if response.status().is_success() { - return Ok(Box::pin(response.map_err(payload_to_io_error))); + if !response.status().is_success() { + return Err(status_error(response).await); } - let body = String::from_utf8_lossy(&response.body().await?).to_string(); - - Err(ObjectError::Status(response.status(), body).into()) + Ok(Box::pin(response.map_err(payload_to_io_error))) } #[tracing::instrument(skip(writer))] @@ -332,12 +324,9 @@ impl Store for ObjectStore { .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, ObjectError::from(e)))?; if !response.status().is_success() { - let body = response.body().await.map_err(payload_to_io_error)?; - let body = String::from_utf8_lossy(&body).to_string(); - return Err(std::io::Error::new( std::io::ErrorKind::Other, - ObjectError::Status(response.status(), body), + status_error(response).await, )); } @@ -352,16 +341,14 @@ impl Store for ObjectStore { #[tracing::instrument] async fn len(&self, identifier: &Self::Identifier) -> Result { - let mut response = self + let response = self .head_object_request(identifier) .send() .await .map_err(ObjectError::from)?; if !response.status().is_success() { - let body = String::from_utf8_lossy(&response.body().await?).to_string(); - - return Err(ObjectError::Status(response.status(), body).into()); + return Err(status_error(response).await); } let length = response @@ -378,12 +365,10 @@ impl Store for ObjectStore { #[tracing::instrument] async fn remove(&self, identifier: &Self::Identifier) -> Result<(), Error> { - let mut response = self.delete_object_request(identifier).send().await?; + let response = self.delete_object_request(identifier).send().await?; if !response.status().is_success() { - let body = String::from_utf8_lossy(&response.body().await?).to_string(); - - return Err(ObjectError::Status(response.status(), body).into()); + return Err(status_error(response).await); } Ok(()) @@ -445,7 +430,7 @@ impl ObjectStore { async fn create_upload_part_request( &self, - bytes: Bytes, + buf: BytesStream, object_id: &ObjectId, part_number: u16, upload_id: &str, @@ -463,7 +448,9 @@ impl ObjectStore { let hash_string = actix_web::web::block(move || { let guard = hashing_span.enter(); let mut hasher = md5::Md5::new(); - hasher.update(&bytes); + for bytes in buf { + hasher.update(&bytes); + } let hash = hasher.finalize(); let hash_string = base64::encode(&hash); drop(guard);