This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch stream-based-write
in repository https://gitbox.apache.org/repos/asf/incubator-opendal.git

commit 44559b30e721f94407922e476ae66eaa5825ea29
Author: Xuanwo <[email protected]>
AuthorDate: Tue Aug 29 18:50:20 2023 +0800

    Add size in stream
    
    Signed-off-by: Xuanwo <[email protected]>
---
 core/src/raw/http_util/client.rs                   |   5 +-
 core/src/raw/http_util/multipart.rs                | 115 ++++++++++++++-------
 core/src/raw/oio/cursor.rs                         |   8 ++
 core/src/raw/oio/stream/api.rs                     |  38 ++++++-
 core/src/raw/oio/stream/into_stream.rs             |  17 ++-
 core/src/raw/oio/stream/into_stream_from_reader.rs |   9 +-
 core/src/services/webhdfs/error.rs                 |   5 +-
 core/src/types/writer.rs                           |   4 +-
 8 files changed, 153 insertions(+), 48 deletions(-)

diff --git a/core/src/raw/http_util/client.rs b/core/src/raw/http_util/client.rs
index d0926d1cd..ce39cb85b 100644
--- a/core/src/raw/http_util/client.rs
+++ b/core/src/raw/http_util/client.rs
@@ -155,7 +155,10 @@ impl HttpClient {
                 .set_source(err)
         });
 
-        let body = IncomingAsyncBody::new(Box::new(oio::into_stream(stream)), 
content_length);
+        let body = IncomingAsyncBody::new(
+            Box::new(oio::into_stream(content_length.unwrap_or_default(), 
stream)),
+            content_length,
+        );
 
         let resp = hr.body(body).expect("response must build succeed");
 
diff --git a/core/src/raw/http_util/multipart.rs 
b/core/src/raw/http_util/multipart.rs
index edc061c1f..6ec4e9c09 100644
--- a/core/src/raw/http_util/multipart.rs
+++ b/core/src/raw/http_util/multipart.rs
@@ -112,7 +112,7 @@ impl<T: Part> Multipart<T> {
         Ok(self)
     }
 
