This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new e3f1c9669cb Refactor to share code between do_put and do_exchange
calls (#5728)
e3f1c9669cb is described below
commit e3f1c9669cb16939383c880927c0b15f943e3690
Author: opensourcegeek <[email protected]>
AuthorDate: Tue May 7 20:21:15 2024 +0100
Refactor to share code between do_put and do_exchange calls (#5728)
Signed-off-by: Praveen Kumar <[email protected]>
Co-authored-by: Praveen Kumar <[email protected]>
---
arrow-flight/src/client.rs | 154 ++++++++++++++++++++++++++++++++-------------
1 file changed, 110 insertions(+), 44 deletions(-)
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index a7e15fe24cc..f5f28f683a9 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use std::task::Poll;
+use std::{pin::Pin, task::Poll};
use crate::{
decode::FlightRecordBatchStream,
@@ -28,6 +28,7 @@ use crate::{
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
+ channel::oneshot::{Receiver, Sender},
future::ready,
ready,
stream::{self, BoxStream},
@@ -364,33 +365,18 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
- let (sender, mut receiver) = futures::channel::oneshot::channel();
+ let (sender, receiver) = futures::channel::oneshot::channel();
// Intercepts client errors and sends them to the oneshot channel above
- let mut request = Box::pin(request); // Pin to heap
- let mut sender = Some(sender); // Wrap into Option so can be taken
- let request_stream = futures::stream::poll_fn(move |cx| {
- Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
- Some(Ok(data)) => Some(data),
- Some(Err(e)) => {
- let _ = sender.take().unwrap().send(e);
- None
- }
- None => None,
- })
- });
+ let request = Box::pin(request); // Pin to heap
+ let request_stream = FallibleRequestStream::new(sender, request);
let request = self.make_request(request_stream);
- let mut response_stream =
self.inner.do_put(request).await?.into_inner();
+ let response_stream = self.inner.do_put(request).await?.into_inner();
// Forwards errors from the error oneshot with priority over responses
from server
- let error_stream = futures::stream::poll_fn(move |cx| {
- if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
- return Poll::Ready(Some(Err(err)));
- }
- let next = ready!(response_stream.poll_next_unpin(cx));
- Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
- });
+ let response_stream = Box::pin(response_stream);
+ let error_stream = FallibleTonicResponseStream::new(receiver,
response_stream);
// combine the response from the server and any error from the client
Ok(error_stream.boxed())
@@ -433,33 +419,17 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
- let (sender, mut receiver) = futures::channel::oneshot::channel();
+ let (sender, receiver) = futures::channel::oneshot::channel();
+ let request = Box::pin(request);
// Intercepts client errors and sends them to the oneshot channel above
- let mut request = Box::pin(request); // Pin to heap
- let mut sender = Some(sender); // Wrap into Option so can be taken
- let request_stream = futures::stream::poll_fn(move |cx| {
- Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
- Some(Ok(data)) => Some(data),
- Some(Err(e)) => {
- let _ = sender.take().unwrap().send(e);
- None
- }
- None => None,
- })
- });
+ let request_stream = FallibleRequestStream::new(sender, request);
let request = self.make_request(request_stream);
- let mut response_stream =
self.inner.do_exchange(request).await?.into_inner();
+ let response_stream =
self.inner.do_exchange(request).await?.into_inner();
- // Forwards errors from the error oneshot with priority over responses
from server
- let error_stream = futures::stream::poll_fn(move |cx| {
- if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
- return Poll::Ready(Some(Err(err)));
- }
- let next = ready!(response_stream.poll_next_unpin(cx));
- Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
- });
+ let response_stream = Box::pin(response_stream);
+ let error_stream = FallibleTonicResponseStream::new(receiver,
response_stream);
// combine the response from the server and any error from the client
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
@@ -704,3 +674,99 @@ impl FlightClient {
request
}
}
+
+/// Wrapper around fallible stream such that when
+/// it encounters an error it uses the oneshot sender to
+/// notify the error and stop any further streaming. See `do_put` or
+/// `do_exchange` for it's uses.
+struct FallibleRequestStream<T, E> {
+ /// sender to notify error
+ sender: Option<Sender<E>>,
+ /// fallible stream
+ fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> +
Send + 'static>>,
+}
+
+impl<T, E> FallibleRequestStream<T, E> {
+ fn new(
+ sender: Sender<E>,
+ fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>>
+ Send + 'static>>,
+ ) -> Self {
+ Self {
+ sender: Some(sender),
+ fallible_stream,
+ }
+ }
+}
+
+impl<T, E> Stream for FallibleRequestStream<T, E> {
+ type Item = T;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ let pinned = self.get_mut();
+ let mut request_streams = pinned.fallible_stream.as_mut();
+ match ready!(request_streams.poll_next_unpin(cx)) {
+ Some(Ok(data)) => Poll::Ready(Some(data)),
+ Some(Err(e)) => {
+ // unwrap() here is safe, ownership of sender will
+ // be moved only once as this stream will not be polled
+ // again
+ let _ = pinned.sender.take().unwrap().send(e);
+ Poll::Ready(None)
+ }
+ None => Poll::Ready(None),
+ }
+ }
+}
+
+/// Wrapper for a tonic response stream that can produce a tonic
+/// error. This is tied to a oneshot receiver which can be notified
+/// of other errors. When it receives an error through receiver
+/// end, it prioritises that error to be sent back. See `do_put` or
+/// `do_exchange` for it's uses
+struct FallibleTonicResponseStream<T> {
+ /// Receiver for FlightError
+ receiver: Receiver<FlightError>,
+ /// Tonic response stream
+ response_stream:
+ Pin<Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> +
Send + 'static>>,
+}
+
+impl<T> FallibleTonicResponseStream<T> {
+ fn new(
+ receiver: Receiver<FlightError>,
+ response_stream: Pin<
+ Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> +
Send + 'static>,
+ >,
+ ) -> Self {
+ Self {
+ receiver,
+ response_stream,
+ }
+ }
+}
+
+impl<T> Stream for FallibleTonicResponseStream<T> {
+ type Item = Result<T>;
+
+ fn poll_next(
+ self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let pinned = self.get_mut();
+ let receiver = &mut pinned.receiver;
+ // Prioritise sending the error that's been notified over
+ // polling the response_stream
+ if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
+ return Poll::Ready(Some(Err(err)));
+ };
+
+ match ready!(pinned.response_stream.poll_next_unpin(cx)) {
+ Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
+ Some(Err(status)) =>
Poll::Ready(Some(Err(FlightError::Tonic(status)))),
+ None => Poll::Ready(None),
+ }
+ }
+}