This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch poll-write
in repository https://gitbox.apache.org/repos/asf/incubator-opendal.git

commit 9b2534b2f72e5439a24f92f78ea3a2b3518db07b
Author: Xuanwo <[email protected]>
AuthorDate: Mon Sep 11 14:03:36 2023 +0800

    Refactor gcs
    
    Signed-off-by: Xuanwo <[email protected]>
---
 core/src/raw/oio/write/mod.rs                    |   4 +
 core/src/raw/oio/write/multipart_upload_write.rs |  16 +-
 core/src/raw/oio/write/range_write.rs            | 269 +++++++++++++++++++++++
 core/src/services/gcs/backend.rs                 |  15 +-
 core/src/services/gcs/core.rs                    |  26 +--
 core/src/services/gcs/writer.rs                  | 178 +++++----------
 6 files changed, 357 insertions(+), 151 deletions(-)

diff --git a/core/src/raw/oio/write/mod.rs b/core/src/raw/oio/write/mod.rs
index ff6d8377e..df0b83b28 100644
--- a/core/src/raw/oio/write/mod.rs
+++ b/core/src/raw/oio/write/mod.rs
@@ -42,3 +42,7 @@ pub use one_shot_write::OneShotWriter;
 
 mod exact_buf_write;
 pub use exact_buf_write::ExactBufWriter;
+
+mod range_write;
+pub use range_write::RangeWrite;
+pub use range_write::RangeWriter;
diff --git a/core/src/raw/oio/write/multipart_upload_write.rs 
b/core/src/raw/oio/write/multipart_upload_write.rs
index 81ee1f1ee..1b2148cb4 100644
--- a/core/src/raw/oio/write/multipart_upload_write.rs
+++ b/core/src/raw/oio/write/multipart_upload_write.rs
@@ -166,12 +166,12 @@ where
                 }
                 State::Close(_) => {
                     unreachable!(
-                        "MultipartUploadWriter must not go into State:Close 
during poll_write"
+                        "MultipartUploadWriter must not go into State::Close 
during poll_write"
                     )
                 }
                 State::Abort(_) => {
                     unreachable!(
-                        "MultipartUploadWriter must not go into State:Abort 
during poll_write"
+                        "MultipartUploadWriter must not go into State::Abort 
during poll_write"
                     )
                 }
             }
@@ -200,13 +200,13 @@ where
                     return Poll::Ready(res);
                 }
                 State::Init(_) => unreachable!(
-                    "MultipartUploadWriter must not go into State:Init during 
poll_close"
+                    "MultipartUploadWriter must not go into State::Init during 
poll_close"
                 ),
                 State::Write(_) => unreachable!(
-                    "MultipartUploadWriter must not go into State:Write during 
poll_close"
+                    "MultipartUploadWriter must not go into State::Write 
during poll_close"
                 ),
                 State::Abort(_) => unreachable!(
-                    "MultipartUploadWriter must not go into State:Abort during 
poll_close"
+                    "MultipartUploadWriter must not go into State::Abort 
during poll_close"
                 ),
             }
         }
@@ -233,13 +233,13 @@ where
                     return Poll::Ready(res);
                 }
                 State::Init(_) => unreachable!(
-                    "MultipartUploadWriter must not go into State:Init during 
poll_abort"
+                    "MultipartUploadWriter must not go into State::Init during 
poll_abort"
                 ),
                 State::Write(_) => unreachable!(
-                    "MultipartUploadWriter must not go into State:Write during 
poll_abort"
+                    "MultipartUploadWriter must not go into State::Write 
during poll_abort"
                 ),
                 State::Close(_) => unreachable!(
-                    "MultipartUploadWriter must not go into State:Close during 
poll_abort"
+                    "MultipartUploadWriter must not go into State::Close 
during poll_abort"
                 ),
             }
         }
