This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch fix-timeout-layer
in repository https://gitbox.apache.org/repos/asf/opendal.git

commit 02320a866b921159abad51612cac49279a5356c9
Author: Xuanwo <[email protected]>
AuthorDate: Tue Jan 23 22:46:20 2024 +0800

    Implement timeout layer correctly
    
    Signed-off-by: Xuanwo <[email protected]>
---
 core/src/layers/timeout.rs | 333 ++++++++++++++-------------------------------
 1 file changed, 104 insertions(+), 229 deletions(-)

diff --git a/core/src/layers/timeout.rs b/core/src/layers/timeout.rs
index 6f50177f08..77e55c6012 100644
--- a/core/src/layers/timeout.rs
+++ b/core/src/layers/timeout.rs
@@ -17,13 +17,14 @@
 
 use std::future::Future;
 use std::io::SeekFrom;
-use std::task::Context;
+use std::pin::Pin;
 use std::task::Poll;
+use std::task::{ready, Context};
 use std::time::Duration;
-use std::time::Instant;
 
 use async_trait::async_trait;
 use bytes::Bytes;
+use tokio::time::Sleep;
 
 use crate::raw::oio::ListOperation;
 use crate::raw::oio::ReadOperation;
@@ -46,15 +47,10 @@ use crate::*;
 /// - IO Operation like `read`, `Reader::read` and `Writer::write`, they 
operate on data directly, we
 ///   control them by setting `io_timeout`.
 ///
-/// It happens that a connection could be slow but not dead, so we have a 
`max_io_timeouts` to
-/// control how many consecutive IO timeouts we can tolerate. If 
`max_io_timeouts` is not reached,
-/// we will print a warning and keep waiting this io operation instead.
-///
 /// # Default
 ///
 /// - timeout: 60 seconds
 /// - io_timeout: 10 seconds
-/// - max_io_timeouts: 3 times
 ///
 /// # Examples
 ///
@@ -71,14 +67,13 @@ use crate::*;
 ///
 /// let _ = Operator::new(services::Memory::default())
 ///     .expect("must init")
-///     
.layer(TimeoutLayer::default().with_timeout(Duration::from_secs(10)).with_io_timeout(3).with_max_io_timeouts(2))
+///     
.layer(TimeoutLayer::default().with_timeout(Duration::from_secs(10)).with_io_timeout(3))
 ///     .finish();
 /// ```
 #[derive(Clone)]
 pub struct TimeoutLayer {
     timeout: Duration,
     io_timeout: Duration,
-    max_io_timeouts: usize,
 }
 
 impl Default for TimeoutLayer {
@@ -86,7 +81,6 @@ impl Default for TimeoutLayer {
         Self {
             timeout: Duration::from_secs(60),
             io_timeout: Duration::from_secs(10),
-            max_io_timeouts: 3,
         }
     }
 }
@@ -113,14 +107,6 @@ impl TimeoutLayer {
         self
     }
 
-    /// Set max io timeouts for TimeoutLayer with given value.
-    ///
-    /// This value is used to control how many consecutive io timeouts we can 
tolerate.
-    pub fn with_max_io_timeouts(mut self, v: usize) -> Self {
-        self.max_io_timeouts = v;
-        self
-    }
-
     /// Set speed for TimeoutLayer with given value.
     ///
     /// # Notes
@@ -132,7 +118,7 @@ impl TimeoutLayer {
     ///
     /// This function will panic if speed is 0.
     #[deprecated(note = "with speed is not supported anymore, please use 
with_io_timeout instead")]
-    pub fn with_speed(mut self, _: u64) -> Self {
+    pub fn with_speed(self, _: u64) -> Self {
         self
     }
 }
@@ -146,7 +132,6 @@ impl<A: Accessor> Layer<A> for TimeoutLayer {
 
             timeout: self.timeout,
             io_timeout: self.io_timeout,
-            max_io_timeouts: self.max_io_timeouts,
         }
     }
 }
