This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new eb2d00b3ab2 Fallible stream for arrow-flight do_exchange call (#3462) 
(#5698)
eb2d00b3ab2 is described below

commit eb2d00b3ab225395e664bc578a1b56c42c6ac32d
Author: opensourcegeek <[email protected]>
AuthorDate: Thu May 2 10:34:53 2024 +0100

    Fallible stream for arrow-flight do_exchange call (#3462) (#5698)
    
    Signed-off-by: Praveen Kumar <[email protected]>
    Co-authored-by: Praveen Kumar <[email protected]>
---
 arrow-flight/src/client.rs   | 42 +++++++++++++------
 arrow-flight/tests/client.rs | 97 ++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 124 insertions(+), 15 deletions(-)

diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index b2abfb0c17b..a7e15fe24cc 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -417,9 +417,7 @@ impl FlightClient {
     ///
     /// // encode the batch as a stream of `FlightData`
     /// let flight_data_stream = FlightDataEncoderBuilder::new()
-    ///   .build(futures::stream::iter(vec![Ok(batch)]))
-    ///   // data encoder return Results, but do_exchange requires FlightData
-    ///   .map(|batch|batch.unwrap());
+    ///   .build(futures::stream::iter(vec![Ok(batch)]));
     ///
     /// // send the stream and get the results as `RecordBatches`
     /// let response: Vec<RecordBatch> = client
@@ -431,20 +429,40 @@ impl FlightClient {
     ///   .expect("error calling do_exchange");
     /// # }
     /// ```
-    pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
+    pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 
'static>(
         &mut self,
         request: S,
     ) -> Result<FlightRecordBatchStream> {
-        let request = self.make_request(request);
+        let (sender, mut receiver) = futures::channel::oneshot::channel();
 
-        let response = self
-            .inner
-            .do_exchange(request)
-            .await?
-            .into_inner()
-            .map_err(FlightError::Tonic);
+        // Intercepts client errors and sends them to the oneshot channel above
+        let mut request = Box::pin(request); // Pin to heap
+        let mut sender = Some(sender); // Wrap into Option so can be taken
+        let request_stream = futures::stream::poll_fn(move |cx| {
+            Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
+                Some(Ok(data)) => Some(data),
+                Some(Err(e)) => {
+                    let _ = sender.take().unwrap().send(e);
+                    None
+                }
+                None => None,
+            })
+        });
+
+        let request = self.make_request(request_stream);
+        let mut response_stream = 
self.inner.do_exchange(request).await?.into_inner();
+
+        // Forwards errors from the error oneshot with priority over responses 
from server
+        let error_stream = futures::stream::poll_fn(move |cx| {
+            if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
+                return Poll::Ready(Some(Err(err)));
+            }
+            let next = ready!(response_stream.poll_next_unpin(cx));
+            Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
+        });
 
-        Ok(FlightRecordBatchStream::new_from_flight_data(response))
+        // combine the response from the server and any error from the client
+        Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
     }
 
     /// Make a `ListFlights` call to the server with the provided
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index 47565334cb6..478938d939a 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -493,7 +493,7 @@ async fn test_do_exchange() {
             
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());
 
         let response_stream = client
-            .do_exchange(futures::stream::iter(input_flight_data.clone()))
+            
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
             .await
             .expect("error making request");
 
@@ -528,7 +528,7 @@ async fn test_do_exchange_error() {
         let input_flight_data = test_flight_data().await;
 
         let response = client
-            .do_exchange(futures::stream::iter(input_flight_data.clone()))
+            
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
             .await;
         let response = match response {
             Ok(_) => panic!("unexpected success"),
@@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() {
         test_server.set_do_exchange_response(response);
 
         let response_stream = client
-            .do_exchange(futures::stream::iter(input_flight_data.clone()))
+            
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
             .await
             .expect("error making request");
 
@@ -593,6 +593,97 @@ async fn test_do_exchange_error_stream() {
     .await;
 }
 
+#[tokio::test]
+async fn test_do_exchange_error_stream_client() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let e = Status::invalid_argument("bad arg: client");
+
+        // input stream to client sends good FlightData followed by an error
+        let input_flight_data = test_flight_data().await;
+        let input_stream = futures::stream::iter(input_flight_data.clone())
+            .map(Ok)
+            .chain(futures::stream::iter(vec![Err(FlightError::from(
+                e.clone(),
+            ))]));
+
+        let output_flight_data = FlightData::new()
+            .with_descriptor(FlightDescriptor::new_cmd("Sample command"))
+            .with_data_body("body".as_bytes())
+            .with_data_header("header".as_bytes())
+            .with_app_metadata("metadata".as_bytes());
+
+        // server responds with one good message
+        let response = vec![Ok(output_flight_data)];
+        test_server.set_do_exchange_response(response);
+
+        let response_stream = client
+            .do_exchange(input_stream)
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        // expect to the error made from the client
+        expect_status(response, e);
+        // server still got the request messages until the client sent the 
error
+        assert_eq!(
+            test_server.take_do_exchange_request(),
+            Some(input_flight_data)
+        );
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange_error_client_and_server() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let e_client = Status::invalid_argument("bad arg: client");
+        let e_server = Status::invalid_argument("bad arg: server");
+
+        // input stream to client sends good FlightData followed by an error
+        let input_flight_data = test_flight_data().await;
+        let input_stream = futures::stream::iter(input_flight_data.clone())
+            .map(Ok)
+            .chain(futures::stream::iter(vec![Err(FlightError::from(
+                e_client.clone(),
+            ))]));
+
+        // server responds with an error (e.g. because it got truncated data)
+        let response = vec![Err(e_server)];
+        test_server.set_do_exchange_response(response);
+
+        let response_stream = client
+            .do_exchange(input_stream)
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        // expect to the error made from the client (not the server)
+        expect_status(response, e_client);
+        // server still got the request messages until the client sent the 
error
+        assert_eq!(
+            test_server.take_do_exchange_request(),
+            Some(input_flight_data)
+        );
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
 #[tokio::test]
 async fn test_get_schema() {
     do_test(|test_server, mut client| async move {

Reply via email to