This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 0373a9d77 Implement fallible streams for `FlightClient::do_put` (#3464)
0373a9d77 is described below
commit 0373a9d77f446918d44b1ee216ed33de3905b688
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Feb 23 10:57:10 2023 +0100
Implement fallible streams for `FlightClient::do_put` (#3464)
* Implement fallible streams for do_put
* Another approach to error wrapping
* implement basic client error test
* Add last error test
* comments
* fix docs
* Simplify
---------
Co-authored-by: Raphael Taylor-Davies <[email protected]>
---
arrow-flight/src/client.rs | 58 +++++++++++++++++++-------
arrow-flight/tests/client.rs | 99 ++++++++++++++++++++++++++++++++++++++++----
2 files changed, 136 insertions(+), 21 deletions(-)
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index bdd51dda4..fe1292fcf 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+use std::task::Poll;
+
use crate::{
decode::FlightRecordBatchStream,
flight_service_client::FlightServiceClient, Action,
ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
@@ -24,8 +26,9 @@ use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
future::ready,
+ ready,
stream::{self, BoxStream},
- Stream, StreamExt, TryStreamExt,
+ FutureExt, Stream, StreamExt, TryStreamExt,
};
use tonic::{metadata::MetadataMap, transport::Channel};
@@ -262,6 +265,15 @@ impl FlightClient {
/// [`Stream`](futures::Stream) of [`FlightData`] and returning a
/// stream of [`PutResult`].
///
+ /// # Note
+ ///
+ /// The input stream is [`Result`] so that this can be connected
+ /// to a streaming data source, such as
[`FlightDataEncoder`](crate::encode::FlightDataEncoder),
+ /// without having to buffer. If the input stream returns an error
+ /// that error will not be sent to the server, instead it will be
+ /// placed into the result stream and the server connection
+ /// terminated.
+ ///
/// # Example:
/// ```no_run
/// # async fn run() {
@@ -279,9 +291,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
- /// .build(futures::stream::iter(vec![Ok(batch)]))
- /// // data encoder return Results, but do_put requires FlightData
- /// .map(|batch|batch.unwrap());
+ /// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `PutResult`
/// let response: Vec<PutResult>= client
@@ -293,20 +303,40 @@ impl FlightClient {
/// .expect("error calling do_put");
/// # }
/// ```
- pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
+ pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
- let request = self.make_request(request);
-
- let response = self
- .inner
- .do_put(request)
- .await?
- .into_inner()
- .map_err(FlightError::Tonic);
+ let (sender, mut receiver) = futures::channel::oneshot::channel();
+
+ // Intercepts client errors and sends them to the oneshot channel above
+ let mut request = Box::pin(request); // Pin to heap
+ let mut sender = Some(sender); // Wrap into Option so can be taken
+ let request_stream = futures::stream::poll_fn(move |cx| {
+ Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
+ Some(Ok(data)) => Some(data),
+ Some(Err(e)) => {
+ let _ = sender.take().unwrap().send(e);
+ None
+ }
+ None => None,
+ })
+ });
+
+ let request = self.make_request(request_stream);
+ let mut response_stream =
self.inner.do_put(request).await?.into_inner();
+
+ // Forwards errors from the error oneshot with priority over responses
from server
+ let error_stream = futures::stream::poll_fn(move |cx| {
+ if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
+ return Poll::Ready(Some(Err(err)));
+ }
+ let next = ready!(response_stream.poll_next_unpin(cx));
+ Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
+ });
- Ok(response.boxed())
+ // combine the response from the server and any error from the client
+ Ok(error_stream.boxed())
}
/// Make a `DoExchange` call to the server with the provided
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index ab1cfa1fb..ed928a52c 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -248,8 +248,10 @@ async fn test_do_put() {
test_server
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());
+ let input_stream =
futures::stream::iter(input_flight_data.clone()).map(Ok);
+
let response_stream = client
- .do_put(futures::stream::iter(input_flight_data.clone()))
+ .do_put(input_stream)
.await
.expect("error making request");
@@ -266,15 +268,15 @@ async fn test_do_put() {
}
#[tokio::test]
-async fn test_do_put_error() {
+async fn test_do_put_error_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
- let response = client
- .do_put(futures::stream::iter(input_flight_data.clone()))
- .await;
+ let input_stream =
futures::stream::iter(input_flight_data.clone()).map(Ok);
+
+ let response = client.do_put(input_stream).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
@@ -290,7 +292,7 @@ async fn test_do_put_error() {
}
#[tokio::test]
-async fn test_do_put_error_stream() {
+async fn test_do_put_error_stream_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
@@ -307,8 +309,10 @@ async fn test_do_put_error_stream() {
test_server.set_do_put_response(response);
+ let input_stream =
futures::stream::iter(input_flight_data.clone()).map(Ok);
+
let response_stream = client
- .do_put(futures::stream::iter(input_flight_data.clone()))
+ .do_put(input_stream)
.await
.expect("error making request");
@@ -326,6 +330,87 @@ async fn test_do_put_error_stream() {
.await;
}
+#[tokio::test]
+async fn test_do_put_error_client() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let e = Status::invalid_argument("bad arg: client");
+
+ // input stream to client sends good FlightData followed by an error
+ let input_flight_data = test_flight_data().await;
+ let input_stream = futures::stream::iter(input_flight_data.clone())
+ .map(Ok)
+ .chain(futures::stream::iter(vec![Err(FlightError::from(
+ e.clone(),
+ ))]));
+
+ // server responds with one good message
+ let response = vec![Ok(PutResult {
+ app_metadata: Bytes::from("foo-metadata"),
+ })];
+ test_server.set_do_put_response(response);
+
+ let response_stream = client
+ .do_put(input_stream)
+ .await
+ .expect("error making request");
+
+ let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ // expect to the error made from the client
+ expect_status(response, e);
+ // server still got the request messages until the client sent the
error
+ assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_put_error_client_and_server() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let e_client = Status::invalid_argument("bad arg: client");
+ let e_server = Status::invalid_argument("bad arg: server");
+
+ // input stream to client sends good FlightData followed by an error
+ let input_flight_data = test_flight_data().await;
+ let input_stream = futures::stream::iter(input_flight_data.clone())
+ .map(Ok)
+ .chain(futures::stream::iter(vec![Err(FlightError::from(
+ e_client.clone(),
+ ))]));
+
+ // server responds with an error (e.g. because it got truncated data)
+ let response = vec![Err(e_server)];
+ test_server.set_do_put_response(response);
+
+ let response_stream = client
+ .do_put(input_stream)
+ .await
+ .expect("error making request");
+
+ let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ // expect to the error made from the client (not the server)
+ expect_status(response, e_client);
+ // server still got the request messages until the client sent the
error
+ assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
#[tokio::test]
async fn test_do_exchange() {
do_test(|test_server, mut client| async move {