This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch poll-write
in repository https://gitbox.apache.org/repos/asf/incubator-opendal.git

commit 89d5bed9b71d6352ae9ee8ef42888a573001cf83
Author: Xuanwo <[email protected]>
AuthorDate: Fri Sep 8 14:04:14 2023 +0800

    Save work
    
    Signed-off-by: Xuanwo <[email protected]>
---
 core/src/layers/oteltrace.rs                     |  37 ++---
 core/src/raw/oio/write/multipart_upload_write.rs | 181 +++++++++++++++++------
 2 files changed, 147 insertions(+), 71 deletions(-)

diff --git a/core/src/layers/oteltrace.rs b/core/src/layers/oteltrace.rs
index df9bc2400..145b997d5 100644
--- a/core/src/layers/oteltrace.rs
+++ b/core/src/layers/oteltrace.rs
@@ -17,6 +17,7 @@
 
 use std::io;
 use std::task;
+use std::task::{Context, Poll};
 
 use async_trait::async_trait;
 use bytes::Bytes;
@@ -27,7 +28,7 @@ use opentelemetry::trace::FutureExt as TraceFutureExt;
 use opentelemetry::trace::Span;
 use opentelemetry::trace::TraceContextExt;
 use opentelemetry::trace::Tracer;
-use opentelemetry::Context;
+use opentelemetry::Context as TraceContext;
 use opentelemetry::KeyValue;
 
 use crate::raw::*;
@@ -89,7 +90,7 @@ impl<A: Accessor> LayeredAccessor for OtelTraceAccessor<A> {
         let mut span = tracer.start("create");
         span.set_attribute(KeyValue::new("path", path.to_string()));
         span.set_attribute(KeyValue::new("args", format!("{:?}", args)));
-        let cx = Context::current_with_span(span);
+        let cx = TraceContext::current_with_span(span);
         self.inner.create_dir(path, args).with_context(cx).await
     }
 
@@ -121,7 +122,7 @@ impl<A: Accessor> LayeredAccessor for OtelTraceAccessor<A> {
         span.set_attribute(KeyValue::new("from", from.to_string()));
         span.set_attribute(KeyValue::new("to", to.to_string()));
         span.set_attribute(KeyValue::new("args", format!("{:?}", args)));
-        let cx = Context::current_with_span(span);
+        let cx = TraceContext::current_with_span(span);
         self.inner().copy(from, to, args).with_context(cx).await
     }
 
@@ -131,7 +132,7 @@ impl<A: Accessor> LayeredAccessor for OtelTraceAccessor<A> {
         span.set_attribute(KeyValue::new("from", from.to_string()));
         span.set_attribute(KeyValue::new("to", to.to_string()));
         span.set_attribute(KeyValue::new("args", format!("{:?}", args)));
-        let cx = Context::current_with_span(span);
+        let cx = TraceContext::current_with_span(span);
         self.inner().rename(from, to, args).with_context(cx).await
     }
 
@@ -140,7 +141,7 @@ impl<A: Accessor> LayeredAccessor for OtelTraceAccessor<A> {
         let mut span = tracer.start("stat");
         span.set_attribute(KeyValue::new("path", path.to_string()));
         span.set_attribute(KeyValue::new("args", format!("{:?}", args)));
-        let cx = Context::current_with_span(span);
+        let cx = TraceContext::current_with_span(span);
         self.inner().stat(path, args).with_context(cx).await
     }
 
@@ -149,7 +150,7 @@ impl<A: Accessor> LayeredAccessor for OtelTraceAccessor<A> {
         let mut span = tracer.start("delete");
         span.set_attribute(KeyValue::new("path", path.to_string()));
         span.set_attribute(KeyValue::new("args", format!("{:?}", args)));
-        let cx = Context::current_with_span(span);
+        let cx = TraceContext::current_with_span(span);
         self.inner().delete(path, args).with_context(cx).await
     }
 
@@ -168,7 +169,7 @@ impl<A: Accessor> LayeredAccessor for OtelTraceAccessor<A> {
         let tracer = global::tracer("opendal");
         let mut span = tracer.start("batch");
         span.set_attribute(KeyValue::new("args", format!("{:?}", args)));
-        let cx = Context::current_with_span(span);
+        let cx = TraceContext::current_with_span(span);
         self.inner().batch(args).with_context(cx).await
     }
 
@@ -177,7 +178,7 @@ impl<A: Accessor> LayeredAccessor for OtelTraceAccessor<A> {
         let mut span = tracer.start("presign");
         span.set_attribute(KeyValue::new("path", path.to_string()));
         span.set_attribute(KeyValue::new("args", format!("{:?}", args)));
-        let cx = Context::current_with_span(span);
+        let cx = TraceContext::current_with_span(span);
         self.inner().presign(path, args).with_context(cx).await
     }
 
