alamb commented on code in PR #3378:
URL: https://github.com/apache/arrow-rs/pull/3378#discussion_r1053555213


##########
arrow-flight/src/error.rs:
##########
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Errors for the Apache Arrow Flight crate

Review Comment:
   This follows the same model as `ArrowError` as I did not want to innovate a 
new Error paradigm (e.g. `thiserror` as part of this PR)



##########
arrow-flight/tests/common/server.rs:
##########
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{
+    pin::Pin,
+    sync::{Arc, Mutex},
+};
+
+use futures::Stream;
+use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming};
+
+use arrow_flight::{
+    flight_service_server::{FlightService, FlightServiceServer},
+    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, 
FlightInfo,
+    HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
+};
+
+#[derive(Debug, Clone)]
+/// Flight server for testing, with configurable responses
+pub struct TestFlightServer {
+    /// Shared state to configure responses
+    state: Arc<Mutex<State>>,
+}
+
+impl TestFlightServer {
+    /// Create a `TestFlightServer`
+    pub fn new() -> Self {
+        Self {
+            state: Arc::new(Mutex::new(State::new())),
+        }
+    }
+
+    /// Return an [`FlightServiceServer`] that can be used with a
+    /// [`Server`](tonic::transport::Server)
+    pub fn service(&self) -> FlightServiceServer<TestFlightServer> {
+        // wrap up tonic goop
+        FlightServiceServer::new(self.clone())
+    }
+
+    /// Specify the response returned from the next call to handshake
+    pub fn set_handshake_response(&self, response: Result<HandshakeResponse, 
Status>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.handshake_response.replace(response);
+    }
+
+    /// Take and return last handshake request send to the server,
+    pub fn take_handshake_request(&self) -> Option<HandshakeRequest> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .handshake_request
+            .take()
+    }
+
+    /// Specify the response returned from the next call to handshake
+    pub fn set_get_flight_info_response(&self, response: Result<FlightInfo, 
Status>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.get_flight_info_response.replace(response);
+    }
+
+    /// Take and return last get_flight_info request send to the server,
+    pub fn take_get_flight_info_request(&self) -> Option<FlightDescriptor> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .get_flight_info_request
+            .take()
+    }
+
+    /// Returns the last metadata from a request received by the server
+    pub fn take_last_request_metadata(&self) -> Option<MetadataMap> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .last_request_metadata
+            .take()
+    }
+
+    /// Save the last request's metadatacom
+    fn save_metadata<T>(&self, request: &Request<T>) {
+        let metadata = request.metadata().clone();
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.last_request_metadata = Some(metadata);
+    }
+}
+
+/// mutable state for the TestFlightSwrver
+#[derive(Debug, Default)]
+struct State {
+    /// The last handshake request that was received
+    pub handshake_request: Option<HandshakeRequest>,
+    /// The next response to return from `handshake()`
+    pub handshake_response: Option<Result<HandshakeResponse, Status>>,
+    /// The last `get_flight_info` request received
+    pub get_flight_info_request: Option<FlightDescriptor>,
+    /// the next response  to return from `get_flight_info`
+    pub get_flight_info_response: Option<Result<FlightInfo, Status>>,
+    /// The last request headers received
+    pub last_request_metadata: Option<MetadataMap>,
+}
+
+impl State {
+    fn new() -> Self {
+        Default::default()
+    }
+}
+
+/// Implement the FlightService trait
+#[tonic::async_trait]
+impl FlightService for TestFlightServer {
+    type HandshakeStream = Pin<
+        Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send + Sync 
+ 'static>,
+    >;

Review Comment:
   👍  -- this is actually copy/paste from the example server



##########
arrow-flight/tests/client.rs:
##########
@@ -0,0 +1,331 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Integration test for "mid level" Client
+
+mod common {

Review Comment:
   here are the "missing" tests for arrow-flight -- I think this framework will 
let us bang out the tests for the flight / flight sql functionality



##########
arrow-flight/src/client.rs:
##########
@@ -0,0 +1,584 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{
+    flight_service_client::FlightServiceClient, 
utils::flight_data_to_arrow_batch,
+    FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
+};
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_schema::Schema;
+use futures::{ready, stream, StreamExt};
+use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc, 
task::Poll};
+use tonic::{metadata::MetadataMap, transport::Channel, Streaming};
+
+use crate::error::{FlightError, Result};
+
+/// A "Mid level" [Apache Arrow 
Flight](https://arrow.apache.org/docs/format/Flight.html) client.
+///
+/// [`FlightClient`] is intended as a convenience for interactions
+/// with Arrow Flight servers. For more direct control, such as access
+/// to the response headers, use  [`FlightServiceClient`] directly
+/// via methods such as [`Self::inner`] or [`Self::into_inner`].
+///
+/// # Example:
+/// ```no_run
+/// # async fn run() {
+/// # use arrow_flight::FlightClient;
+/// use tonic::transport::Channel;
+/// let channel = Channel::from_static("http://localhost:1234";)
+///   .connect()
+///   .await
+///   .expect("error connecting");
+///
+/// let mut client = FlightClient::new(channel);
+///
+/// // Send 'Hi' bytes as the handshake request to the server
+/// let response = client
+///   .handshake(b"Hi".to_vec())
+///   .await
+///   .expect("error handshaking");
+///
+/// // Expect the server responded with 'Ho'
+/// assert_eq!(response, b"Ho".to_vec());
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct FlightClient {
+    /// Optional grpc header metadata to include with each request
+    metadata: MetadataMap,
+
+    /// The inner client
+    inner: FlightServiceClient<Channel>,
+}
+
+impl FlightClient {
+    /// Creates a client client with the provided 
[`Channel`](tonic::transport::Channel)
+    pub fn new(channel: Channel) -> Self {
+        Self::new_from_inner(FlightServiceClient::new(channel))
+    }
+
+    /// Creates a new higher level client with the provided lower level client
+    pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
+        Self {
+            metadata: MetadataMap::new(),
+            inner,
+        }
+    }
+
+    /// Return a reference to gRPC metadata included with each request
+    pub fn metadata(&self) -> &MetadataMap {
+        &self.metadata
+    }
+
+    /// Return a reference to gRPC metadata included with each request
+    ///
+    /// These headers can be used, for example, to include
+    /// authorization or other application specific headers.
+    pub fn metadata_mut(&mut self) -> &mut MetadataMap {
+        &mut self.metadata
+    }
+
+    /// Add the specified header with value to all subsequent
+    /// requests. See [`Self::metadata_mut`] for fine grained control.
+    pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
+        let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
+            .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
+
+        let value = value
+            .parse()
+            .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
+
+        // ignore previous value
+        self.metadata.insert(key, value);
+
+        Ok(())
+    }
+
+    /// Return a reference to the underlying tonic
+    /// [`FlightServiceClient`]
+    pub fn inner(&self) -> &FlightServiceClient<Channel> {
+        &self.inner
+    }
+
+    /// Return a mutable reference to the underlying tonic
+    /// [`FlightServiceClient`]
+    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
+        &mut self.inner
+    }
+
+    /// Consume this client and return the underlying tonic
+    /// [`FlightServiceClient`]
+    pub fn into_inner(self) -> FlightServiceClient<Channel> {
+        self.inner
+    }
+
+    /// Perform an Arrow Flight handshake with the server, sending
+    /// `payload` as the [`HandshakeRequest`] payload and returning
+    /// the [`HandshakeResponse`](crate::HandshakeResponse)
+    /// bytes returned from the server
+    ///
+    /// See [`FlightClient`] docs for an example.
+    pub async fn handshake(&mut self, payload: Vec<u8>) -> Result<Vec<u8>> {
+        let request = HandshakeRequest {
+            protocol_version: 0,
+            payload,
+        };
+
+        // apply headers, etc
+        let request = self.make_request(stream::iter(vec![request]));
+
+        let mut response_stream = self
+            .inner
+            .handshake(request)
+            .await
+            .map_err(FlightError::Tonic)?
+            .into_inner();
+
+        if let Some(response) = response_stream.next().await {
+            let response = response.map_err(FlightError::Tonic)?;
+
+            // check if there is another response
+            if response_stream.next().await.is_some() {
+                return Err(FlightError::protocol(
+                    "Got unexpected second response from handshake",
+                ));
+            }
+
+            Ok(response.payload)
+        } else {
+            Err(FlightError::protocol("No response from handshake"))
+        }
+    }
+
+    /// Make a `DoGet` call to the server with the provided ticket,
+    /// returning a [`FlightRecordBatchStream`] for reading
+    /// [`RecordBatch`]es.
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use arrow_flight::FlightClient;
+    /// # use arrow_flight::Ticket;
+    /// # use arrow_array::RecordBatch;
+    /// # use tonic::transport::Channel;
+    /// # use futures::stream::TryStreamExt;
+    /// # let channel = Channel::from_static("http://localhost:1234";)
+    /// #  .connect()
+    /// #  .await
+    /// #  .expect("error connecting");
+    /// # let ticket = Ticket { ticket: b"foo".to_vec() };
+    ///
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // Invoke a do_get request on the server with a previously
+    /// // received Ticket
+    ///
+    /// let response = client
+    ///    .do_get(ticket)
+    ///    .await
+    ///    .expect("error invoking do_get");
+    ///
+    /// // Use try_collect to get the RecordBatches from the server
+    /// let batches: Vec<RecordBatch> = response
+    ///    .try_collect()
+    ///    .await
+    ///    .expect("no stream errors");
+    /// # }
+    /// ```
+    pub async fn do_get(&mut self, ticket: Ticket) -> 
Result<FlightRecordBatchStream> {
+        let request = self.make_request(ticket);
+
+        let response = self
+            .inner
+            .do_get(request)
+            .await
+            .map_err(FlightError::Tonic)?
+            .into_inner();
+
+        let flight_data_stream = FlightDataStream::new(response);
+        Ok(FlightRecordBatchStream::new(flight_data_stream))
+    }
+
+    /// Make a `GetFlightInfo` call to the server with the provided
+    /// [`FlightDescriptor`] and return the [`FlightInfo`] from the
+    /// server. The [`FlightInfo`] can be used with [`Self::do_get`]
+    /// to retrieve the requested batches.
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use arrow_flight::FlightClient;
+    /// # use arrow_flight::FlightDescriptor;
+    /// # use tonic::transport::Channel;
+    /// # let channel = Channel::from_static("http://localhost:1234";)
+    /// #   .connect()
+    /// #   .await
+    /// #   .expect("error connecting");
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // Send a 'CMD' request to the server
+    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
+    /// let flight_info = client
+    ///   .get_flight_info(request)
+    ///   .await
+    ///   .expect("error handshaking");
+    ///
+    /// // retrieve the first endpoint from the returned flight info
+    /// let ticket = flight_info
+    ///   .endpoint[0]
+    ///   // Extract the ticket
+    ///   .ticket
+    ///   .clone()
+    ///   .expect("expected ticket");
+    ///
+    /// // Retrieve the corresponding RecordBatch stream with do_get
+    /// let data = client
+    ///   .do_get(ticket)
+    ///   .await
+    ///   .expect("error fetching data");
+    /// # }
+    /// ```
+    pub async fn get_flight_info(
+        &mut self,
+        descriptor: FlightDescriptor,
+    ) -> Result<FlightInfo> {
+        let request = self.make_request(descriptor);
+
+        let response = self
+            .inner
+            .get_flight_info(request)
+            .await
+            .map_err(FlightError::Tonic)?
+            .into_inner();
+        Ok(response)
+    }
+
+    // TODO other methods
+    // list_flights
+    // get_schema
+    // do_put
+    // do_action
+    // list_actions
+    // do_exchange
+
+
+
+    /// return a Request, adding any configured metadata
+    fn make_request<T>(&self, t: T) -> tonic::Request<T> {
+        // Pass along metadata
+        let mut request = tonic::Request::new(t);
+        *request.metadata_mut() = self.metadata.clone();
+        request
+    }
+}
+
+/// A stream of [`RecordBatch`]es from from an Arrow Flight server.
+///
+/// To access the lower level Flight messages directly, consider
+/// calling [`Self::into_inner`] and using the [`FlightDataStream`]
+/// directly.
+#[derive(Debug)]
+pub struct FlightRecordBatchStream {
+    inner: FlightDataStream,
+    got_schema: bool,
+}
+
+impl FlightRecordBatchStream {
+    pub fn new(inner: FlightDataStream) -> Self {
+        Self {
+            inner,
+            got_schema: false,
+        }
+    }
+
+    /// Has a message defining the schema been received yet?
+    pub fn got_schema(&self) -> bool {
+        self.got_schema
+    }
+
+    /// Consume self and return the wrapped [`FlightDataStream`]
+    pub fn into_inner(self) -> FlightDataStream {
+        self.inner
+    }
+}
+impl futures::Stream for FlightRecordBatchStream {
+    type Item = Result<RecordBatch>;
+
+    /// Returns the next [`RecordBatch`] available in this stream, or `None` if
+    /// there are no further results available.
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        loop {
+            let res = ready!(self.inner.poll_next_unpin(cx));
+            match res {
+                // Inner exhausted
+                None => {
+                    return Poll::Ready(None);
+                }
+                Some(Err(e)) => {
+                    return Poll::Ready(Some(Err(e)));
+                }
+                // translate data
+                Some(Ok(data)) => match data.payload {
+                    DecodedPayload::Schema(_) if self.got_schema => {
+                        return Poll::Ready(Some(Err(FlightError::protocol(
+                            "Unexpectedly saw multiple Schema messages in 
FlightData stream",
+                        ))));
+                    }
+                    DecodedPayload::Schema(_) => {
+                        self.got_schema = true;
+                        // Need next message, poll inner again
+                    }
+                    DecodedPayload::RecordBatch(batch) => {
+                        return Poll::Ready(Some(Ok(batch)));
+                    }
+                    DecodedPayload::None => {
+                        // Need next message
+                    }
+                },
+            }
+        }
+    }
+}
+
+/// Wrapper around a stream of [`FlightData`] that handles the details
+/// of decoding low level Flight messages into [`Schema`] and
+/// [`RecordBatch`]es, including details such as dictionaries.
+///
+/// # Protocol Details
+///
+/// The client handles flight messages as followes:
+///
+/// - **None:** This message has no effect. This is useful to
+///   transmit metadata without any actual payload.
+///
+/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
+///   the decoded schema is returned.
+///
+/// - **Dictionary Batch:** A new dictionary for a given column is registered. 
An existing
+///   dictionary for the same column will be overwritten. This
+///   message is NOT visible.
+///
+/// - **Record Batch:** Record batch is created based on the current
+///   schema and dictionaries. This fails if no schema was transmitted
+///   yet.
+///
+/// All other message types (at the time of writing: e.g. tensor and
+/// sparse tensor) lead to an error.
+///
+/// Example usecases
+///
+/// 1. Using this low level stream it is possible to receive a steam
+/// of RecordBatches in FlightData that have different schemas by
+/// handling multiple schema messages separately.
+#[derive(Debug)]
+pub struct FlightDataStream {
+    /// Underlying data stream
+    response: Streaming<FlightData>,
+    /// Decoding state
+    state: Option<FlightStreamState>,
+    /// seen the end of the inner stream?
+    done: bool,
+}
+
+impl FlightDataStream {
+    /// Create a new wrapper around the stream of FlightData
+    pub fn new(response: Streaming<FlightData>) -> Self {
+        Self {
+            state: None,
+            response,
+            done: false,
+        }
+    }
+
+    /// Extracts flight data from the next message, updating decoding
+    /// state as necessary.
+    fn extract_message(&mut self, data: FlightData) -> 
Result<Option<DecodedFlightData>> {
+        use arrow_ipc::MessageHeader;
+        let message = 
arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| {
+            FlightError::DecodeError(format!("Error decoding root message: 
{e}"))
+        })?;
+
+        match message.header_type() {
+            MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
+            MessageHeader::Schema => {
+                let schema = Schema::try_from(&data).map_err(|e| {
+                    FlightError::DecodeError(format!("Error decoding schema: 
{e}"))
+                })?;
+
+                let schema = Arc::new(schema);
+                let dictionaries_by_field = HashMap::new();
+
+                self.state = Some(FlightStreamState {
+                    schema: Arc::clone(&schema),
+                    dictionaries_by_field,
+                });
+                Ok(Some(DecodedFlightData::new_schema(data, schema)))
+            }
+            MessageHeader::DictionaryBatch => {
+                let state = if let Some(state) = self.state.as_mut() {
+                    state
+                } else {
+                    return Err(FlightError::protocol(
+                        "Received DictionaryBatch prior to Schema",
+                    ));
+                };
+
+                let buffer: arrow_buffer::Buffer = data.data_body.into();
+                let dictionary_batch =
+                    message.header_as_dictionary_batch().ok_or_else(|| {
+                        FlightError::protocol(
+                            "Could not get dictionary batch from 
DictionaryBatch message",
+                        )
+                    })?;
+
+                arrow_ipc::reader::read_dictionary(

Review Comment:
   I believe there was some discussion about improving the efficiency of IPC 
reads, etc. This is now the place to do so. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to