This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 0373a9d77 Implement fallible streams for `FlightClient::do_put` (#3464)
0373a9d77 is described below

commit 0373a9d77f446918d44b1ee216ed33de3905b688
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Feb 23 10:57:10 2023 +0100

    Implement fallible streams for `FlightClient::do_put` (#3464)
    
    * Implement fallible streams for do_put
    
    * Another approach to error wrapping
    
    * implement basic client error test
    
    * Add last error test
    
    * comments
    
    * fix docs
    
    * Simplify
    
    ---------
    
    Co-authored-by: Raphael Taylor-Davies <[email protected]>
---
 arrow-flight/src/client.rs   | 58 +++++++++++++++++++-------
 arrow-flight/tests/client.rs | 99 ++++++++++++++++++++++++++++++++++++++++----
 2 files changed, 136 insertions(+), 21 deletions(-)

diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index bdd51dda4..fe1292fcf 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -15,6 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::task::Poll;
+
 use crate::{
     decode::FlightRecordBatchStream, 
flight_service_client::FlightServiceClient, Action,
     ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
@@ -24,8 +26,9 @@ use arrow_schema::Schema;
 use bytes::Bytes;
 use futures::{
     future::ready,
+    ready,
     stream::{self, BoxStream},
-    Stream, StreamExt, TryStreamExt,
+    FutureExt, Stream, StreamExt, TryStreamExt,
 };
 use tonic::{metadata::MetadataMap, transport::Channel};
 
@@ -262,6 +265,15 @@ impl FlightClient {
     /// [`Stream`](futures::Stream) of [`FlightData`] and returning a
     /// stream of [`PutResult`].
     ///
+    /// # Note
+    ///
+    /// The input stream is [`Result`] so that this can be connected
+    /// to a streaming data source, such as 
[`FlightDataEncoder`](crate::encode::FlightDataEncoder),
+    /// without having to buffer. If the input stream returns an error
+    /// that error will not be sent to the server, instead it will be
+    /// placed into the result stream and the server connection
+    /// terminated.
+    ///
     /// # Example:
     /// ```no_run
     /// # async fn run() {
@@ -279,9 +291,7 @@ impl FlightClient {
     ///
     /// // encode the batch as a stream of `FlightData`
     /// let flight_data_stream = FlightDataEncoderBuilder::new()
-    ///   .build(futures::stream::iter(vec![Ok(batch)]))
-    ///   // data encoder return Results, but do_put requires FlightData
-    ///   .map(|batch|batch.unwrap());
+    ///   .build(futures::stream::iter(vec![Ok(batch)]));
     ///
     /// // send the stream and get the results as `PutResult`
     /// let response: Vec<PutResult>= client
@@ -293,20 +303,40 @@ impl FlightClient {
     ///   .expect("error calling do_put");
     /// # }
     /// ```
-    pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
+    pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
         &mut self,
         request: S,
     ) -> Result<BoxStream<'static, Result<PutResult>>> {
-        let request = self.make_request(request);
-
-        let response = self
-            .inner
-            .do_put(request)
-            .await?
-            .into_inner()
-            .map_err(FlightError::Tonic);
+        let (sender, mut receiver) = futures::channel::oneshot::channel();
+
+        // Intercepts client errors and sends them to the oneshot channel above
+        let mut request = Box::pin(request); // Pin to heap
+        let mut sender = Some(sender); // Wrap into Option so can be taken
+        let request_stream = futures::stream::poll_fn(move |cx| {
+            Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
+                Some(Ok(data)) => Some(data),
+                Some(Err(e)) => {
+                    let _ = sender.take().unwrap().send(e);
+                    None
+                }
+                None => None,
+            })
+        });
+
+        let request = self.make_request(request_stream);
+        let mut response_stream = 
self.inner.do_put(request).await?.into_inner();
+
+        // Forwards errors from the error oneshot with priority over responses 
from server
+        let error_stream = futures::stream::poll_fn(move |cx| {
+            if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
+                return Poll::Ready(Some(Err(err)));
+            }
+            let next = ready!(response_stream.poll_next_unpin(cx));
+            Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
+        });
 
-        Ok(response.boxed())
+        // combine the response from the server and any error from the client
+        Ok(error_stream.boxed())
     }
 
     /// Make a `DoExchange` call to the server with the provided
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index ab1cfa1fb..ed928a52c 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -248,8 +248,10 @@ async fn test_do_put() {
         test_server
             
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());
 
+        let input_stream = 
futures::stream::iter(input_flight_data.clone()).map(Ok);
+
         let response_stream = client
-            .do_put(futures::stream::iter(input_flight_data.clone()))
+            .do_put(input_stream)
             .await
             .expect("error making request");
 
@@ -266,15 +268,15 @@ async fn test_do_put() {
 }
 
 #[tokio::test]
-async fn test_do_put_error() {
+async fn test_do_put_error_server() {
     do_test(|test_server, mut client| async move {
         client.add_header("foo-header", "bar-header-value").unwrap();
 
         let input_flight_data = test_flight_data().await;
 
-        let response = client
-            .do_put(futures::stream::iter(input_flight_data.clone()))
-            .await;
+        let input_stream = 
futures::stream::iter(input_flight_data.clone()).map(Ok);
+
+        let response = client.do_put(input_stream).await;
         let response = match response {
             Ok(_) => panic!("unexpected success"),
             Err(e) => e,
@@ -290,7 +292,7 @@ async fn test_do_put_error() {
 }
 
 #[tokio::test]
-async fn test_do_put_error_stream() {
+async fn test_do_put_error_stream_server() {
     do_test(|test_server, mut client| async move {
         client.add_header("foo-header", "bar-header-value").unwrap();
 
@@ -307,8 +309,10 @@ async fn test_do_put_error_stream() {
 
         test_server.set_do_put_response(response);
 
+        let input_stream = 
futures::stream::iter(input_flight_data.clone()).map(Ok);
+
         let response_stream = client
-            .do_put(futures::stream::iter(input_flight_data.clone()))
+            .do_put(input_stream)
             .await
             .expect("error making request");
 
@@ -326,6 +330,87 @@ async fn test_do_put_error_stream() {
     .await;
 }
 
+#[tokio::test]
+async fn test_do_put_error_client() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let e = Status::invalid_argument("bad arg: client");
+
+        // input stream to client sends good FlightData followed by an error
+        let input_flight_data = test_flight_data().await;
+        let input_stream = futures::stream::iter(input_flight_data.clone())
+            .map(Ok)
+            .chain(futures::stream::iter(vec![Err(FlightError::from(
+                e.clone(),
+            ))]));
+
+        // server responds with one good message
+        let response = vec![Ok(PutResult {
+            app_metadata: Bytes::from("foo-metadata"),
+        })];
+        test_server.set_do_put_response(response);
+
+        let response_stream = client
+            .do_put(input_stream)
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        // expect to the error made from the client
+        expect_status(response, e);
+        // server still got the request messages until the client sent the 
error
+        assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_put_error_client_and_server() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let e_client = Status::invalid_argument("bad arg: client");
+        let e_server = Status::invalid_argument("bad arg: server");
+
+        // input stream to client sends good FlightData followed by an error
+        let input_flight_data = test_flight_data().await;
+        let input_stream = futures::stream::iter(input_flight_data.clone())
+            .map(Ok)
+            .chain(futures::stream::iter(vec![Err(FlightError::from(
+                e_client.clone(),
+            ))]));
+
+        // server responds with an error (e.g. because it got truncated data)
+        let response = vec![Err(e_server)];
+        test_server.set_do_put_response(response);
+
+        let response_stream = client
+            .do_put(input_stream)
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        // expect to the error made from the client (not the server)
+        expect_status(response, e_client);
+        // server still got the request messages until the client sent the 
error
+        assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
 #[tokio::test]
 async fn test_do_exchange() {
     do_test(|test_server, mut client| async move {

Reply via email to