@@ -276,23 +277,15 @@ impl<R> OtelTraceWrapper<R> {
 }
 
 impl<R: oio::Read> oio::Read for OtelTraceWrapper<R> {
-    fn poll_read(
-        &mut self,
-        cx: &mut task::Context<'_>,
-        buf: &mut [u8],
-    ) -> task::Poll<Result<usize>> {
+    fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> 
Poll<Result<usize>> {
         self.inner.poll_read(cx, buf)
     }
 
-    fn poll_seek(
-        &mut self,
-        cx: &mut task::Context<'_>,
-        pos: io::SeekFrom,
-    ) -> task::Poll<Result<u64>> {
+    fn poll_seek(&mut self, cx: &mut Context<'_>, pos: io::SeekFrom) -> 
Poll<Result<u64>> {
         self.inner.poll_seek(cx, pos)
     }
 
-    fn poll_next(&mut self, cx: &mut task::Context<'_>) -> 
task::Poll<Option<Result<Bytes>>> {
+    fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         self.inner.poll_next(cx)
     }
 }
@@ -314,15 +307,15 @@ impl<R: oio::BlockingRead> oio::BlockingRead for 
OtelTraceWrapper<R> {
 #[async_trait]
 impl<R: oio::Write> oio::Write for OtelTraceWrapper<R> {
     fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn oio::WriteBuf) -> 
Poll<Result<usize>> {
-        self.inner.write(bs).await
+        self.inner.poll_write(cx, bs)
     }
 
     fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
-        self.inner.abort().await
+        self.inner.poll_abort(cx)
     }
 
     fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
-        self.inner.close().await
+        self.inner.poll_close(cx)
     }
 }
 
diff --git a/core/src/raw/oio/write/multipart_upload_write.rs 
b/core/src/raw/oio/write/multipart_upload_write.rs
index 10353714d..55a28e84a 100644
--- a/core/src/raw/oio/write/multipart_upload_write.rs
+++ b/core/src/raw/oio/write/multipart_upload_write.rs
@@ -16,6 +16,9 @@
 // under the License.
 
 use async_trait::async_trait;
+use futures::future::BoxFuture;
+use std::sync::Arc;
+use std::task::{Context, Poll};
 
 use crate::raw::*;
 use crate::*;
@@ -78,40 +81,31 @@ pub struct MultipartUploadPart {
 
 /// MultipartUploadWriter will implements [`Write`] based on multipart
 /// uploads.
-///
-/// ## TODO
-///
-/// - Add threshold for `write_once` to avoid unnecessary multipart uploads.
-/// - Allow users to switch to un-buffered mode if users write 16MiB every 
time.
 pub struct MultipartUploadWriter<W: MultipartUploadWrite> {
-    inner: W,
+    state: State<W>,
 
-    upload_id: Option<String>,
+    upload_id: Option<Arc<String>>,
     parts: Vec<MultipartUploadPart>,
 }
 
+enum State<W> {
+    Idle(Option<W>),
+    Init(BoxFuture<'static, (W, Result<String>)>),
+    Write(BoxFuture<'static, (W, Result<(usize, MultipartUploadPart)>)>),
+    Close(BoxFuture<'static, (W, Result<()>)>),
+    Abort(BoxFuture<'static, (W, Result<()>)>),
+}
+
 impl<W: MultipartUploadWrite> MultipartUploadWriter<W> {
     /// Create a new MultipartUploadWriter.
     pub fn new(inner: W) -> Self {
         Self {
-            inner,
+            state: State::Idle(Some(inner)),
 
             upload_id: None,
             parts: Vec::new(),
         }
     }
-
-    /// Get the upload id. Initiate a new multipart upload if the upload id is 
empty.
-    pub async fn upload_id(&mut self) -> Result<String> {
-        match &self.upload_id {
-            Some(upload_id) => Ok(upload_id.to_string()),
-            None => {
-                let upload_id = self.inner.initiate_part().await?;
-                self.upload_id = Some(upload_id.clone());
-                Ok(upload_id)
-            }
-        }
-    }
 }
 
 #[async_trait]
@@ -120,40 +114,129 @@ where
     W: MultipartUploadWrite,
 {
     fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn oio::WriteBuf) -> 
Poll<Result<usize>> {
-        let upload_id = self.upload_id().await?;
-
-        let size = bs.remaining();
-
-        self.inner
-            .write_part(
-                &upload_id,
-                self.parts.len(),
-                size as u64,
-                AsyncBody::Bytes(bs.copy_to_bytes(size)),
-            )
-            .await
-            .map(|v| self.parts.push(v))?;
-
-        Ok(size)
+        loop {
+            match &mut self.state {
+                State::Idle(w) => {
+                    let w = w.take().expect("writer must be valid");
+                    match self.upload_id.as_ref() {
+                        Some(upload_id) => {
+                            let size = bs.remaining();
+                            let bs = bs.copy_to_bytes(size);
+                            let upload_id = upload_id.clone();
+                            let part_number = self.parts.len();
+
+                            self.state = State::Write(Box::pin(async move {
+                                let part = w
+                                    .write_part(
+                                        &upload_id,
+                                        part_number,
+                                        size as u64,
+                                        AsyncBody::Bytes(bs),
+                                    )
+                                    .await?;
+
+                                (w, Ok((size, part)))
+                            }));
+                        }
+                        None => {
+                            self.state = State::Init(Box::pin(async move {
+                                let upload_id = w.initiate_part().await;
+                                (w, upload_id)
+                            }));
+                        }
+                    }
+                }
+                State::Init(fut) => {
+                    let (w, upload_id) = 
futures::ready!(fut.as_mut().poll(cx));
+                    self.state = State::Idle(Some(w));
+                    self.upload_id = Some(Arc::new(upload_id?));
+                }
+                State::Write(fut) => {
+                    let (w, res) = futures::ready!(fut.as_mut().poll(cx));
+                    self.state = State::Idle(Some(w));
+
+                    let (written, part) = res?;
+                    self.parts.push(part);
+                    return Poll::Ready(Ok(written));
+                }
+                State::Close(_) => {
+                    unreachable!(
+                        "MultipartUploadWriter must not go into State:Close 
during poll_write"
+                    )
+                }
+                State::Abort(_) => {
+                    unreachable!(
+                        "MultipartUploadWriter must not go into State:Abort 
during poll_write"
+                    )
+                }
+            }
+        }
     }
 
     fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
-        let upload_id = if let Some(upload_id) = &self.upload_id {
-            upload_id
-        } else {
-            return Ok(());
-        };
-
-        self.inner.complete_part(upload_id, &self.parts).await
+        loop {
+            match &mut self.state {
+                State::Idle(w) => {
+                    let w = w.take().expect("writer must be valid");
+                    match &self.upload_id {
+                        Some(upload_id) => {
+                            let parts = self.parts.clone();
+                            self.state = State::Close(Box::pin(async move {
+                                let res = w.complete_part(&upload_id, 
&self.parts).await;
+                                (w, res)
+                            }));
+                        }
+                        None => return Poll::Ready(Ok(())),
+                    }
+                }
+                State::Close(fut) => {
+                    let (w, res) = futures::ready!(fut.as_mut().poll(cx));
+                    self.state = State::Idle(Some(w));
+                    return Poll::Ready(res);
+                }
+                State::Init(_) => unreachable!(
+                    "MultipartUploadWriter must not go into State:Init during 
poll_close"
+                ),
+                State::Write(_) => unreachable!(
+                    "MultipartUploadWriter must not go into State:Write during 
poll_close"
+                ),
+                State::Abort(_) => unreachable!(
+                    "MultipartUploadWriter must not go into State:Abort during 
poll_close"
+                ),
+            }
+        }
     }
 
     fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
-        let upload_id = if let Some(upload_id) = &self.upload_id {
-            upload_id
-        } else {
-            return Ok(());
-        };
-
-        self.inner.abort_part(upload_id).await
+        loop {
+            match &mut self.state {
+                State::Idle(w) => {
+                    let w = w.take().expect("writer must be valid");
+                    match &self.upload_id {
+                        Some(upload_id) => {
+                            self.state = State::Close(Box::pin(async move {
+                                let res = w.abort_part(&upload_id).await;
+                                (w, res)
+                            }));
+                        }
+                        None => return Poll::Ready(Ok(())),
+                    }
+                }
+                State::Abort(fut) => {
+                    let (w, res) = futures::ready!(fut.as_mut().poll(cx));
+                    self.state = State::Idle(Some(w));
+                    return Poll::Ready(res);
+                }
+                State::Init(_) => unreachable!(
+                    "MultipartUploadWriter must not go into State:Init during 
poll_abort"
+                ),
+                State::Write(_) => unreachable!(
+                    "MultipartUploadWriter must not go into State:Write during 
poll_abort"
+                ),
+                State::Close(_) => unreachable!(
+                    "MultipartUploadWriter must not go into State:Close during 
poll_abort"
+                ),
+            }
+        }
     }
 }

Reply via email to