This is an automated email from the ASF dual-hosted git repository. xuanwo pushed a commit to branch poll-write in repository https://gitbox.apache.org/repos/asf/incubator-opendal.git
commit 9b2534b2f72e5439a24f92f78ea3a2b3518db07b Author: Xuanwo <[email protected]> AuthorDate: Mon Sep 11 14:03:36 2023 +0800 Refactor gcs Signed-off-by: Xuanwo <[email protected]> --- core/src/raw/oio/write/mod.rs | 4 + core/src/raw/oio/write/multipart_upload_write.rs | 16 +- core/src/raw/oio/write/range_write.rs | 269 +++++++++++++++++++++++ core/src/services/gcs/backend.rs | 15 +- core/src/services/gcs/core.rs | 26 +-- core/src/services/gcs/writer.rs | 178 +++++---------- 6 files changed, 357 insertions(+), 151 deletions(-) diff --git a/core/src/raw/oio/write/mod.rs b/core/src/raw/oio/write/mod.rs index ff6d8377e..df0b83b28 100644 --- a/core/src/raw/oio/write/mod.rs +++ b/core/src/raw/oio/write/mod.rs @@ -42,3 +42,7 @@ pub use one_shot_write::OneShotWriter; mod exact_buf_write; pub use exact_buf_write::ExactBufWriter; + +mod range_write; +pub use range_write::RangeWrite; +pub use range_write::RangeWriter; diff --git a/core/src/raw/oio/write/multipart_upload_write.rs b/core/src/raw/oio/write/multipart_upload_write.rs index 81ee1f1ee..1b2148cb4 100644 --- a/core/src/raw/oio/write/multipart_upload_write.rs +++ b/core/src/raw/oio/write/multipart_upload_write.rs @@ -166,12 +166,12 @@ where } State::Close(_) => { unreachable!( - "MultipartUploadWriter must not go into State:Close during poll_write" + "MultipartUploadWriter must not go into State::Close during poll_write" ) } State::Abort(_) => { unreachable!( - "MultipartUploadWriter must not go into State:Abort during poll_write" + "MultipartUploadWriter must not go into State::Abort during poll_write" ) } } @@ -200,13 +200,13 @@ where return Poll::Ready(res); } State::Init(_) => unreachable!( - "MultipartUploadWriter must not go into State:Init during poll_close" + "MultipartUploadWriter must not go into State::Init during poll_close" ), State::Write(_) => unreachable!( - "MultipartUploadWriter must not go into State:Write during poll_close" + "MultipartUploadWriter must not go into State::Write during poll_close" ), State::Abort(_) => unreachable!( - "MultipartUploadWriter must not go into State:Abort during poll_close" + "MultipartUploadWriter must not go into State::Abort during poll_close" ), } } @@ -233,13 +233,13 @@ where return Poll::Ready(res); } State::Init(_) => unreachable!( - "MultipartUploadWriter must not go into State:Init during poll_abort" + "MultipartUploadWriter must not go into State::Init during poll_abort" ), State::Write(_) => unreachable!( - "MultipartUploadWriter must not go into State:Write during poll_abort" + "MultipartUploadWriter must not go into State::Write during poll_abort" ), State::Close(_) => unreachable!( - "MultipartUploadWriter must not go into State:Close during poll_abort" + "MultipartUploadWriter must not go into State::Close during poll_abort" ), } } diff --git a/core/src/raw/oio/write/range_write.rs b/core/src/raw/oio/write/range_write.rs new file mode 100644 index 000000000..a96af8d3d --- /dev/null +++ b/core/src/raw/oio/write/range_write.rs @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_trait::async_trait; +use futures::future::BoxFuture; +use futures::FutureExt; +use std::task::{ready, Context, Poll}; + +use crate::raw::oio::WriteBuf; +use crate::raw::*; +use crate::*; + +#[async_trait] +pub trait RangeWrite: Send + Sync + Unpin + 'static { + async fn initiate_range(&self) -> Result<String>; + + async fn write_range( + &self, + location: &str, + written: u64, + size: u64, + body: AsyncBody, + ) -> Result<()>; + + async fn complete_range( + &self, + location: &str, + written: u64, + size: u64, + body: AsyncBody, + ) -> Result<()>; + + async fn abort_range(&self, location: &str) -> Result<()>; +} + +pub struct RangeWriter<W: RangeWrite> { + location: Option<String>, + written: u64, + align_size: usize, + align_buffer: oio::ChunkedCursor, + + state: State<W>, +} + +enum State<W> { + Idle(Option<W>), + Init(BoxFuture<'static, (W, Result<String>)>), + /// The returning value is (consume size, written size). + Write(BoxFuture<'static, (W, Result<(usize, u64)>)>), + Complete(BoxFuture<'static, (W, Result<()>)>), + Abort(BoxFuture<'static, (W, Result<()>)>), +} + +/// # Safety +/// +/// We will only take `&mut Self` reference for State. +unsafe impl<W: RangeWrite> Sync for State<W> {} + +impl<W: RangeWrite> RangeWriter<W> { + /// Create a new MultipartUploadWriter. + pub fn new(inner: W) -> Self { + Self { + state: State::Idle(Some(inner)), + + location: None, + written: 0, + align_size: 256 * 1024, + align_buffer: oio::ChunkedCursor::default(), + } + } + + pub fn with_align_size(mut self, size: usize) -> Self { + self.align_size = size; + self + } +} + +impl<W: RangeWrite> RangeWriter<W> { + fn align(&mut self, bs: &dyn WriteBuf) -> usize { + let remaining = bs.remaining(); + let current_size = self.align_buffer.len(); + + let total_size = current_size + remaining; + if total_size <= self.align_size { + return remaining; + } + + total_size - total_size % self.align_size - current_size + } +} + +impl<W: RangeWrite> oio::Write for RangeWriter<W> { + fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn WriteBuf) -> Poll<Result<usize>> { + loop { + match &mut self.state { + State::Idle(w) => { + let w = w.take().unwrap(); + match self.location.clone() { + Some(location) => { + let remaining = bs.remaining(); + let current_size = self.align_buffer.len(); + let mut total_size = current_size + remaining; + + if total_size <= self.align_size { + let bs = bs.copy_to_bytes(remaining); + self.align_buffer.push(bs); + return Poll::Ready(Ok(remaining)); + } + // If total_size is aligned, we need to write one less chunk to make sure + // that the file has at least one chunk during complete stage. + if total_size % self.align_size == 0 { + total_size -= self.align_size; + } + + let consume = total_size - total_size % self.align_size - current_size; + let mut align_buffer = self.align_buffer.clone(); + let bs = bs.copy_to_bytes(consume); + align_buffer.push(bs); + + let written = self.written; + let fut = async move { + let size = align_buffer.len() as u64; + let res = w + .write_range( + &location, + written, + size, + AsyncBody::Stream(Box::new(align_buffer)), + ) + .await; + + (w, res.map(|_| (consume, size))) + }; + self.state = State::Write(Box::pin(fut)); + } + None => { + let fut = async move { + let res = w.initiate_range().await; + + (w, res) + }; + self.state = State::Init(Box::pin(fut)); + } + } + } + State::Init(fut) => { + let (w, res) = ready!(fut.poll_unpin(cx)); + self.state = State::Idle(Some(w)); + self.location = Some(res?); + } + State::Write(fut) => { + let (w, res) = ready!(fut.poll_unpin(cx)); + self.state = State::Idle(Some(w)); + let (consume, written) = res?; + self.written += written; + self.align_buffer.clear(); + return Poll::Ready(Ok(consume)); + } + State::Complete(_) => { + unreachable!("RangeWriter must not go into State::Complete during poll_write") + } + State::Abort(_) => { + unreachable!("RangeWriter must not go into State::Abort during poll_write") + } + } + } + } + + fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> { + loop { + match &mut self.state { + State::Idle(w) => { + let w = w.take().unwrap(); + match self.location.clone() { + Some(location) => { + debug_assert!( + self.align_buffer.len() > 0, + "RangeWriter requires to have last chunk" + ); + let mut align_buffer = self.align_buffer.clone(); + + let written = self.written; + let fut = async move { + let size = align_buffer.len() as u64; + let res = w + .complete_range( + &location, + written, + size, + AsyncBody::Stream(Box::new(align_buffer)), + ) + .await; + + (w, res) + }; + self.state = State::Complete(Box::pin(fut)); + } + None => return Poll::Ready(Ok(())), + } + } + State::Init(_) => { + unreachable!("RangeWriter must not go into State::Init during poll_close") + } + State::Write(_) => { + unreachable!("RangeWriter must not go into State::Write during poll_close") + } + State::Complete(fut) => { + let (w, res) = ready!(fut.poll_unpin(cx)); + self.state = State::Idle(Some(w)); + self.align_buffer.clear(); + return Poll::Ready(res); + } + State::Abort(_) => { + unreachable!("RangeWriter must not go into State::Abort during poll_close") + } + } + } + } + + fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> { + loop { + match &mut self.state { + State::Idle(w) => { + let w = w.take().unwrap(); + match self.location.clone() { + Some(location) => { + let fut = async move { + let res = w.abort_range(&location).await; + + (w, res) + }; + self.state = State::Complete(Box::pin(fut)); + } + None => return Poll::Ready(Ok(())), + } + } + State::Init(_) => { + unreachable!("RangeWriter must not go into State::Init during poll_close") + } + State::Write(_) => { + unreachable!("RangeWriter must not go into State::Write during poll_close") + } + State::Complete(_) => { + unreachable!("RangeWriter must not go into State::Complete during poll_close") + } + State::Abort(fut) => { + let (w, res) = ready!(fut.poll_unpin(cx)); + self.state = State::Idle(Some(w)); + self.align_buffer.clear(); + return Poll::Ready(res); + } + } + } + } +} diff --git a/core/src/services/gcs/backend.rs b/core/src/services/gcs/backend.rs index 77549908e..f91c792a0 100644 --- a/core/src/services/gcs/backend.rs +++ b/core/src/services/gcs/backend.rs @@ -35,6 +35,7 @@ use super::error::parse_error; use super::pager::GcsPager; use super::writer::GcsWriter; use crate::raw::*; +use crate::services::gcs::writer::GcsWriters; use crate::*; const DEFAULT_GCS_ENDPOINT: &str = "https://storage.googleapis.com"; @@ -336,7 +337,7 @@ pub struct GcsBackend { impl Accessor for GcsBackend { type Reader = IncomingAsyncBody; type BlockingReader = (); - type Writer = GcsWriter; + type Writer = GcsWriters; type BlockingWriter = (); type Pager = GcsPager; type BlockingPager = (); @@ -418,10 +419,14 @@ impl Accessor for GcsBackend { } async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> { - Ok(( - RpWrite::default(), - GcsWriter::new(self.core.clone(), path, args), - )) + let w = GcsWriter::new(self.core.clone(), path, args.clone()); + let w = if args.content_length().is_some() { + GcsWriters::One(oio::OneShotWriter::new(w)) + } else { + GcsWriters::Two(oio::RangeWriter::new(w)) + }; + + Ok((RpWrite::default(), w)) } async fn copy(&self, from: &str, to: &str, _: OpCopy) -> Result<RpCopy> { diff --git a/core/src/services/gcs/core.rs b/core/src/services/gcs/core.rs index b8eca1ece..82c6d9054 100644 --- a/core/src/services/gcs/core.rs +++ b/core/src/services/gcs/core.rs @@ -558,21 +558,11 @@ impl GcsCore { location: &str, size: u64, written: u64, - is_last_part: bool, body: AsyncBody, ) -> Result<Request<AsyncBody>> { let mut req = Request::put(location); - let range_header = if is_last_part { - format!( - "bytes {}-{}/{}", - written, - written + size - 1, - written + size - ) - } else { - format!("bytes {}-{}/*", written, written + size - 1) - }; + let range_header = format!("bytes {}-{}/*", written, written + size - 1); req = req .header(CONTENT_LENGTH, size) @@ -587,22 +577,22 @@ impl GcsCore { pub async fn gcs_complete_resumable_upload( &self, location: &str, - written_bytes: u64, - bs: Bytes, + written: u64, + size: u64, + body: AsyncBody, ) -> Result<Response<IncomingAsyncBody>> { - let size = bs.len() as u64; let mut req = Request::post(location) .header(CONTENT_LENGTH, size) .header( CONTENT_RANGE, format!( "bytes {}-{}/{}", - written_bytes, - written_bytes + size - 1, - written_bytes + size + written, + written + size - 1, + written + size ), ) - .body(AsyncBody::Bytes(bs)) + .body(body) .map_err(new_request_build_error)?; self.sign(&mut req).await?; diff --git a/core/src/services/gcs/writer.rs b/core/src/services/gcs/writer.rs index c61c7df4c..6b66240f7 100644 --- a/core/src/services/gcs/writer.rs +++ b/core/src/services/gcs/writer.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use async_trait::async_trait; use bytes::Bytes; +use futures::future::BoxFuture; +use futures::FutureExt; use http::StatusCode; use super::core::GcsCore; @@ -27,38 +30,33 @@ use super::error::parse_error; use crate::raw::*; use crate::*; +pub type GcsWriters = + oio::TwoWaysWriter<oio::OneShotWriter<GcsWriter>, oio::RangeWriter<GcsWriter>>; + pub struct GcsWriter { core: Arc<GcsCore>, path: String, op: OpWrite, - - location: Option<String>, - written: u64, - buffer: oio::VectorCursor, - write_fixed_size: usize, } impl GcsWriter { pub fn new(core: Arc<GcsCore>, path: &str, op: OpWrite) -> Self { - let write_fixed_size = core.write_fixed_size; GcsWriter { core, path: path.to_string(), op, - - location: None, - written: 0, - buffer: oio::VectorCursor::new(), - write_fixed_size, } } +} - async fn write_oneshot(&self, size: u64, body: AsyncBody) -> Result<()> { +#[async_trait] +impl oio::OneShotWrite for GcsWriter { + async fn write_once(&self, bs: Bytes) -> Result<()> { let mut req = self.core.gcs_insert_object_request( &percent_encode_path(&self.path), - Some(size), + Some(bs.len() as u64), &self.op, - body, + AsyncBody::Bytes(bs), )?; self.core.sign(&mut req).await?; @@ -75,8 +73,11 @@ impl GcsWriter { _ => Err(parse_error(resp).await?), } } +} - async fn initiate_upload(&self) -> Result<String> { +#[async_trait] +impl oio::RangeWrite for GcsWriter { + async fn initiate_range(&self) -> Result<String> { let resp = self.core.gcs_initiate_resumable_upload(&self.path).await?; let status = resp.status(); @@ -96,14 +97,16 @@ impl GcsWriter { } } - async fn write_part(&self, location: &str, bs: Bytes) -> Result<()> { - let mut req = self.core.gcs_upload_in_resumable_upload( - location, - bs.len() as u64, - self.written, - false, - AsyncBody::Bytes(bs), - )?; + async fn write_range( + &self, + location: &str, + written: u64, + size: u64, + body: AsyncBody, + ) -> Result<()> { + let mut req = self + .core + .gcs_upload_in_resumable_upload(location, size, written, body)?; self.core.sign(&mut req).await?; @@ -115,105 +118,40 @@ impl GcsWriter { _ => Err(parse_error(resp).await?), } } -} -#[async_trait] -impl oio::Write for GcsWriter { - fn poll_write(&mut self, _: &mut Context<'_>, _: &dyn oio::WriteBuf) -> Poll<Result<usize>> { - // let size = bs.remaining(); - // - // let location = match &self.location { - // Some(location) => location, - // None => { - // if self.op.content_length().unwrap_or_default() == size as u64 && self.written == 0 - // { - // self.write_oneshot(size as u64, AsyncBody::Bytes(bs.copy_to_bytes(size))) - // .await?; - // - // return Ok(size); - // } else { - // let location = self.initiate_upload().await?; - // self.location = Some(location); - // self.location.as_deref().unwrap() - // } - // } - // }; - // - // self.buffer.push(bs.copy_to_bytes(size)); - // // Return directly if the buffer is not full - // if self.buffer.len() <= self.write_fixed_size { - // return Ok(size); - // } - // - // let bs = self.buffer.peak_exact(self.write_fixed_size); - // - // match self.write_part(location, bs).await { - // Ok(_) => { - // self.buffer.take(self.write_fixed_size); - // self.written += self.write_fixed_size as u64; - // Ok(size) - // } - // Err(e) => { - // // If the upload fails, we should pop the given bs to make sure - // // write is re-enter safe. - // self.buffer.pop(); - // Err(e) - // } - // } - - todo!() - } + async fn complete_range( + &self, + location: &str, + written: u64, + size: u64, + body: AsyncBody, + ) -> Result<()> { + let resp = self + .core + .gcs_complete_resumable_upload(location, written, size, body) + .await?; - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll<Result<()>> { - // let location = if let Some(location) = &self.location { - // location - // } else { - // return Ok(()); - // }; - // - // let resp = self.core.gcs_abort_resumable_upload(location).await?; - // - // match resp.status().as_u16() { - // // gcs returns 499 if the upload aborted successfully - // // reference: https://cloud.google.com/storage/docs/performing-resumable-uploads#cancel-upload-json - // 499 => { - // resp.into_body().consume().await?; - // self.location = None; - // self.buffer.clear(); - // Ok(()) - // } - // _ => Err(parse_error(resp).await?), - // } - - todo!() + let status = resp.status(); + match status { + StatusCode::OK => { + resp.into_body().consume().await?; + Ok(()) + } + _ => Err(parse_error(resp).await?), + } } - fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Result<()>> { - // let location = if let Some(location) = &self.location { - // location - // } else { - // return Ok(()); - // }; - // - // let bs = self.buffer.peak_exact(self.buffer.len()); - // - // let resp = self - // .core - // .gcs_complete_resumable_upload(location, self.written, bs) - // .await?; - // - // let status = resp.status(); - // match status { - // StatusCode::OK => { - // resp.into_body().consume().await?; - // - // self.location = None; - // self.buffer.clear(); - // Ok(()) - // } - // _ => Err(parse_error(resp).await?), - // } - - todo!() + async fn abort_range(&self, location: &str) -> Result<()> { + let resp = self.core.gcs_abort_resumable_upload(location).await?; + + match resp.status().as_u16() { + // gcs returns 499 if the upload aborted successfully + // reference: https://cloud.google.com/storage/docs/performing-resumable-uploads#cancel-upload-json + 499 => { + resp.into_body().consume().await?; + Ok(()) + } + _ => Err(parse_error(resp).await?), + } } }