@@ -157,10 +142,18 @@ pub struct TimeoutAccessor<A: Accessor> {
 
     timeout: Duration,
     io_timeout: Duration,
-    max_io_timeouts: usize,
 }
 
 impl<A: Accessor> TimeoutAccessor<A> {
+    async fn timeout<F: Future<Output = Result<T>>, T>(&self, op: Operation, 
fut: F) -> Result<T> {
+        tokio::time::timeout(self.timeout, fut).await.map_err(|_| {
+            Error::new(ErrorKind::Unexpected, "operation timeout reached")
+                .with_operation(op)
+                .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
+                .set_temporary()
+        })?
+    }
+
     async fn io_timeout<F: Future<Output = Result<T>>, T>(
         &self,
         op: Operation,
@@ -169,9 +162,9 @@ impl<A: Accessor> TimeoutAccessor<A> {
         tokio::time::timeout(self.io_timeout, fut)
             .await
             .map_err(|_| {
-                Error::new(ErrorKind::Unexpected, "io operation timeout 
reached")
+                Error::new(ErrorKind::Unexpected, "io timeout reached")
                     .with_operation(op)
-                    .with_context("io_timeout", 
self.io_timeout.as_secs_f64().to_string())
+                    .with_context("timeout", 
self.io_timeout.as_secs_f64().to_string())
                     .set_temporary()
             })?
     }
@@ -192,37 +185,56 @@ impl<A: Accessor> LayeredAccessor for TimeoutAccessor<A> {
         &self.inner
     }
 
+    async fn create_dir(&self, path: &str, args: OpCreateDir) -> 
Result<RpCreateDir> {
+        self.timeout(Operation::CreateDir, self.inner.create_dir(path, args))
+            .await
+    }
+
     async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, 
Self::Reader)> {
         self.io_timeout(Operation::Read, self.inner.read(path, args))
             .await
-            .map(|(rp, r)| {
-                (
-                    rp,
-                    TimeoutWrapper::new(r, self.io_timeout, 
self.max_io_timeouts),
-                )
-            })
+            .map(|(rp, r)| (rp, TimeoutWrapper::new(r, self.io_timeout)))
     }
 
     async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, 
Self::Writer)> {
         self.io_timeout(Operation::Write, self.inner.write(path, args))
             .await
-            .map(|(rp, r)| {
-                (
-                    rp,
-                    TimeoutWrapper::new(r, self.io_timeout, 
self.max_io_timeouts),
-                )
-            })
+            .map(|(rp, r)| (rp, TimeoutWrapper::new(r, self.io_timeout)))
+    }
+
+    async fn copy(&self, from: &str, to: &str, args: OpCopy) -> Result<RpCopy> 
{
+        self.timeout(Operation::Copy, self.inner.copy(from, to, args))
+            .await
+    }
+
+    async fn rename(&self, from: &str, to: &str, args: OpRename) -> 
Result<RpRename> {
+        self.timeout(Operation::Rename, self.inner.rename(from, to, args))
+            .await
+    }
+
+    async fn stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
+        self.timeout(Operation::Stat, self.inner.stat(path, args))
+            .await
+    }
+
+    async fn delete(&self, path: &str, args: OpDelete) -> Result<RpDelete> {
+        self.timeout(Operation::Delete, self.inner.delete(path, args))
+            .await
     }
 
     async fn list(&self, path: &str, args: OpList) -> Result<(RpList, 
Self::Lister)> {
         self.io_timeout(Operation::List, self.inner.list(path, args))
             .await
-            .map(|(rp, r)| {
-                (
-                    rp,
-                    TimeoutWrapper::new(r, self.io_timeout, 
self.max_io_timeouts),
-                )
-            })
+            .map(|(rp, r)| (rp, TimeoutWrapper::new(r, self.io_timeout)))
+    }
+
+    async fn batch(&self, args: OpBatch) -> Result<RpBatch> {
+        self.timeout(Operation::Batch, self.inner.batch(args)).await
+    }
+
+    async fn presign(&self, path: &str, args: OpPresign) -> Result<RpPresign> {
+        self.timeout(Operation::Presign, self.inner.presign(path, args))
+            .await
     }
 
     fn blocking_read(&self, path: &str, args: OpRead) -> Result<(RpRead, 
Self::BlockingReader)> {
@@ -242,235 +254,98 @@ pub struct TimeoutWrapper<R> {
     inner: R,
 
     timeout: Duration,
-    max_timeouts: usize,
-
-    current_timeouts: usize,
-    futures: Option<(BoxedFuture)>,
+    sleep: Option<Pin<Box<Sleep>>>,
 }
 
 impl<R> TimeoutWrapper<R> {
-    fn new(inner: R, timeout: Duration, max_timeouts: usize) -> Self {
+    fn new(inner: R, timeout: Duration) -> Self {
         Self {
             inner,
             timeout,
-            max_timeouts,
-            start: None,
+            sleep: None,
         }
     }
-}
 
-impl<R: oio::Read> oio::Read for TimeoutWrapper<R> {
-    fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> 
Poll<Result<usize>> {
-        match self.start {
-            Some(start) => {
-                if start.elapsed() > self.timeout {
-                    // Clean up the start time before return ready.
-                    self.start = None;
-
-                    return Poll::Ready(Err(Error::new(
-                        ErrorKind::Unexpected,
-                        "operation timeout",
+    #[inline]
+    fn poll_timeout(&mut self, cx: &mut Context<'_>, op: &'static str) -> 
Result<()> {
+        if let Some(sleep) = self.sleep.as_mut() {
+            match sleep.as_mut().poll(cx) {
+                Poll::Pending => Ok(()),
+                Poll::Ready(_) => {
+                    self.sleep = None;
+                    Err(
+                        Error::new(ErrorKind::Unexpected, "io operation 
timeout reached")
+                            .with_operation(op)
+                            .with_context("io_timeout", 
self.timeout.as_secs_f64().to_string())
+                            .set_temporary(),
                     )
-                    .with_operation(ReadOperation::Read)
-                    .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
-                    .set_temporary()));
                 }
             }
-            None => {
-                self.start = Some(Instant::now());
-            }
+        } else {
+            self.sleep = Some(Box::pin(tokio::time::sleep(self.timeout)));
+            Ok(())
         }
+    }
+}
 
-        match self.inner.poll_read(cx, buf) {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(v) => {
-                self.start = None;
-                Poll::Ready(v)
-            }
-        }
+impl<R: oio::Read> oio::Read for TimeoutWrapper<R> {
+    fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> 
Poll<Result<usize>> {
+        self.poll_timeout(cx, ReadOperation::Read.into_static())?;
+
+        let v = ready!(self.inner.poll_read(cx, buf));
+        self.sleep = None;
+        Poll::Ready(v)
     }
 
     fn poll_seek(&mut self, cx: &mut Context<'_>, pos: SeekFrom) -> 
Poll<Result<u64>> {
-        match self.start {
-            Some(start) => {
-                if start.elapsed() > self.timeout {
-                    // Clean up the start time before return ready.
-                    self.start = None;
-
-                    return Poll::Ready(Err(Error::new(
-                        ErrorKind::Unexpected,
-                        "operation timeout",
-                    )
-                    .with_operation(ReadOperation::Seek)
-                    .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
-                    .set_temporary()));
-                }
-            }
-            None => {
-                self.start = Some(Instant::now());
-            }
-        }
+        self.poll_timeout(cx, ReadOperation::Seek.into_static())?;
 
-        match self.inner.poll_seek(cx, pos) {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(v) => {
-                self.start = None;
-                Poll::Ready(v)
-            }
-        }
+        let v = ready!(self.inner.poll_seek(cx, pos));
+        self.sleep = None;
+        Poll::Ready(v)
     }
 
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Option<Result<Bytes>>> {
-        match self.start {
-            Some(start) => {
-                if start.elapsed() > self.timeout {
-                    // Clean up the start time before return ready.
-                    self.start = None;
-
-                    return Poll::Ready(Some(Err(Error::new(
-                        ErrorKind::Unexpected,
-                        "operation timeout",
-                    )
-                    .with_operation(ReadOperation::Next)
-                    .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
-                    .set_temporary())));
-                }
-            }
-            None => {
-                self.start = Some(Instant::now());
-            }
-        }
+        self.poll_timeout(cx, ReadOperation::Next.into_static())?;
 
-        match self.inner.poll_next(cx) {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(v) => {
-                self.start = None;
-                Poll::Ready(v)
-            }
-        }
+        let v = ready!(self.inner.poll_next(cx));
+        self.sleep = None;
+        Poll::Ready(v)
     }
 }
 
 impl<R: oio::Write> oio::Write for TimeoutWrapper<R> {
     fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn oio::WriteBuf) -> 
Poll<Result<usize>> {
-        match self.start {
-            Some(start) => {
-                if start.elapsed() > self.timeout {
-                    // Clean up the start time before return ready.
-                    self.start = None;
-
-                    return Poll::Ready(Err(Error::new(
-                        ErrorKind::Unexpected,
-                        "operation timeout",
-                    )
-                    .with_operation(WriteOperation::Write)
-                    .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
-                    .set_temporary()));
-                }
-            }
-            None => {
-                self.start = Some(Instant::now());
-            }
-        }
+        self.poll_timeout(cx, WriteOperation::Write.into_static())?;
 
-        match self.inner.poll_write(cx, bs) {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(v) => {
-                self.start = None;
-                Poll::Ready(v)
-            }
-        }
+        let v = ready!(self.inner.poll_write(cx, bs));
+        self.sleep = None;
+        Poll::Ready(v)
     }
 
     fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
-        match self.start {
-            Some(start) => {
-                if start.elapsed() > self.timeout {
-                    // Clean up the start time before return ready.
-                    self.start = None;
-
-                    return Poll::Ready(Err(Error::new(
-                        ErrorKind::Unexpected,
-                        "operation timeout",
-                    )
-                    .with_operation(WriteOperation::Abort)
-                    .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
-                    .set_temporary()));
-                }
-            }
-            None => {
-                self.start = Some(Instant::now());
-            }
-        }
+        self.poll_timeout(cx, WriteOperation::Abort.into_static())?;
 
-        match self.inner.poll_abort(cx) {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(v) => {
-                self.start = None;
-                Poll::Ready(v)
-            }
-        }
+        let v = ready!(self.inner.poll_abort(cx));
+        self.sleep = None;
+        Poll::Ready(v)
     }
 
     fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
-        match self.start {
-            Some(start) => {
-                if start.elapsed() > self.timeout {
-                    // Clean up the start time before return ready.
-                    self.start = None;
-
-                    return Poll::Ready(Err(Error::new(
-                        ErrorKind::Unexpected,
-                        "operation timeout",
-                    )
-                    .with_operation(WriteOperation::Close)
-                    .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
-                    .set_temporary()));
-                }
-            }
-            None => {
-                self.start = Some(Instant::now());
-            }
-        }
+        self.poll_timeout(cx, WriteOperation::Close.into_static())?;
 
-        match self.inner.poll_close(cx) {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(v) => {
-                self.start = None;
-                Poll::Ready(v)
-            }
-        }
+        let v = ready!(self.inner.poll_close(cx));
+        self.sleep = None;
+        Poll::Ready(v)
     }
 }
 
 impl<R: oio::List> oio::List for TimeoutWrapper<R> {
     fn poll_next(&mut self, cx: &mut Context<'_>) -> 
Poll<Result<Option<oio::Entry>>> {
-        match self.start {
-            Some(start) => {
-                if start.elapsed() > self.timeout {
-                    // Clean up the start time before return ready.
-                    self.start = None;
-
-                    return Poll::Ready(Err(Error::new(
-                        ErrorKind::Unexpected,
-                        "operation timeout",
-                    )
-                    .with_operation(ListOperation::Next)
-                    .with_context("timeout", 
self.timeout.as_secs_f64().to_string())
-                    .set_temporary()));
-                }
-            }
-            None => {
-                self.start = Some(Instant::now());
-            }
-        }
+        self.poll_timeout(cx, ListOperation::Next.into_static())?;
 
-        match self.inner.poll_next(cx) {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(v) => {
-                self.start = None;
-                Poll::Ready(v)
-            }
-        }
+        let v = ready!(self.inner.poll_next(cx));
+        self.sleep = None;
+        Poll::Ready(v)
     }
 }

Reply via email to