This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f36c883459 Implement MultipartStore for ThrottledStore (#5533)
9f36c883459 is described below

commit 9f36c883459405ecd9a5f4fdfa9a3317ab52302c
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Mar 29 10:14:35 2024 +0000

    Implement MultipartStore for ThrottledStore (#5533)
    
    * Implement MultipartStore for ThrottledStore
    
    Limit concurrency in BufWriter
    
    Tweak WriteMultipart
    
    * Fix MSRV
    
    * Format
---
 object_store/src/buffered.rs | 14 ++++++++
 object_store/src/throttle.rs | 78 ++++++++++++++++++++++++++++++++++++++++----
 object_store/src/upload.rs   | 76 ++++++++++++++++++++++++++++++++++--------
 3 files changed, 148 insertions(+), 20 deletions(-)

diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs
index 39f8eafbef7..de6d4eb1bb9 100644
--- a/object_store/src/buffered.rs
+++ b/object_store/src/buffered.rs
@@ -216,6 +216,7 @@ impl AsyncBufRead for BufReader {
 /// streamed using [`ObjectStore::put_multipart`]
 pub struct BufWriter {
     capacity: usize,
+    max_concurrency: usize,
     state: BufWriterState,
     store: Arc<dyn ObjectStore>,
 }
@@ -250,10 +251,21 @@ impl BufWriter {
         Self {
             capacity,
             store,
+            max_concurrency: 8,
             state: BufWriterState::Buffer(path, Vec::new()),
         }
     }
 
+    /// Override the maximum number of in-flight requests for this writer
+    ///
+    /// Defaults to 8
+    pub fn with_max_concurrency(self, max_concurrency: usize) -> Self {
+        Self {
+            max_concurrency,
+            ..self
+        }
+    }
+
     /// Abort this writer, cleaning up any partially uploaded state
     ///
     /// # Panic
@@ -275,9 +287,11 @@ impl AsyncWrite for BufWriter {
         buf: &[u8],
     ) -> Poll<Result<usize, Error>> {
         let cap = self.capacity;
+        let max_concurrency = self.max_concurrency;
         loop {
             return match &mut self.state {
                 BufWriterState::Write(Some(write)) => {
+                    ready!(write.poll_for_capacity(cx, max_concurrency))?;
                     write.write(buf);
                     Poll::Ready(Ok(buf.len()))
                 }
diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs
index 5ca1eedbf73..65fac5922f6 100644
--- a/object_store/src/throttle.rs
+++ b/object_store/src/throttle.rs
@@ -20,11 +20,12 @@ use parking_lot::Mutex;
 use std::ops::Range;
 use std::{convert::TryInto, sync::Arc};
 
-use crate::GetOptions;
+use crate::multipart::{MultipartStore, PartId};
 use crate::{
-    path::Path, GetResult, GetResultPayload, ListResult, MultipartUpload, 
ObjectMeta, ObjectStore,
-    PutOptions, PutResult, Result,
+    path::Path, GetResult, GetResultPayload, ListResult, MultipartId, 
MultipartUpload, ObjectMeta,
+    ObjectStore, PutOptions, PutResult, Result,
 };
+use crate::{GetOptions, UploadPart};
 use async_trait::async_trait;
 use bytes::Bytes;
 use futures::{stream::BoxStream, FutureExt, StreamExt};
@@ -110,12 +111,12 @@ async fn sleep(duration: Duration) {
 /// **Note that the behavior of the wrapper is deterministic and might not 
reflect real-world
 /// conditions!**
 #[derive(Debug)]
-pub struct ThrottledStore<T: ObjectStore> {
+pub struct ThrottledStore<T> {
     inner: T,
     config: Arc<Mutex<ThrottleConfig>>,
 }
 
-impl<T: ObjectStore> ThrottledStore<T> {
+impl<T> ThrottledStore<T> {
     /// Create new wrapper with zero waiting times.
     pub fn new(inner: T, config: ThrottleConfig) -> Self {
         Self {
@@ -157,8 +158,12 @@ impl<T: ObjectStore> ObjectStore for ThrottledStore<T> {
         self.inner.put_opts(location, bytes, opts).await
     }
 
-    async fn put_multipart(&self, _location: &Path) -> Result<Box<dyn 
MultipartUpload>> {
-        Err(super::Error::NotImplemented)
+    async fn put_multipart(&self, location: &Path) -> Result<Box<dyn 
MultipartUpload>> {
+        let upload = self.inner.put_multipart(location).await?;
+        Ok(Box::new(ThrottledUpload {
+            upload,
+            sleep: self.config().wait_put_per_call,
+        }))
     }
 
     async fn get(&self, location: &Path) -> Result<GetResult> {
@@ -316,6 +321,63 @@ where
         .boxed()
 }
 
+#[async_trait]
+impl<T: MultipartStore> MultipartStore for ThrottledStore<T> {
+    async fn create_multipart(&self, path: &Path) -> Result<MultipartId> {
+        self.inner.create_multipart(path).await
+    }
+
+    async fn put_part(
+        &self,
+        path: &Path,
+        id: &MultipartId,
+        part_idx: usize,
+        data: Bytes,
+    ) -> Result<PartId> {
+        sleep(self.config().wait_put_per_call).await;
+        self.inner.put_part(path, id, part_idx, data).await
+    }
+
+    async fn complete_multipart(
+        &self,
+        path: &Path,
+        id: &MultipartId,
+        parts: Vec<PartId>,
+    ) -> Result<PutResult> {
+        self.inner.complete_multipart(path, id, parts).await
+    }
+
+    async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> 
Result<()> {
+        self.inner.abort_multipart(path, id).await
+    }
+}
+
+#[derive(Debug)]
+struct ThrottledUpload {
+    upload: Box<dyn MultipartUpload>,
+    sleep: Duration,
+}
+
+#[async_trait]
+impl MultipartUpload for ThrottledUpload {
+    fn put_part(&mut self, data: Bytes) -> UploadPart {
+        let duration = self.sleep;
+        let put = self.upload.put_part(data);
+        Box::pin(async move {
+            sleep(duration).await;
+            put.await
+        })
+    }
+
+    async fn complete(&mut self) -> Result<PutResult> {
+        self.upload.complete().await
+    }
+
+    async fn abort(&mut self) -> Result<()> {
+        self.upload.abort().await
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -351,6 +413,8 @@ mod tests {
         list_with_delimiter(&store).await;
         rename_and_copy(&store).await;
         copy_if_not_exists(&store).await;
+        stream_get(&store).await;
+        multipart(&store, &store).await;
     }
 
     #[tokio::test]
diff --git a/object_store/src/upload.rs b/object_store/src/upload.rs
index 6f8bfa8a5f7..fe864e2821c 100644
--- a/object_store/src/upload.rs
+++ b/object_store/src/upload.rs
@@ -15,12 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{PutResult, Result};
+use std::task::{Context, Poll};
+
 use async_trait::async_trait;
 use bytes::Bytes;
 use futures::future::BoxFuture;
+use futures::ready;
 use tokio::task::JoinSet;
 
+use crate::{PutResult, Result};
+
 /// An upload part request
 pub type UploadPart = BoxFuture<'static, Result<()>>;
 
@@ -110,31 +114,44 @@ pub struct WriteMultipart {
 impl WriteMultipart {
     /// Create a new [`WriteMultipart`] that will upload using 5MB chunks
     pub fn new(upload: Box<dyn MultipartUpload>) -> Self {
-        Self::new_with_capacity(upload, 5 * 1024 * 1024)
+        Self::new_with_chunk_size(upload, 5 * 1024 * 1024)
     }
 
-    /// Create a new [`WriteMultipart`] that will upload in fixed `capacity` 
sized chunks
-    pub fn new_with_capacity(upload: Box<dyn MultipartUpload>, capacity: 
usize) -> Self {
+    /// Create a new [`WriteMultipart`] that will upload in fixed `chunk_size` 
sized chunks
+    pub fn new_with_chunk_size(upload: Box<dyn MultipartUpload>, chunk_size: 
usize) -> Self {
         Self {
             upload,
-            buffer: Vec::with_capacity(capacity),
+            buffer: Vec::with_capacity(chunk_size),
             tasks: Default::default(),
         }
     }
 
-    /// Wait until there are `max_concurrency` or fewer requests in-flight
-    pub async fn wait_for_capacity(&mut self, max_concurrency: usize) -> 
Result<()> {
-        while self.tasks.len() > max_concurrency {
-            self.tasks.join_next().await.unwrap()??;
+    /// Polls for there to be less than `max_concurrency` [`UploadPart`] in 
progress
+    ///
+    /// See [`Self::wait_for_capacity`] for an async version of this function
+    pub fn poll_for_capacity(
+        &mut self,
+        cx: &mut Context<'_>,
+        max_concurrency: usize,
+    ) -> Poll<Result<()>> {
+        while !self.tasks.is_empty() && self.tasks.len() >= max_concurrency {
+            ready!(self.tasks.poll_join_next(cx)).unwrap()??
         }
-        Ok(())
+        Poll::Ready(Ok(()))
+    }
+
+    /// Wait until there are less than `max_concurrency` [`UploadPart`] in 
progress
+    ///
+    /// See [`Self::poll_for_capacity`] for a [`Poll`] version of this function
+    pub async fn wait_for_capacity(&mut self, max_concurrency: usize) -> 
Result<()> {
+        futures::future::poll_fn(|cx| self.poll_for_capacity(cx, 
max_concurrency)).await
     }
 
     /// Write data to this [`WriteMultipart`]
     ///
-    /// Note this method is synchronous (not `async`) and will immediately 
start new uploads
-    /// as soon as the internal `capacity` is hit, regardless of
-    /// how many outstanding uploads are already in progress.
+    /// Note this method is synchronous (not `async`) and will immediately
+    /// start new uploads as soon as the internal `chunk_size` is hit,
+    /// regardless of how many outstanding uploads are already in progress.
     ///
     /// Back pressure can optionally be applied to producers by calling
     /// [`Self::wait_for_capacity`] prior to calling this method
@@ -173,3 +190,36 @@ impl WriteMultipart {
         self.upload.complete().await
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use std::time::Duration;
+
+    use futures::FutureExt;
+
+    use crate::memory::InMemory;
+    use crate::path::Path;
+    use crate::throttle::{ThrottleConfig, ThrottledStore};
+    use crate::ObjectStore;
+
+    use super::*;
+
+    #[tokio::test]
+    async fn test_concurrency() {
+        let config = ThrottleConfig {
+            wait_put_per_call: Duration::from_millis(1),
+            ..Default::default()
+        };
+
+        let path = Path::from("foo");
+        let store = ThrottledStore::new(InMemory::new(), config);
+        let upload = store.put_multipart(&path).await.unwrap();
+        let mut write = WriteMultipart::new_with_chunk_size(upload, 10);
+
+        for _ in 0..20 {
+            write.write(&[0; 5]);
+        }
+        assert!(write.wait_for_capacity(10).now_or_never().is_none());
+        write.wait_for_capacity(10).await.unwrap()
+    }
+}

Reply via email to