This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new ef5c45cf418 Add BufWriter for Adapative Put / Multipart Upload (#5431)
ef5c45cf418 is described below

commit ef5c45cf4186a8124da5a1603ebdbc09ef9928fc
Author: Raphael Taylor-Davies <1781103+tustv...@users.noreply.github.com>
AuthorDate: Tue Feb 27 15:39:36 2024 +1300

    Add BufWriter for Adapative Put / Multipart Upload (#5431)
    
    * Add BufWriter
    
    * Review feedback
---
 object_store/src/buffered.rs | 163 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 161 insertions(+), 2 deletions(-)

diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs
index 3a1354f4f20..fdefe599f79 100644
--- a/object_store/src/buffered.rs
+++ b/object_store/src/buffered.rs
@@ -18,7 +18,7 @@
 //! Utilities for performing tokio-style buffered IO
 
 use crate::path::Path;
-use crate::{ObjectMeta, ObjectStore};
+use crate::{MultipartId, ObjectMeta, ObjectStore};
 use bytes::Bytes;
 use futures::future::{BoxFuture, FutureExt};
 use futures::ready;
@@ -27,7 +27,7 @@ use std::io::{Error, ErrorKind, SeekFrom};
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
-use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, ReadBuf};
+use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, AsyncWriteExt, 
ReadBuf};
 
 /// The default buffer size used by [`BufReader`]
 pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024;
@@ -205,6 +205,138 @@ impl AsyncBufRead for BufReader {
     }
 }
 