diff --git a/core/src/raw/oio/write/range_write.rs 
b/core/src/raw/oio/write/range_write.rs
new file mode 100644
index 000000000..a96af8d3d
--- /dev/null
+++ b/core/src/raw/oio/write/range_write.rs
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use async_trait::async_trait;
+use futures::future::BoxFuture;
+use futures::FutureExt;
+use std::task::{ready, Context, Poll};
+
+use crate::raw::oio::WriteBuf;
+use crate::raw::*;
+use crate::*;
+
+#[async_trait]
+pub trait RangeWrite: Send + Sync + Unpin + 'static {
+    async fn initiate_range(&self) -> Result<String>;
+
+    async fn write_range(
+        &self,
+        location: &str,
+        written: u64,
+        size: u64,
+        body: AsyncBody,
+    ) -> Result<()>;
+
+    async fn complete_range(
+        &self,
+        location: &str,
+        written: u64,
+        size: u64,
+        body: AsyncBody,
+    ) -> Result<()>;
+
+    async fn abort_range(&self, location: &str) -> Result<()>;
+}
+
+pub struct RangeWriter<W: RangeWrite> {
+    location: Option<String>,
+    written: u64,
+    align_size: usize,
+    align_buffer: oio::ChunkedCursor,
+
+    state: State<W>,
+}
+
+enum State<W> {
+    Idle(Option<W>),
+    Init(BoxFuture<'static, (W, Result<String>)>),
+    /// The returning value is (consume size, written size).
+    Write(BoxFuture<'static, (W, Result<(usize, u64)>)>),
+    Complete(BoxFuture<'static, (W, Result<()>)>),
+    Abort(BoxFuture<'static, (W, Result<()>)>),
+}
+
+/// # Safety
+///
+/// We will only take `&mut Self` reference for State.
+unsafe impl<W: RangeWrite> Sync for State<W> {}
+
+impl<W: RangeWrite> RangeWriter<W> {
+    /// Create a new MultipartUploadWriter.
+    pub fn new(inner: W) -> Self {
+        Self {
+            state: State::Idle(Some(inner)),
+
+            location: None,
+            written: 0,
+            align_size: 256 * 1024,
+            align_buffer: oio::ChunkedCursor::default(),
+        }
+    }
+
+    pub fn with_align_size(mut self, size: usize) -> Self {
+        self.align_size = size;
+        self
+    }
+}
+
+impl<W: RangeWrite> RangeWriter<W> {
+    fn align(&mut self, bs: &dyn WriteBuf) -> usize {
+        let remaining = bs.remaining();
+        let current_size = self.align_buffer.len();
+
+        let total_size = current_size + remaining;
+        if total_size <= self.align_size {
+            return remaining;
+        }
+
+        total_size - total_size % self.align_size - current_size
+    }
+}
+
+impl<W: RangeWrite> oio::Write for RangeWriter<W> {
+    fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn WriteBuf) -> 
Poll<Result<usize>> {
+        loop {
+            match &mut self.state {
+                State::Idle(w) => {
+                    let w = w.take().unwrap();
+                    match self.location.clone() {
+                        Some(location) => {
+                            let remaining = bs.remaining();
+                            let current_size = self.align_buffer.len();
+                            let mut total_size = current_size + remaining;
+
+                            if total_size <= self.align_size {
+                                let bs = bs.copy_to_bytes(remaining);
+                                self.align_buffer.push(bs);
+                                return Poll::Ready(Ok(remaining));
+                            }
+                            // If total_size is aligned, we need to write one 
less chunk to make sure
+                            // that the file has at least one chunk during 
complete stage.
+                            if total_size % self.align_size == 0 {
+                                total_size -= self.align_size;
+                            }
+
+                            let consume = total_size - total_size % 
self.align_size - current_size;
+                            let mut align_buffer = self.align_buffer.clone();
+                            let bs = bs.copy_to_bytes(consume);
+                            align_buffer.push(bs);
+
+                            let written = self.written;
+                            let fut = async move {
+                                let size = align_buffer.len() as u64;
+                                let res = w
+                                    .write_range(
+                                        &location,
+                                        written,
+                                        size,
+                                        
AsyncBody::Stream(Box::new(align_buffer)),
+                                    )
+                                    .await;
+
+                                (w, res.map(|_| (consume, size)))
+                            };
+                            self.state = State::Write(Box::pin(fut));
+                        }
+                        None => {
+                            let fut = async move {
+                                let res = w.initiate_range().await;
+
+                                (w, res)
+                            };
+                            self.state = State::Init(Box::pin(fut));
+                        }
+                    }
+                }
+                State::Init(fut) => {
+                    let (w, res) = ready!(fut.poll_unpin(cx));
+                    self.state = State::Idle(Some(w));
+                    self.location = Some(res?);
+                }
+                State::Write(fut) => {
+                    let (w, res) = ready!(fut.poll_unpin(cx));
+                    self.state = State::Idle(Some(w));
+                    let (consume, written) = res?;
+                    self.written += written;
+                    self.align_buffer.clear();
+                    return Poll::Ready(Ok(consume));
+                }
+                State::Complete(_) => {
+                    unreachable!("RangeWriter must not go into State::Complete 
during poll_write")
+                }
+                State::Abort(_) => {
+                    unreachable!("RangeWriter must not go into State::Abort 
during poll_write")
+                }
+            }
+        }
+    }
+
+    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
+        loop {
+            match &mut self.state {
+                State::Idle(w) => {
+                    let w = w.take().unwrap();
+                    match self.location.clone() {
+                        Some(location) => {
+                            debug_assert!(
+                                self.align_buffer.len() > 0,
+                                "RangeWriter requires to have last chunk"
+                            );
+                            let mut align_buffer = self.align_buffer.clone();
+
+                            let written = self.written;
+                            let fut = async move {
+                                let size = align_buffer.len() as u64;
+                                let res = w
+                                    .complete_range(
+                                        &location,
+                                        written,
+                                        size,
+                                        
AsyncBody::Stream(Box::new(align_buffer)),
+                                    )
+                                    .await;
+
+                                (w, res)
+                            };
+                            self.state = State::Complete(Box::pin(fut));
+                        }
+                        None => return Poll::Ready(Ok(())),
+                    }
+                }
+                State::Init(_) => {
+                    unreachable!("RangeWriter must not go into State::Init 
during poll_close")
+                }
+                State::Write(_) => {
+                    unreachable!("RangeWriter must not go into State::Write 
during poll_close")
+                }
+                State::Complete(fut) => {
+                    let (w, res) = ready!(fut.poll_unpin(cx));
+                    self.state = State::Idle(Some(w));
+                    self.align_buffer.clear();
+                    return Poll::Ready(res);
+                }
+                State::Abort(_) => {
+                    unreachable!("RangeWriter must not go into State::Abort 
during poll_close")
+                }
+            }
+        }
+    }
+
+    fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
+        loop {
+            match &mut self.state {
+                State::Idle(w) => {
+                    let w = w.take().unwrap();
+                    match self.location.clone() {
+                        Some(location) => {
+                            let fut = async move {
+                                let res = w.abort_range(&location).await;
+
+                                (w, res)
+                            };
+                            self.state = State::Complete(Box::pin(fut));
+                        }
+                        None => return Poll::Ready(Ok(())),
+                    }
+                }
+                State::Init(_) => {
+                    unreachable!("RangeWriter must not go into State::Init 
during poll_close")
+                }
+                State::Write(_) => {
+                    unreachable!("RangeWriter must not go into State::Write 
during poll_close")
+                }
+                State::Complete(_) => {
+                    unreachable!("RangeWriter must not go into State::Complete 
during poll_close")
+                }
+                State::Abort(fut) => {
+                    let (w, res) = ready!(fut.poll_unpin(cx));
+                    self.state = State::Idle(Some(w));
+                    self.align_buffer.clear();
+                    return Poll::Ready(res);
+                }
+            }
+        }
+    }
+}
diff --git a/core/src/services/gcs/backend.rs b/core/src/services/gcs/backend.rs
index 77549908e..f91c792a0 100644
--- a/core/src/services/gcs/backend.rs
+++ b/core/src/services/gcs/backend.rs
@@ -35,6 +35,7 @@ use super::error::parse_error;
 use super::pager::GcsPager;
 use super::writer::GcsWriter;
 use crate::raw::*;
+use crate::services::gcs::writer::GcsWriters;
 use crate::*;
 
 const DEFAULT_GCS_ENDPOINT: &str = "https://storage.googleapis.com";;
@@ -336,7 +337,7 @@ pub struct GcsBackend {
 impl Accessor for GcsBackend {
     type Reader = IncomingAsyncBody;
     type BlockingReader = ();
-    type Writer = GcsWriter;
+    type Writer = GcsWriters;
     type BlockingWriter = ();
     type Pager = GcsPager;
     type BlockingPager = ();
@@ -418,10 +419,14 @@ impl Accessor for GcsBackend {
     }
 
     async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, 
Self::Writer)> {
-        Ok((
-            RpWrite::default(),
-            GcsWriter::new(self.core.clone(), path, args),
-        ))
+        let w = GcsWriter::new(self.core.clone(), path, args.clone());
+        let w = if args.content_length().is_some() {
+            GcsWriters::One(oio::OneShotWriter::new(w))
+        } else {
+            GcsWriters::Two(oio::RangeWriter::new(w))
+        };
+
+        Ok((RpWrite::default(), w))
     }
 
     async fn copy(&self, from: &str, to: &str, _: OpCopy) -> Result<RpCopy> {
diff --git a/core/src/services/gcs/core.rs b/core/src/services/gcs/core.rs
index b8eca1ece..82c6d9054 100644
--- a/core/src/services/gcs/core.rs
+++ b/core/src/services/gcs/core.rs
@@ -558,21 +558,11 @@ impl GcsCore {
         location: &str,
         size: u64,
         written: u64,
-        is_last_part: bool,
         body: AsyncBody,
     ) -> Result<Request<AsyncBody>> {
         let mut req = Request::put(location);
 
-        let range_header = if is_last_part {
-            format!(
-                "bytes {}-{}/{}",
-                written,
-                written + size - 1,
-                written + size
-            )
-        } else {
-            format!("bytes {}-{}/*", written, written + size - 1)
-        };
+        let range_header = format!("bytes {}-{}/*", written, written + size - 
1);
 
         req = req
             .header(CONTENT_LENGTH, size)
@@ -587,22 +577,22 @@ impl GcsCore {
     pub async fn gcs_complete_resumable_upload(
         &self,
         location: &str,
-        written_bytes: u64,
-        bs: Bytes,
+        written: u64,
+        size: u64,
+        body: AsyncBody,
     ) -> Result<Response<IncomingAsyncBody>> {
-        let size = bs.len() as u64;
         let mut req = Request::post(location)
             .header(CONTENT_LENGTH, size)
             .header(
                 CONTENT_RANGE,
                 format!(
                     "bytes {}-{}/{}",
-                    written_bytes,
-                    written_bytes + size - 1,
-                    written_bytes + size
+                    written,
+                    written + size - 1,
+                    written + size
                 ),
             )
-            .body(AsyncBody::Bytes(bs))
+            .body(body)
             .map_err(new_request_build_error)?;
 
         self.sign(&mut req).await?;
diff --git a/core/src/services/gcs/writer.rs b/core/src/services/gcs/writer.rs
index c61c7df4c..6b66240f7 100644
--- a/core/src/services/gcs/writer.rs
+++ b/core/src/services/gcs/writer.rs
@@ -15,11 +15,14 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::fmt::{Debug, Formatter};
 use std::sync::Arc;
-use std::task::{Context, Poll};
+use std::task::{ready, Context, Poll};
 
 use async_trait::async_trait;
 use bytes::Bytes;
+use futures::future::BoxFuture;
+use futures::FutureExt;
 use http::StatusCode;
 
 use super::core::GcsCore;
@@ -27,38 +30,33 @@ use super::error::parse_error;
 use crate::raw::*;
 use crate::*;
 
+pub type GcsWriters =
+    oio::TwoWaysWriter<oio::OneShotWriter<GcsWriter>, 
oio::RangeWriter<GcsWriter>>;
+
 pub struct GcsWriter {
     core: Arc<GcsCore>,
     path: String,
     op: OpWrite,
-
-    location: Option<String>,
-    written: u64,
-    buffer: oio::VectorCursor,
-    write_fixed_size: usize,
 }
 
 impl GcsWriter {
     pub fn new(core: Arc<GcsCore>, path: &str, op: OpWrite) -> Self {
-        let write_fixed_size = core.write_fixed_size;
         GcsWriter {
             core,
             path: path.to_string(),
             op,
-
-            location: None,
-            written: 0,
-            buffer: oio::VectorCursor::new(),
-            write_fixed_size,
         }
     }
+}
 
-    async fn write_oneshot(&self, size: u64, body: AsyncBody) -> Result<()> {
+#[async_trait]
+impl oio::OneShotWrite for GcsWriter {
+    async fn write_once(&self, bs: Bytes) -> Result<()> {
         let mut req = self.core.gcs_insert_object_request(
             &percent_encode_path(&self.path),
-            Some(size),
+            Some(bs.len() as u64),
             &self.op,
-            body,
+            AsyncBody::Bytes(bs),
         )?;
 
         self.core.sign(&mut req).await?;
@@ -75,8 +73,11 @@ impl GcsWriter {
             _ => Err(parse_error(resp).await?),
         }
     }
+}
 
-    async fn initiate_upload(&self) -> Result<String> {
+#[async_trait]
+impl oio::RangeWrite for GcsWriter {
+    async fn initiate_range(&self) -> Result<String> {
         let resp = self.core.gcs_initiate_resumable_upload(&self.path).await?;
         let status = resp.status();
 
@@ -96,14 +97,16 @@ impl GcsWriter {
         }
     }
 
-    async fn write_part(&self, location: &str, bs: Bytes) -> Result<()> {
-        let mut req = self.core.gcs_upload_in_resumable_upload(
-            location,
-            bs.len() as u64,
-            self.written,
-            false,
-            AsyncBody::Bytes(bs),
-        )?;
+    async fn write_range(
+        &self,
+        location: &str,
+        written: u64,
+        size: u64,
+        body: AsyncBody,
+    ) -> Result<()> {
+        let mut req = self
+            .core
+            .gcs_upload_in_resumable_upload(location, size, written, body)?;
 
         self.core.sign(&mut req).await?;
 
@@ -115,105 +118,40 @@ impl GcsWriter {
             _ => Err(parse_error(resp).await?),
         }
     }
-}
 
-#[async_trait]
-impl oio::Write for GcsWriter {
-    fn poll_write(&mut self, _: &mut Context<'_>, _: &dyn oio::WriteBuf) -> 
Poll<Result<usize>> {
-        // let size = bs.remaining();
-        //
-        // let location = match &self.location {
-        //     Some(location) => location,
-        //     None => {
-        //         if self.op.content_length().unwrap_or_default() == size as 
u64 && self.written == 0
-        //         {
-        //             self.write_oneshot(size as u64, 
AsyncBody::Bytes(bs.copy_to_bytes(size)))
-        //                 .await?;
-        //
-        //             return Ok(size);
-        //         } else {
-        //             let location = self.initiate_upload().await?;
-        //             self.location = Some(location);
-        //             self.location.as_deref().unwrap()
-        //         }
-        //     }
-        // };
-        //
-        // self.buffer.push(bs.copy_to_bytes(size));
-        // // Return directly if the buffer is not full
-        // if self.buffer.len() <= self.write_fixed_size {
-        //     return Ok(size);
-        // }
-        //
-        // let bs = self.buffer.peak_exact(self.write_fixed_size);
-        //
-        // match self.write_part(location, bs).await {
-        //     Ok(_) => {
-        //         self.buffer.take(self.write_fixed_size);
-        //         self.written += self.write_fixed_size as u64;
-        //         Ok(size)
-        //     }
-        //     Err(e) => {
-        //         // If the upload fails, we should pop the given bs to make 
sure
-        //         // write is re-enter safe.
-        //         self.buffer.pop();
-        //         Err(e)
-        //     }
-        // }
-
-        todo!()
-    }
+    async fn complete_range(
+        &self,
+        location: &str,
+        written: u64,
+        size: u64,
+        body: AsyncBody,
+    ) -> Result<()> {
+        let resp = self
+            .core
+            .gcs_complete_resumable_upload(location, written, size, body)
+            .await?;
 
-    fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll<Result<()>> {
-        // let location = if let Some(location) = &self.location {
-        //     location
-        // } else {
-        //     return Ok(());
-        // };
-        //
-        // let resp = self.core.gcs_abort_resumable_upload(location).await?;
-        //
-        // match resp.status().as_u16() {
-        //     // gcs returns 499 if the upload aborted successfully
-        //     // reference: 
https://cloud.google.com/storage/docs/performing-resumable-uploads#cancel-upload-json
-        //     499 => {
-        //         resp.into_body().consume().await?;
-        //         self.location = None;
-        //         self.buffer.clear();
-        //         Ok(())
-        //     }
-        //     _ => Err(parse_error(resp).await?),
-        // }
-
-        todo!()
+        let status = resp.status();
+        match status {
+            StatusCode::OK => {
+                resp.into_body().consume().await?;
+                Ok(())
+            }
+            _ => Err(parse_error(resp).await?),
+        }
     }
 
-    fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Result<()>> {
-        // let location = if let Some(location) = &self.location {
-        //     location
-        // } else {
-        //     return Ok(());
-        // };
-        //
-        // let bs = self.buffer.peak_exact(self.buffer.len());
-        //
-        // let resp = self
-        //     .core
-        //     .gcs_complete_resumable_upload(location, self.written, bs)
-        //     .await?;
-        //
-        // let status = resp.status();
-        // match status {
-        //     StatusCode::OK => {
-        //         resp.into_body().consume().await?;
-        //
-        //         self.location = None;
-        //         self.buffer.clear();
-        //         Ok(())
-        //     }
-        //     _ => Err(parse_error(resp).await?),
-        // }
-
-        todo!()
+    async fn abort_range(&self, location: &str) -> Result<()> {
+        let resp = self.core.gcs_abort_resumable_upload(location).await?;
+
+        match resp.status().as_u16() {
+            // gcs returns 499 if the upload aborted successfully
+            // reference: 
https://cloud.google.com/storage/docs/performing-resumable-uploads#cancel-upload-json
+            499 => {
+                resp.into_body().consume().await?;
+                Ok(())
+            }
+            _ => Err(parse_error(resp).await?),
+        }
     }
 }

Reply via email to