-    pub(crate) fn build(self) -> (u64, MultipartStream<T>) {
+    pub(crate) fn build(self) -> MultipartStream<T> {
         let mut total_size = 0;
 
         let mut bs = BytesMut::new();
@@ -125,8 +125,8 @@ impl<T: Part> Multipart<T> {
         let mut parts = VecDeque::new();
         // Write headers.
         for v in self.parts.into_iter() {
-            let (size, stream) = v.format();
-            total_size += pre_part.len() as u64 + size;
+            let mut stream = v.format();
+            total_size += stream.size();
             parts.push_back(stream);
         }
 
@@ -140,15 +140,13 @@ impl<T: Part> Multipart<T> {
 
         total_size += final_part.len() as u64;
 
-        (
-            total_size,
-            MultipartStream {
-                pre_part,
-                pre_part_consumed: false,
-                parts,
-                final_part: Some(final_part),
-            },
-        )
+        MultipartStream {
+            size: total_size,
+            pre_part,
+            pre_part_consumed: false,
+            parts,
+            final_part: Some(final_part),
+        }
     }
 
     /// Consume the input and generate a request with multipart body.
@@ -156,7 +154,7 @@ impl<T: Part> Multipart<T> {
     /// This function will make sure content_type and content_length set 
correctly.
     pub fn apply(self, mut builder: http::request::Builder) -> 
Result<Request<AsyncBody>> {
         let boundary = self.boundary.clone();
-        let (content_length, stream) = self.build();
+        let mut stream = self.build();
 
         // Insert content type with correct boundary.
         builder = builder.header(
@@ -164,7 +162,7 @@ impl<T: Part> Multipart<T> {
             format!("multipart/{}; boundary={}", T::TYPE, boundary).as_str(),
         );
         // Insert content length with calculated size.
-        builder = builder.header(CONTENT_LENGTH, content_length);
+        builder = builder.header(CONTENT_LENGTH, stream.size());
 
         builder
             .body(AsyncBody::Stream(Box::new(stream)))
@@ -173,6 +171,8 @@ impl<T: Part> Multipart<T> {
 }
 
 pub struct MultipartStream<T: Part> {
+    size: u64,
+
     pre_part: Bytes,
     pre_part_consumed: bool,
 
@@ -182,10 +182,15 @@ pub struct MultipartStream<T: Part> {
 }
 
 impl<T: Part> Stream for MultipartStream<T> {
+    fn size(&mut self) -> u64 {
+        self.size
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         if let Some(stream) = self.parts.front_mut() {
             if !self.pre_part_consumed {
                 self.pre_part_consumed = true;
+                self.size -= self.pre_part.len() as u64;
                 return Poll::Ready(Some(Ok(self.pre_part.clone())));
             }
             return match ready!(stream.poll_next(cx)) {
@@ -194,11 +199,18 @@ impl<T: Part> Stream for MultipartStream<T> {
                     self.parts.pop_front();
                     return self.poll_next(cx);
                 }
-                Some(v) => Poll::Ready(Some(v)),
+                Some(v) => {
+                    let v = v.map(|bs| {
+                        self.size -= bs.len() as u64;
+                        bs
+                    });
+                    Poll::Ready(Some(v))
+                }
             };
         }
 
         if let Some(final_part) = self.final_part.take() {
+            self.size -= final_part.len() as u64;
             return Poll::Ready(Some(Ok(final_part)));
         }
 
@@ -224,7 +236,7 @@ pub trait Part: Sized + 'static {
     type STREAM: Stream;
 
     /// format will generates the bytes.
-    fn format(self) -> (u64, Self::STREAM);
+    fn format(self) -> Self::STREAM;
 
     /// parse will parse the bytes into a part.
     fn parse(s: &str) -> Result<Self>;
@@ -286,7 +298,7 @@ impl Part for FormDataPart {
     const TYPE: &'static str = "form-data";
     type STREAM = FormDataPartStream;
 
-    fn format(self) -> (u64, FormDataPartStream) {
+    fn format(self) -> FormDataPartStream {
         let mut bs = BytesMut::new();
 
         // Building pre-content.
@@ -302,13 +314,11 @@ impl Part for FormDataPart {
         // pre-content + content + post-content (b`\r\n`)
         let total_size = bs.len() as u64 + self.content_length + 2;
 
-        (
-            total_size,
-            FormDataPartStream {
-                pre_content: Some(bs),
-                content: Some(self.content),
-            },
-        )
+        FormDataPartStream {
+            size: total_size,
+            pre_content: Some(bs),
+            content: Some(self.content),
+        }
     }
 
     fn parse(_: &str) -> Result<Self> {
@@ -320,6 +330,7 @@ impl Part for FormDataPart {
 }
 
 pub struct FormDataPartStream {
+    size: u64,
     /// Including headers and the first `b\r\n`
     pre_content: Option<Bytes>,
     content: Option<Streamer>,
@@ -327,8 +338,13 @@ pub struct FormDataPartStream {
 
 #[async_trait]
 impl Stream for FormDataPartStream {
+    fn size(&mut self) -> u64 {
+        self.size
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         if let Some(pre_content) = self.pre_content.take() {
+            self.size -= pre_content.len() as u64;
             return Poll::Ready(Some(Ok(pre_content)));
         }
 
@@ -338,7 +354,13 @@ impl Stream for FormDataPartStream {
                     self.content = None;
                     Poll::Ready(Some(Ok(Bytes::from_static(b"\r\n"))))
                 }
-                Some(v) => Poll::Ready(Some(v)),
+                Some(v) => {
+                    let v = v.map(|bs| {
+                        self.size -= bs.len() as u64;
+                        bs
+                    });
+                    Poll::Ready(Some(v))
+                }
             };
         }
 
@@ -455,7 +477,7 @@ impl MixedPart {
         let body = if let Some(stream) = self.content {
             IncomingAsyncBody::new(stream, Some(self.content_length))
         } else {
-            
IncomingAsyncBody::new(Box::new(oio::into_stream(stream::empty())), Some(0))
+            IncomingAsyncBody::new(Box::new(oio::into_stream(0, 
stream::empty())), Some(0))
         };
 
         builder
@@ -508,7 +530,7 @@ impl Part for MixedPart {
     const TYPE: &'static str = "mixed";
     type STREAM = MixedPartStream;
 
-    fn format(self) -> (u64, Self::STREAM) {
+    fn format(self) -> Self::STREAM {
         let mut bs = BytesMut::new();
 
         // Write parts headers.
@@ -576,13 +598,11 @@ impl Part for MixedPart {
             total_size += self.content_length + 2;
         }
 
-        (
-            total_size,
-            MixedPartStream {
-                pre_content: Some(bs),
-                content: self.content,
-            },
-        )
+        MixedPartStream {
+            size: total_size,
+            pre_content: Some(bs),
+            content: self.content,
+        }
     }
 
     /// TODO
@@ -674,14 +694,20 @@ impl Part for MixedPart {
 }
 
 pub struct MixedPartStream {
+    size: u64,
     /// Including headers and the first `b\r\n`
     pre_content: Option<Bytes>,
     content: Option<Streamer>,
 }
 
 impl Stream for MixedPartStream {
+    fn size(&mut self) -> u64 {
+        self.size
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         if let Some(pre_content) = self.pre_content.take() {
+            self.size -= pre_content.len() as u64;
             return Poll::Ready(Some(Ok(pre_content)));
         }
 
@@ -689,9 +715,16 @@ impl Stream for MixedPartStream {
             return match ready!(stream.poll_next(cx)) {
                 None => {
                     self.content = None;
+                    self.size -= 2;
                     Poll::Ready(Some(Ok(Bytes::from_static(b"\r\n"))))
                 }
-                Some(v) => Poll::Ready(Some(v)),
+                Some(v) => {
+                    let v = v.map(|bs| {
+                        self.size -= bs.len() as u64;
+                        bs
+                    });
+                    Poll::Ready(Some(v))
+                }
             };
         }
 
@@ -722,7 +755,8 @@ mod tests {
             .part(FormDataPart::new("foo").content(Bytes::from("bar")))
             .part(FormDataPart::new("hello").content(Bytes::from("world")));
 
-        let (size, body) = multipart.build();
+        let mut body = multipart.build();
+        let size = body.size();
         let bs = body.collect().await.unwrap();
         assert_eq!(size, bs.len() as u64);
 
@@ -758,7 +792,8 @@ mod tests {
             
.part(FormDataPart::new("Signature").content("0RavWzkygo6QX9caELEqKi9kDbU="))
             .part(FormDataPart::new("file").header(CONTENT_TYPE, 
"image/jpeg".parse().unwrap()).content("...file 
content...")).part(FormDataPart::new("submit").content("Upload to Amazon S3"));
 
-        let (size, body) = multipart.build();
+        let mut body = multipart.build();
+        let size = body.size();
         let bs = body.collect().await?;
         assert_eq!(size, bs.len() as u64);
 
@@ -882,7 +917,8 @@ Upload to Amazon S3
                     .content(r#"{"metadata": {"type": "calico"}}"#),
             );
 
-        let (size, body) = multipart.build();
+        let mut body = multipart.build();
+        let size = body.size();
         let bs = body.collect().await?;
         assert_eq!(size, bs.len() as u64);
 
@@ -988,7 +1024,8 @@ content-length: 32
                     .header("content-length".parse().unwrap(), 
"0".parse().unwrap()),
             );
 
-        let (size, body) = multipart.build();
+        let mut body = multipart.build();
+        let size = body.size();
         let bs = body.collect().await?;
         assert_eq!(size, bs.len() as u64);
 
diff --git a/core/src/raw/oio/cursor.rs b/core/src/raw/oio/cursor.rs
index e02522cc6..74f97ef4d 100644
--- a/core/src/raw/oio/cursor.rs
+++ b/core/src/raw/oio/cursor.rs
@@ -161,6 +161,10 @@ impl oio::BlockingRead for Cursor {
 }
 
 impl oio::Stream for Cursor {
+    fn size(&mut self) -> u64 {
+        self.inner.len() as u64 - self.pos
+    }
+
     fn poll_next(&mut self, _: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         if self.is_empty() {
             return Poll::Ready(None);
@@ -327,6 +331,10 @@ impl ChunkedCursor {
 }
 
 impl oio::Stream for ChunkedCursor {
+    fn size(&mut self) -> u64 {
+        self.len() as u64
+    }
+
     fn poll_next(&mut self, _: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         if self.is_empty() {
             return Poll::Ready(None);
diff --git a/core/src/raw/oio/stream/api.rs b/core/src/raw/oio/stream/api.rs
index 132fb78fc..fcf7b8b5c 100644
--- a/core/src/raw/oio/stream/api.rs
+++ b/core/src/raw/oio/stream/api.rs
@@ -36,7 +36,15 @@ pub type Streamer = Box<dyn Stream>;
 /// It's nearly the same with [`futures::Stream`], but it satisfied
 /// `Unpin` + `Send` + `Sync`. And the item is `Result<Bytes>`.
 pub trait Stream: Unpin + Send + Sync {
-    /// Poll next item `Result<Bytes>` from the stream.
+    /// Fetch remaining size of this stream.
+    ///
+    /// # NOTES
+    ///
+    /// It's by design that we take `&mut self` here to make sure we don't 
have other
+    /// threads reading the same stream at the same time.
+    fn size(&mut self) -> u64;
+
+    /// Fetch next item `Result<Bytes>` from the stream.
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>>;
 
     /// Reset this stream to the beginning.
@@ -44,6 +52,10 @@ pub trait Stream: Unpin + Send + Sync {
 }
 
 impl Stream for () {
+    fn size(&mut self) -> u64 {
+        unimplemented!("size is required to be implemented for oio::Stream")
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         let _ = cx;
 
@@ -60,6 +72,10 @@ impl Stream for () {
 /// `Box<dyn Stream>` won't implement `Stream` automatically.
 /// To make Streamer work as expected, we must add this impl.
 impl<T: Stream + ?Sized> Stream for Box<T> {
+    fn size(&mut self) -> u64 {
+        (**self).size()
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         (**self).poll_next(cx)
     }
@@ -70,6 +86,13 @@ impl<T: Stream + ?Sized> Stream for Box<T> {
 }
 
 impl<T: Stream + ?Sized> Stream for Arc<std::sync::Mutex<T>> {
+    fn size(&mut self) -> u64 {
+        match self.try_lock() {
+            Ok(mut this) => this.size(),
+            Err(_) => panic!("the stream is expected to have only one 
consumer, but it's not"),
+        }
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         match self.try_lock() {
             Ok(mut this) => this.poll_next(cx),
@@ -92,6 +115,13 @@ impl<T: Stream + ?Sized> Stream for 
Arc<std::sync::Mutex<T>> {
 }
 
 impl<T: Stream + ?Sized> Stream for Arc<tokio::sync::Mutex<T>> {
+    fn size(&mut self) -> u64 {
+        match self.try_lock() {
+            Ok(mut this) => this.size(),
+            Err(_) => panic!("the stream is expected to have only one 
consumer, but it's not"),
+        }
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         match self.try_lock() {
             Ok(mut this) => this.poll_next(cx),
@@ -113,6 +143,8 @@ impl<T: Stream + ?Sized> Stream for 
Arc<tokio::sync::Mutex<T>> {
     }
 }
 
+/// TODO: implement `FusedStream` for `Stream`
+/// TODO: implement `fn size_hint(&self) -> (usize, Option<usize>)` for 
`Stream`
 impl futures::Stream for dyn Stream {
     type Item = Result<Bytes>;
 
@@ -206,6 +238,10 @@ pub struct Chain<S1: Stream, S2: Stream> {
 }
 
 impl<S1: Stream, S2: Stream> Stream for Chain<S1, S2> {
+    fn size(&mut self) -> u64 {
+        self.first.as_mut().map(|v| v.size()).unwrap_or_default() + 
self.second.size()
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         if let Some(first) = self.first.as_mut() {
             if let Some(item) = ready!(first.poll_next(cx)) {
diff --git a/core/src/raw/oio/stream/into_stream.rs 
b/core/src/raw/oio/stream/into_stream.rs
index fd01f3fcc..81e2301ae 100644
--- a/core/src/raw/oio/stream/into_stream.rs
+++ b/core/src/raw/oio/stream/into_stream.rs
@@ -25,14 +25,18 @@ use crate::raw::*;
 use crate::*;
 
 /// Convert given futures stream into [`oio::Stream`].
-pub fn into_stream<S>(stream: S) -> IntoStream<S>
+pub fn into_stream<S>(size: u64, stream: S) -> IntoStream<S>
 where
     S: futures::Stream<Item = Result<Bytes>> + Send + Sync + Unpin,
 {
-    IntoStream { inner: stream }
+    IntoStream {
+        size,
+        inner: stream,
+    }
 }
 
 pub struct IntoStream<S> {
+    size: u64,
     inner: S,
 }
 
@@ -40,8 +44,15 @@ impl<S> oio::Stream for IntoStream<S>
 where
     S: futures::Stream<Item = Result<Bytes>> + Send + Sync + Unpin,
 {
+    fn size(&mut self) -> u64 {
+        self.size
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
-        self.inner.try_poll_next_unpin(cx)
+        self.inner.try_poll_next_unpin(cx).map_ok(|v| {
+            self.size -= v.len() as u64;
+            v
+        })
     }
 
     fn poll_reset(&mut self, _: &mut Context<'_>) -> Poll<Result<()>> {
diff --git a/core/src/raw/oio/stream/into_stream_from_reader.rs 
b/core/src/raw/oio/stream/into_stream_from_reader.rs
index d8b29bff0..a038695ef 100644
--- a/core/src/raw/oio/stream/into_stream_from_reader.rs
+++ b/core/src/raw/oio/stream/into_stream_from_reader.rs
@@ -33,18 +33,20 @@ use crate::*;
 const DEFAULT_CAPACITY: usize = 64 * 1024;
 
 /// Convert given futures reader into [`oio::Stream`].
-pub fn into_stream_from_reader<R>(r: R) -> FromReaderStream<R>
+pub fn into_stream_from_reader<R>(size: u64, r: R) -> FromReaderStream<R>
 where
     R: AsyncRead + Send + Sync + Unpin,
 {
     FromReaderStream {
         inner: Some(r),
+        size,
         buf: BytesMut::new(),
     }
 }
 
 pub struct FromReaderStream<R> {
     inner: Option<R>,
+    size: u64,
     buf: BytesMut,
 }
 
@@ -52,6 +54,10 @@ impl<S> oio::Stream for FromReaderStream<S>
 where
     S: AsyncRead + Send + Sync + Unpin,
 {
+    fn size(&mut self) -> u64 {
+        self.size
+    }
+
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
         let reader = match self.inner.as_mut() {
             Some(r) => r,
@@ -77,6 +83,7 @@ where
             Ok(n) => {
                 // Safety: read_exact makes sure this buffer has been filled.
                 unsafe { self.buf.advance_mut(n) }
+                self.size -= n as u64;
 
                 let chunk = self.buf.split();
                 Poll::Ready(Some(Ok(chunk.freeze())))
diff --git a/core/src/services/webhdfs/error.rs 
b/core/src/services/webhdfs/error.rs
index 449725e63..e7fced334 100644
--- a/core/src/services/webhdfs/error.rs
+++ b/core/src/services/webhdfs/error.rs
@@ -100,7 +100,10 @@ mod tests {
     "#,
         );
         let body = IncomingAsyncBody::new(
-            
Box::new(oio::into_stream(stream::iter(vec![Ok(ill_args.clone())]))),
+            Box::new(oio::into_stream(
+                ill_args.len() as u64,
+                stream::iter(vec![Ok(ill_args.clone())]),
+            )),
             None,
         );
         let resp = Response::builder()
diff --git a/core/src/types/writer.rs b/core/src/types/writer.rs
index 204508f8f..f3011e69e 100644
--- a/core/src/types/writer.rs
+++ b/core/src/types/writer.rs
@@ -131,7 +131,7 @@ impl Writer {
         T: Into<Bytes>,
     {
         if let State::Idle(Some(w)) = &mut self.state {
-            let s = Box::new(oio::into_stream(sink_from.map_ok(|v| v.into())));
+            let s = Box::new(oio::into_stream(size, sink_from.map_ok(|v| 
v.into())));
             w.write(size, s).await
         } else {
             unreachable!(
@@ -176,7 +176,7 @@ impl Writer {
         R: futures::AsyncRead + Send + Sync + Unpin + 'static,
     {
         if let State::Idle(Some(w)) = &mut self.state {
-            let s = Box::new(oio::into_stream_from_reader(read_from));
+            let s = Box::new(oio::into_stream_from_reader(size, read_from));
             w.write(size, s).await
         } else {
             unreachable!(

Reply via email to