tustvold commented on code in PR #5205:
URL: https://github.com/apache/arrow-rs/pull/5205#discussion_r1434885895
##########
object_store/src/multipart.rs:
##########
@@ -316,3 +317,215 @@ pub trait MultiPartStore: Send + Sync + 'static {
/// Aborts a multipart upload
async fn abort_multipart(&self, path: &Path, id: &MultipartId) ->
Result<()>;
}
+
+/// Create a lazy multipart writer for a given [`ObjectStore`] and [`Path`].
+///
+/// A multipart upload using `ObjectStore::put_multipart` will only be created
when the size exceeds `multipart_threshold`,
+/// otherwise a direct PUT will be performed on shutdown.
+pub fn put_multipart_lazy(
+ store: Arc<dyn ObjectStore>,
+ path: Path,
+ multipart_threshold: usize,
+) -> Box<dyn AsyncWrite + Send + Unpin> {
+ Box::new(LazyWriteMultiPart::new(store, path, multipart_threshold))
+}
+
+enum LazyWriteState {
+ /// Buffering data, not yet reached multipart threshold
+ Buffer(Vec<u8>),
+ /// Writer shutdown, putting data in progress
+ Put(BoxedTryFuture<()>),
+ /// Multipart threshold reached, creating a new multipart upload
+ CreateMultipart(BoxedTryFuture<Box<dyn AsyncWrite + Send + Unpin>>,
Vec<u8>),
+ /// Writing the buffered data from before creation of multipart upload
+ FlushingInitialWrite(Option<Box<dyn AsyncWrite + Send + Unpin>>, Vec<u8>,
usize),
+ /// Delegate to underlying multipart writer
+ AsyncWrite(Box<dyn AsyncWrite + Send + Unpin>),
+}
+
+/// Wrapper around a [`ObjectStore`] and [`Path`] that implements
[`AsyncWrite`]
+struct LazyWriteMultiPart {
+ store: Arc<dyn ObjectStore>,
+ path: Path,
+ multipart_threshold: usize,
+ state: LazyWriteState,
+}
+
+impl LazyWriteMultiPart {
+ pub fn new(store: Arc<dyn ObjectStore>, path: Path, multipart_threshold:
usize) -> Self {
+ Self {
+ store,
+ path,
+ multipart_threshold,
+ state: LazyWriteState::Buffer(Vec::new()),
+ }
+ }
+
+ fn poll_create_multipart(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ match &mut self.state {
+ LazyWriteState::CreateMultipart(fut, buffer) => {
+ let writer = ready!(Pin::new(fut).poll(cx))?;
+ if buffer.is_empty() {
+ self.state = LazyWriteState::AsyncWrite(writer);
+ } else {
+ let new_buffer = std::mem::take(buffer);
+ self.state =
LazyWriteState::FlushingInitialWrite(Some(writer), new_buffer, 0);
+ }
+ Poll::Ready(Ok(()))
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ fn do_inner_flush(
+ cx: &mut std::task::Context<'_>,
+ writer: &mut Box<dyn AsyncWrite + Send + Unpin>,
+ buffer: &mut Vec<u8>,
+ flush_offset: &mut usize,
+ write_len: usize,
+ ) -> Poll<Result<usize, io::Error>> {
+ let end = std::cmp::min(*flush_offset + write_len, buffer.len());
+ let n = ready!(Pin::new(writer).poll_write(cx,
&buffer[*flush_offset..end]))?;
+ *flush_offset += n;
+ Poll::Ready(Ok(n))
+ }
+}
+
+impl AsyncWrite for LazyWriteMultiPart {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ let multipart_threshold = self.multipart_threshold;
+ let store = Arc::clone(&self.store);
+ let path = self.path.clone();
+
+ let mut wrote = 0;
+ loop {
+ match &mut self.state {
+ LazyWriteState::Buffer(buffer) => {
+ let buf_len = buf.len();
+ let new_len = buffer.len() + buf_len;
+
+ if new_len > multipart_threshold {
+ let new_buffer = std::mem::take(buffer);
+ let store = Arc::clone(&store);
+ let path = path.clone();
+ let create_fut = Box::pin(async move {
+ let (_, multipart_writer) =
store.put_multipart(&path).await?;
+ Ok(multipart_writer)
+ });
+ self.state =
LazyWriteState::CreateMultipart(create_fut, new_buffer);
+ } else {
+ buffer.extend_from_slice(buf);
+ return Poll::Ready(Ok(buf_len));
+ }
+ }
+ LazyWriteState::CreateMultipart(_, _) => {
+ ready!(self.as_mut().poll_create_multipart(cx))?;
+ }
+ LazyWriteState::FlushingInitialWrite(writer, buffer,
flush_offset) => {
+ let n = ready!(Self::do_inner_flush(
+ cx,
+ writer.as_mut().unwrap(),
+ buffer,
+ flush_offset,
+ buf.len()
+ ))?;
+
+ if *flush_offset == buffer.len() {
+ wrote += n;
+ self.state =
LazyWriteState::AsyncWrite(writer.take().unwrap());
+ } else {
+ buffer.extend_from_slice(buf);
+ return Poll::Ready(Ok(n));
+ }
+ }
+ LazyWriteState::AsyncWrite(writer) => {
+ return Pin::new(writer).poll_write(cx, buf).map_ok(|n| n +
wrote)
Review Comment:
I'm not sure this is correct, the returned count is the number from `buf`
that is consumed, as it stands I think this will include any data that was
previously buffered as well
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]