This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new eb2d00b3ab2 Fallible stream for arrow-flight do_exchange call (#3462)
(#5698)
eb2d00b3ab2 is described below
commit eb2d00b3ab225395e664bc578a1b56c42c6ac32d
Author: opensourcegeek <[email protected]>
AuthorDate: Thu May 2 10:34:53 2024 +0100
Fallible stream for arrow-flight do_exchange call (#3462) (#5698)
Signed-off-by: Praveen Kumar <[email protected]>
Co-authored-by: Praveen Kumar <[email protected]>
---
arrow-flight/src/client.rs | 42 +++++++++++++------
arrow-flight/tests/client.rs | 97 ++++++++++++++++++++++++++++++++++++++++++--
2 files changed, 124 insertions(+), 15 deletions(-)
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index b2abfb0c17b..a7e15fe24cc 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -417,9 +417,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
- /// .build(futures::stream::iter(vec![Ok(batch)]))
- /// // data encoder return Results, but do_exchange requires FlightData
- /// .map(|batch|batch.unwrap());
+ /// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `RecordBatches`
/// let response: Vec<RecordBatch> = client
@@ -431,20 +429,40 @@ impl FlightClient {
/// .expect("error calling do_exchange");
/// # }
/// ```
- pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
+ pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send +
'static>(
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
- let request = self.make_request(request);
+ let (sender, mut receiver) = futures::channel::oneshot::channel();
- let response = self
- .inner
- .do_exchange(request)
- .await?
- .into_inner()
- .map_err(FlightError::Tonic);
+ // Intercepts client errors and sends them to the oneshot channel above
+ let mut request = Box::pin(request); // Pin to heap
+ let mut sender = Some(sender); // Wrap into Option so can be taken
+ let request_stream = futures::stream::poll_fn(move |cx| {
+ Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
+ Some(Ok(data)) => Some(data),
+ Some(Err(e)) => {
+ let _ = sender.take().unwrap().send(e);
+ None
+ }
+ None => None,
+ })
+ });
+
+ let request = self.make_request(request_stream);
+ let mut response_stream =
self.inner.do_exchange(request).await?.into_inner();
+
+ // Forwards errors from the error oneshot with priority over responses
from server
+ let error_stream = futures::stream::poll_fn(move |cx| {
+ if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
+ return Poll::Ready(Some(Err(err)));
+ }
+ let next = ready!(response_stream.poll_next_unpin(cx));
+ Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
+ });
- Ok(FlightRecordBatchStream::new_from_flight_data(response))
+ // combine the response from the server and any error from the client
+ Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
}
/// Make a `ListFlights` call to the server with the provided
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index 47565334cb6..478938d939a 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -493,7 +493,7 @@ async fn test_do_exchange() {
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());
let response_stream = client
- .do_exchange(futures::stream::iter(input_flight_data.clone()))
+
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");
@@ -528,7 +528,7 @@ async fn test_do_exchange_error() {
let input_flight_data = test_flight_data().await;
let response = client
- .do_exchange(futures::stream::iter(input_flight_data.clone()))
+
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await;
let response = match response {
Ok(_) => panic!("unexpected success"),
@@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() {
test_server.set_do_exchange_response(response);
let response_stream = client
- .do_exchange(futures::stream::iter(input_flight_data.clone()))
+
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");
@@ -593,6 +593,97 @@ async fn test_do_exchange_error_stream() {
.await;
}
+#[tokio::test]
+async fn test_do_exchange_error_stream_client() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let e = Status::invalid_argument("bad arg: client");
+
+ // input stream to client sends good FlightData followed by an error
+ let input_flight_data = test_flight_data().await;
+ let input_stream = futures::stream::iter(input_flight_data.clone())
+ .map(Ok)
+ .chain(futures::stream::iter(vec![Err(FlightError::from(
+ e.clone(),
+ ))]));
+
+ let output_flight_data = FlightData::new()
+ .with_descriptor(FlightDescriptor::new_cmd("Sample command"))
+ .with_data_body("body".as_bytes())
+ .with_data_header("header".as_bytes())
+ .with_app_metadata("metadata".as_bytes());
+
+ // server responds with one good message
+ let response = vec![Ok(output_flight_data)];
+ test_server.set_do_exchange_response(response);
+
+ let response_stream = client
+ .do_exchange(input_stream)
+ .await
+ .expect("error making request");
+
+ let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ // expect to the error made from the client
+ expect_status(response, e);
+ // server still got the request messages until the client sent the
error
+ assert_eq!(
+ test_server.take_do_exchange_request(),
+ Some(input_flight_data)
+ );
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange_error_client_and_server() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let e_client = Status::invalid_argument("bad arg: client");
+ let e_server = Status::invalid_argument("bad arg: server");
+
+ // input stream to client sends good FlightData followed by an error
+ let input_flight_data = test_flight_data().await;
+ let input_stream = futures::stream::iter(input_flight_data.clone())
+ .map(Ok)
+ .chain(futures::stream::iter(vec![Err(FlightError::from(
+ e_client.clone(),
+ ))]));
+
+ // server responds with an error (e.g. because it got truncated data)
+ let response = vec![Err(e_server)];
+ test_server.set_do_exchange_response(response);
+
+ let response_stream = client
+ .do_exchange(input_stream)
+ .await
+ .expect("error making request");
+
+ let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ // expect to the error made from the client (not the server)
+ expect_status(response, e_client);
+ // server still got the request messages until the client sent the
error
+ assert_eq!(
+ test_server.take_do_exchange_request(),
+ Some(input_flight_data)
+ );
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
#[tokio::test]
async fn test_get_schema() {
do_test(|test_server, mut client| async move {