This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new fba19b0142 Cleanup multipart upload trait (#4572)
fba19b0142 is described below

commit fba19b0142daed54c181cdb8f634f29cf7d37f8d
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Jul 27 02:32:07 2023 -0400

    Cleanup multipart upload trait (#4572)
    
    * Cleanup multipart upload trait
    
    * Update object_store/src/multipart.rs
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    
    ---------
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
 object_store/src/aws/client.rs |  4 +--
 object_store/src/aws/mod.rs    | 30 +++++-----------
 object_store/src/azure/mod.rs  | 17 ++++------
 object_store/src/gcp/mod.rs    | 77 ++++++++++++++++++++----------------------
 object_store/src/multipart.rs  | 50 +++++++++------------------
 5 files changed, 69 insertions(+), 109 deletions(-)

diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs
index 971d2c6086..188897620b 100644
--- a/object_store/src/aws/client.rs
+++ b/object_store/src/aws/client.rs
@@ -23,7 +23,7 @@ use crate::client::list::ListClient;
 use crate::client::list_response::ListResponse;
 use crate::client::retry::RetryExt;
 use crate::client::GetOptionsExt;
-use crate::multipart::UploadPart;
+use crate::multipart::PartId;
 use crate::path::DELIMITER;
 use crate::{
     ClientOptions, GetOptions, ListResult, MultipartId, Path, Result, 
RetryConfig,
@@ -479,7 +479,7 @@ impl S3Client {
         &self,
         location: &Path,
         upload_id: &str,
-        parts: Vec<UploadPart>,
+        parts: Vec<PartId>,
     ) -> Result<()> {
         let parts = parts
             .into_iter()
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index e74e6f2dfc..5a29bd0fc6 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -56,7 +56,7 @@ use crate::client::{
     TokenCredentialProvider,
 };
 use crate::config::ConfigValue;
-use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, 
UploadPart};
+use crate::multipart::{PartId, PutPart, WriteMultiPart};
 use crate::{
     ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta,
     ObjectStore, Path, Result, RetryConfig,
@@ -227,7 +227,7 @@ impl ObjectStore for AmazonS3 {
             client: Arc::clone(&self.client),
         };
 
-        Ok((id, Box::new(CloudMultiPartUpload::new(upload, 8))))
+        Ok((id, Box::new(WriteMultiPart::new(upload, 8))))
     }
 
     async fn abort_multipart(
@@ -308,12 +308,8 @@ struct S3MultiPartUpload {
 }
 
 #[async_trait]
-impl CloudMultiPartUploadImpl for S3MultiPartUpload {
-    async fn put_multipart_part(
-        &self,
-        buf: Vec<u8>,
-        part_idx: usize,
-    ) -> Result<UploadPart, std::io::Error> {
+impl PutPart for S3MultiPartUpload {
+    async fn put_part(&self, buf: Vec<u8>, part_idx: usize) -> Result<PartId> {
         use reqwest::header::ETAG;
         let part = (part_idx + 1).to_string();
 
@@ -326,26 +322,16 @@ impl CloudMultiPartUploadImpl for S3MultiPartUpload {
             )
             .await?;
 
-        let etag = response
-            .headers()
-            .get(ETAG)
-            .context(MissingEtagSnafu)
-            .map_err(crate::Error::from)?;
+        let etag = response.headers().get(ETAG).context(MissingEtagSnafu)?;
 
-        let etag = etag
-            .to_str()
-            .context(BadHeaderSnafu)
-            .map_err(crate::Error::from)?;
+        let etag = etag.to_str().context(BadHeaderSnafu)?;
 
-        Ok(UploadPart {
+        Ok(PartId {
             content_id: etag.to_string(),
         })
     }
 
-    async fn complete(
-        &self,
-        completed_parts: Vec<UploadPart>,
-    ) -> Result<(), std::io::Error> {
+    async fn complete(&self, completed_parts: Vec<PartId>) -> Result<()> {
         self.client
             .complete_multipart(&self.location, &self.upload_id, 
completed_parts)
             .await?;
diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs
index d273503832..8619319a5b 100644
--- a/object_store/src/azure/mod.rs
+++ b/object_store/src/azure/mod.rs
@@ -28,7 +28,7 @@
 //! after 7 days.
 use self::client::{BlockId, BlockList};
 use crate::{
-    multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
+    multipart::{PartId, PutPart, WriteMultiPart},
     path::Path,
     ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta,
     ObjectStore, Result, RetryConfig,
@@ -42,7 +42,6 @@ use percent_encoding::percent_decode_str;
 use serde::{Deserialize, Serialize};
 use snafu::{OptionExt, ResultExt, Snafu};
 use std::fmt::{Debug, Formatter};
-use std::io;
 use std::str::FromStr;
 use std::sync::Arc;
 use tokio::io::AsyncWrite;
@@ -186,7 +185,7 @@ impl ObjectStore for MicrosoftAzure {
             client: Arc::clone(&self.client),
             location: location.to_owned(),
         };
-        Ok((String::new(), Box::new(CloudMultiPartUpload::new(inner, 8))))
+        Ok((String::new(), Box::new(WriteMultiPart::new(inner, 8))))
     }
 
     async fn abort_multipart(
@@ -243,12 +242,8 @@ struct AzureMultiPartUpload {
 }
 
 #[async_trait]
-impl CloudMultiPartUploadImpl for AzureMultiPartUpload {
-    async fn put_multipart_part(
-        &self,
-        buf: Vec<u8>,
-        part_idx: usize,
-    ) -> Result<UploadPart, io::Error> {
+impl PutPart for AzureMultiPartUpload {
+    async fn put_part(&self, buf: Vec<u8>, part_idx: usize) -> Result<PartId> {
         let content_id = format!("{part_idx:20}");
         let block_id: BlockId = content_id.clone().into();
 
@@ -264,10 +259,10 @@ impl CloudMultiPartUploadImpl for AzureMultiPartUpload {
             )
             .await?;
 
-        Ok(UploadPart { content_id })
+        Ok(PartId { content_id })
     }
 
-    async fn complete(&self, completed_parts: Vec<UploadPart>) -> Result<(), 
io::Error> {
+    async fn complete(&self, completed_parts: Vec<PartId>) -> Result<()> {
         let blocks = completed_parts
             .into_iter()
             .map(|part| BlockId::from(part.content_id))
diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs
index d4d370373d..d98e6b068d 100644
--- a/object_store/src/gcp/mod.rs
+++ b/object_store/src/gcp/mod.rs
@@ -29,7 +29,6 @@
 //! to abort the upload and drop those unneeded parts. In addition, you may 
wish to
 //! consider implementing automatic clean up of unused parts that are older 
than one
 //! week.
-use std::io;
 use std::str::FromStr;
 use std::sync::Arc;
 
@@ -52,7 +51,7 @@ use crate::client::{
     TokenCredentialProvider,
 };
 use crate::{
-    multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
+    multipart::{PartId, PutPart, WriteMultiPart},
     path::{Path, DELIMITER},
     ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta,
     ObjectStore, Result, RetryConfig,
@@ -117,6 +116,15 @@ enum Error {
     #[snafu(display("Error getting put response body: {}", source))]
     PutResponseBody { source: reqwest::Error },
 
+    #[snafu(display("Got invalid put response: {}", source))]
+    InvalidPutResponse { source: quick_xml::de::DeError },
+
+    #[snafu(display("Error performing post request {}: {}", path, source))]
+    PostRequest {
+        source: crate::client::retry::Error,
+        path: String,
+    },
+
     #[snafu(display("Error decoding object size: {}", source))]
     InvalidSize { source: std::num::ParseIntError },
 
@@ -148,6 +156,12 @@ enum Error {
 
     #[snafu(display("Configuration key: '{}' is not known.", key))]
     UnknownConfigurationKey { key: String },
+
+    #[snafu(display("ETag Header missing from response"))]
+    MissingEtag,
+
+    #[snafu(display("Received header containing non-ASCII data"))]
+    BadHeader { source: header::ToStrError },
 }
 
 impl From<Error> for super::Error {
@@ -283,14 +297,9 @@ impl GoogleCloudStorageClient {
             })?;
 
         let data = response.bytes().await.context(PutResponseBodySnafu)?;
-        let result: InitiateMultipartUploadResult = quick_xml::de::from_reader(
-            data.as_ref().reader(),
-        )
-        .context(InvalidXMLResponseSnafu {
-            method: "POST".to_string(),
-            url,
-            data,
-        })?;
+        let result: InitiateMultipartUploadResult =
+            quick_xml::de::from_reader(data.as_ref().reader())
+                .context(InvalidPutResponseSnafu)?;
 
         Ok(result.upload_id)
     }
@@ -472,24 +481,16 @@ struct GCSMultipartUpload {
 }
 
 #[async_trait]
-impl CloudMultiPartUploadImpl for GCSMultipartUpload {
+impl PutPart for GCSMultipartUpload {
     /// Upload an object part 
<https://cloud.google.com/storage/docs/xml-api/put-object-multipart>
-    async fn put_multipart_part(
-        &self,
-        buf: Vec<u8>,
-        part_idx: usize,
-    ) -> Result<UploadPart, io::Error> {
+    async fn put_part(&self, buf: Vec<u8>, part_idx: usize) -> Result<PartId> {
         let upload_id = self.multipart_id.clone();
         let url = format!(
             "{}/{}/{}",
             self.client.base_url, self.client.bucket_name_encoded, 
self.encoded_path
         );
 
-        let credential = self
-            .client
-            .get_credential()
-            .await
-            .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+        let credential = self.client.get_credential().await?;
 
         let response = self
             .client
@@ -504,26 +505,24 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
             .header(header::CONTENT_LENGTH, format!("{}", buf.len()))
             .body(buf)
             .send_retry(&self.client.retry_config)
-            .await?;
+            .await
+            .context(PutRequestSnafu {
+                path: &self.encoded_path,
+            })?;
 
         let content_id = response
             .headers()
             .get("ETag")
-            .ok_or_else(|| {
-                io::Error::new(
-                    io::ErrorKind::InvalidData,
-                    "response headers missing ETag",
-                )
-            })?
+            .context(MissingEtagSnafu)?
             .to_str()
-            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
+            .context(BadHeaderSnafu)?
             .to_string();
 
-        Ok(UploadPart { content_id })
+        Ok(PartId { content_id })
     }
 
     /// Complete a multipart upload 
<https://cloud.google.com/storage/docs/xml-api/post-object-complete>
-    async fn complete(&self, completed_parts: Vec<UploadPart>) -> Result<(), 
io::Error> {
+    async fn complete(&self, completed_parts: Vec<PartId>) -> Result<()> {
         let upload_id = self.multipart_id.clone();
         let url = format!(
             "{}/{}/{}",
@@ -539,16 +538,11 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
             })
             .collect();
 
-        let credential = self
-            .client
-            .get_credential()
-            .await
-            .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
-
+        let credential = self.client.get_credential().await?;
         let upload_info = CompleteMultipartUpload { parts };
 
         let data = quick_xml::se::to_string(&upload_info)
-            .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
+            .context(InvalidPutResponseSnafu)?
             // We cannot disable the escaping that transforms "/" to "&quote;" 
:(
             // https://github.com/tafia/quick-xml/issues/362
             // https://github.com/tafia/quick-xml/issues/350
@@ -561,7 +555,10 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
             .query(&[("uploadId", upload_id)])
             .body(data)
             .send_retry(&self.client.retry_config)
-            .await?;
+            .await
+            .context(PostRequestSnafu {
+                path: &self.encoded_path,
+            })?;
 
         Ok(())
     }
@@ -588,7 +585,7 @@ impl ObjectStore for GoogleCloudStorage {
             multipart_id: upload_id.clone(),
         };
 
-        Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8))))
+        Ok((upload_id, Box::new(WriteMultiPart::new(inner, 8))))
     }
 
     async fn abort_multipart(
diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs
index 5f9b7e6748..d4c911fcea 100644
--- a/object_store/src/multipart.rs
+++ b/object_store/src/multipart.rs
@@ -31,40 +31,33 @@ use crate::Result;
 type BoxedTryFuture<T> = Pin<Box<dyn Future<Output = Result<T, io::Error>> + 
Send>>;
 
 /// A trait that can be implemented by cloud-based object stores
-/// and used in combination with [`CloudMultiPartUpload`] to provide
+/// and used in combination with [`WriteMultiPart`] to provide
 /// multipart upload support
 #[async_trait]
-pub trait CloudMultiPartUploadImpl: 'static {
+pub trait PutPart: Send + Sync + 'static {
     /// Upload a single part
-    async fn put_multipart_part(
-        &self,
-        buf: Vec<u8>,
-        part_idx: usize,
-    ) -> Result<UploadPart, io::Error>;
+    async fn put_part(&self, buf: Vec<u8>, part_idx: usize) -> Result<PartId>;
 
     /// Complete the upload with the provided parts
     ///
     /// `completed_parts` is in order of part number
-    async fn complete(&self, completed_parts: Vec<UploadPart>) -> Result<(), 
io::Error>;
+    async fn complete(&self, completed_parts: Vec<PartId>) -> Result<()>;
 }
 
 /// Represents a part of a file that has been successfully uploaded in a 
multipart upload process.
 #[derive(Debug, Clone)]
-pub struct UploadPart {
+pub struct PartId {
     /// Id of this part
     pub content_id: String,
 }
 
-/// Struct that manages and controls multipart uploads to a cloud storage 
service.
-pub struct CloudMultiPartUpload<T>
-where
-    T: CloudMultiPartUploadImpl,
-{
+/// Wrapper around a [`PutPart`] that implements [`AsyncWrite`]
+pub struct WriteMultiPart<T: PutPart> {
     inner: Arc<T>,
     /// A list of completed parts, in sequential order.
-    completed_parts: Vec<Option<UploadPart>>,
+    completed_parts: Vec<Option<PartId>>,
     /// Part upload tasks currently running
-    tasks: FuturesUnordered<BoxedTryFuture<(usize, UploadPart)>>,
+    tasks: FuturesUnordered<BoxedTryFuture<(usize, PartId)>>,
     /// Maximum number of upload tasks to run concurrently
     max_concurrency: usize,
     /// Buffer that will be sent in next upload.
@@ -80,10 +73,7 @@ where
     completion_task: Option<BoxedTryFuture<()>>,
 }
 
-impl<T> CloudMultiPartUpload<T>
-where
-    T: CloudMultiPartUploadImpl,
-{
+impl<T: PutPart> WriteMultiPart<T> {
     /// Create a new multipart upload with the implementation and the given 
maximum concurrency
     pub fn new(inner: T, max_concurrency: usize) -> Self {
         Self {
@@ -114,7 +104,7 @@ where
     }
 
     /// Poll current tasks
-    pub fn poll_tasks(
+    fn poll_tasks(
         mut self: Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
     ) -> Result<(), io::Error> {
@@ -130,12 +120,7 @@ where
         }
         Ok(())
     }
-}
 
-impl<T> CloudMultiPartUpload<T>
-where
-    T: CloudMultiPartUploadImpl + Send + Sync,
-{
     // The `poll_flush` function will only flush the in-progress tasks.
     // The `final_flush` method called during `poll_shutdown` will flush
     // the `current_buffer` along with in-progress tasks.
@@ -153,7 +138,7 @@ where
             let inner = Arc::clone(&self.inner);
             let part_idx = self.current_part_idx;
             self.tasks.push(Box::pin(async move {
-                let upload_part = inner.put_multipart_part(out_buffer, 
part_idx).await?;
+                let upload_part = inner.put_part(out_buffer, part_idx).await?;
                 Ok((part_idx, upload_part))
             }));
         }
@@ -169,10 +154,7 @@ where
     }
 }
 
-impl<T> AsyncWrite for CloudMultiPartUpload<T>
-where
-    T: CloudMultiPartUploadImpl + Send + Sync,
-{
+impl<T: PutPart> AsyncWrite for WriteMultiPart<T> {
     fn poll_write(
         mut self: Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
@@ -199,7 +181,7 @@ where
             let inner = Arc::clone(&self.inner);
             let part_idx = self.current_part_idx;
             self.tasks.push(Box::pin(async move {
-                let upload_part = inner.put_multipart_part(out_buffer, 
part_idx).await?;
+                let upload_part = inner.put_part(out_buffer, part_idx).await?;
                 Ok((part_idx, upload_part))
             }));
             self.current_part_idx += 1;
@@ -269,9 +251,9 @@ where
     }
 }
 
-impl<T: CloudMultiPartUploadImpl> std::fmt::Debug for CloudMultiPartUpload<T> {
+impl<T: PutPart> std::fmt::Debug for WriteMultiPart<T> {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        f.debug_struct("CloudMultiPartUpload")
+        f.debug_struct("WriteMultiPart")
             .field("completed_parts", &self.completed_parts)
             .field("tasks", &self.tasks)
             .field("max_concurrency", &self.max_concurrency)

Reply via email to