+/// An async buffered writer compatible with the tokio IO traits
+///
+/// Up to `capacity` bytes will be buffered in memory, and flushed on shutdown
+/// using [`ObjectStore::put`]. If `capacity` is exceeded, data will instead be
+/// streamed using [`ObjectStore::put_multipart`]
+pub struct BufWriter {
+    capacity: usize,
+    state: BufWriterState,
+    multipart_id: Option<MultipartId>,
+    store: Arc<dyn ObjectStore>,
+}
+
+impl std::fmt::Debug for BufWriter {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("BufWriter")
+            .field("capacity", &self.capacity)
+            .field("multipart_id", &self.multipart_id)
+            .finish()
+    }
+}
+
+type MultipartResult = (MultipartId, Box<dyn AsyncWrite + Send + Unpin>);
+
+enum BufWriterState {
+    /// Buffer up to capacity bytes
+    Buffer(Path, Vec<u8>),
+    /// [`ObjectStore::put_multipart`]
+    Prepare(BoxFuture<'static, std::io::Result<MultipartResult>>),
+    /// Write to a multipart upload
+    Write(Box<dyn AsyncWrite + Send + Unpin>),
+    /// [`ObjectStore::put`]
+    Put(BoxFuture<'static, std::io::Result<()>>),
+}
+
+impl BufWriter {
+    /// Create a new [`BufWriter`] from the provided [`ObjectStore`] and 
[`Path`]
+    pub fn new(store: Arc<dyn ObjectStore>, path: Path) -> Self {
+        Self::with_capacity(store, path, 10 * 1024 * 1024)
+    }
+
+    /// Create a new [`BufWriter`] from the provided [`ObjectStore`], [`Path`] 
and `capacity`
+    pub fn with_capacity(store: Arc<dyn ObjectStore>, path: Path, capacity: 
usize) -> Self {
+        Self {
+            capacity,
+            store,
+            state: BufWriterState::Buffer(path, Vec::new()),
+            multipart_id: None,
+        }
+    }
+
+    /// Returns the [`MultipartId`] if multipart upload
+    pub fn multipart_id(&self) -> Option<&MultipartId> {
+        self.multipart_id.as_ref()
+    }
+}
+
+impl AsyncWrite for BufWriter {
+    fn poll_write(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &[u8],
+    ) -> Poll<Result<usize, Error>> {
+        let cap = self.capacity;
+        loop {
+            return match &mut self.state {
+                BufWriterState::Write(write) => Pin::new(write).poll_write(cx, 
buf),
+                BufWriterState::Put(_) => panic!("Already shut down"),
+                BufWriterState::Prepare(f) => {
+                    let (id, w) = ready!(f.poll_unpin(cx)?);
+                    self.state = BufWriterState::Write(w);
+                    self.multipart_id = Some(id);
+                    continue;
+                }
+                BufWriterState::Buffer(path, b) => {
+                    if b.len().saturating_add(buf.len()) >= cap {
+                        let buffer = std::mem::take(b);
+                        let path = std::mem::take(path);
+                        let store = Arc::clone(&self.store);
+                        self.state = BufWriterState::Prepare(Box::pin(async 
move {
+                            let (id, mut writer) = 
store.put_multipart(&path).await?;
+                            writer.write_all(&buffer).await?;
+                            Ok((id, writer))
+                        }));
+                        continue;
+                    }
+                    b.extend_from_slice(buf);
+                    Poll::Ready(Ok(buf.len()))
+                }
+            };
+        }
+    }
+
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Result<(), Error>> {
+        loop {
+            return match &mut self.state {
+                BufWriterState::Buffer(_, _) => Poll::Ready(Ok(())),
+                BufWriterState::Write(write) => Pin::new(write).poll_flush(cx),
+                BufWriterState::Put(_) => panic!("Already shut down"),
+                BufWriterState::Prepare(f) => {
+                    let (id, w) = ready!(f.poll_unpin(cx)?);
+                    self.state = BufWriterState::Write(w);
+                    self.multipart_id = Some(id);
+                    continue;
+                }
+            };
+        }
+    }
+
+    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Result<(), Error>> {
+        loop {
+            match &mut self.state {
+                BufWriterState::Prepare(f) => {
+                    let (id, w) = ready!(f.poll_unpin(cx)?);
+                    self.state = BufWriterState::Write(w);
+                    self.multipart_id = Some(id);
+                }
+                BufWriterState::Buffer(p, b) => {
+                    let buf = std::mem::take(b);
+                    let path = std::mem::take(p);
+                    let store = Arc::clone(&self.store);
+                    self.state = BufWriterState::Put(Box::pin(async move {
+                        store.put(&path, buf.into()).await?;
+                        Ok(())
+                    }));
+                }
+                BufWriterState::Put(f) => return f.poll_unpin(cx),
+                BufWriterState::Write(w) => return 
Pin::new(w).poll_shutdown(cx),
+            }
+        }
+    }
+}
+
 /// Port of standardised function as requires Rust 1.66
 ///
 /// 
<https://github.com/rust-lang/rust/pull/87601/files#diff-b9390ee807a1dae3c3128dce36df56748ad8d23c6e361c0ebba4d744bf6efdb9R1533>
@@ -300,4 +432,31 @@ mod tests {
             assert!(buffer.is_empty());
         }
     }
+
+    #[tokio::test]
+    async fn test_buf_writer() {
+        let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
+        let path = Path::from("file.txt");
+
+        // Test put
+        let mut writer = BufWriter::with_capacity(Arc::clone(&store), 
path.clone(), 30);
+        writer.write_all(&[0; 20]).await.unwrap();
+        writer.flush().await.unwrap();
+        writer.write_all(&[0; 5]).await.unwrap();
+        assert!(writer.multipart_id().is_none());
+        writer.shutdown().await.unwrap();
+        assert!(writer.multipart_id().is_none());
+        assert_eq!(store.head(&path).await.unwrap().size, 25);
+
+        // Test multipart
+        let mut writer = BufWriter::with_capacity(Arc::clone(&store), 
path.clone(), 30);
+        writer.write_all(&[0; 20]).await.unwrap();
+        writer.flush().await.unwrap();
+        writer.write_all(&[0; 20]).await.unwrap();
+        assert!(writer.multipart_id().is_some());
+        writer.shutdown().await.unwrap();
+        assert!(writer.multipart_id().is_some());
+
+        assert_eq!(store.head(&path).await.unwrap().size, 40);
+    }
 }

Reply via email to