This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new dc09b0b42 Implement `RecordBatch` <--> `FlightData` encode/decode +
tests (#3391)
dc09b0b42 is described below
commit dc09b0b426cb4d4c9d4bf0a112668256565c25cd
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Dec 31 07:58:01 2022 -0500
Implement `RecordBatch` <--> `FlightData` encode/decode + tests (#3391)
* Implement `RecordBatch` <--> `FlightData` encode/decode + tests
* fix comment
* Update arrow-flight/src/encode.rs
Co-authored-by: Liang-Chi Hsieh <[email protected]>
* Add test encoding error
* Add test for chained streams
* Add mismatched schema and data test
* Add new test
* more tests
* Apply suggestions from code review
Co-authored-by: Liang-Chi Hsieh <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* Add From ArrowError impl for FlightError
* Correct make_dictionary_batch and add tests
* do not take
* Make dictionary massaging non pub
* Add comment about memory size and make split function non pub
* explicitly return early from encode stream
* fix doc link
Co-authored-by: Liang-Chi Hsieh <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
arrow-flight/Cargo.toml | 2 +
arrow-flight/src/client.rs | 325 +------------------
arrow-flight/src/{client.rs => decode.rs} | 343 +++++---------------
arrow-flight/src/encode.rs | 511 ++++++++++++++++++++++++++++++
arrow-flight/src/error.rs | 25 ++
arrow-flight/src/lib.rs | 10 +-
arrow-flight/tests/client.rs | 92 +++++-
arrow-flight/tests/common/server.rs | 46 ++-
arrow-flight/tests/encode_decode.rs | 453 ++++++++++++++++++++++++++
arrow-ipc/src/reader.rs | 6 +
arrow-ipc/src/writer.rs | 5 +-
11 files changed, 1242 insertions(+), 576 deletions(-)
diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index 4fc0e0d91..1664004bd 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -29,6 +29,8 @@ license = "Apache-2.0"
[dependencies]
arrow-array = { version = "30.0.0", path = "../arrow-array" }
arrow-buffer = { version = "30.0.0", path = "../arrow-buffer" }
+# Cast is needed to work around https://github.com/apache/arrow-rs/issues/3389
+arrow-cast = { version = "30.0.0", path = "../arrow-cast" }
arrow-ipc = { version = "30.0.0", path = "../arrow-ipc" }
arrow-schema = { version = "30.0.0", path = "../arrow-schema" }
base64 = { version = "0.20", default-features = false, features = ["std"] }
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index 0e75ac7c0..753c40f2a 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -16,15 +16,12 @@
// under the License.
use crate::{
- flight_service_client::FlightServiceClient,
utils::flight_data_to_arrow_batch,
- FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
+ decode::FlightRecordBatchStream,
flight_service_client::FlightServiceClient,
+ FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
};
-use arrow_array::{ArrayRef, RecordBatch};
-use arrow_schema::Schema;
use bytes::Bytes;
-use futures::{future::ready, ready, stream, StreamExt};
-use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc,
task::Poll};
-use tonic::{metadata::MetadataMap, transport::Channel, Streaming};
+use futures::{future::ready, stream, StreamExt, TryStreamExt};
+use tonic::{metadata::MetadataMap, transport::Channel};
use crate::error::{FlightError, Result};
@@ -161,7 +158,7 @@ impl FlightClient {
/// Make a `DoGet` call to the server with the provided ticket,
/// returning a [`FlightRecordBatchStream`] for reading
- /// [`RecordBatch`]es.
+ /// [`RecordBatch`](arrow_array::RecordBatch)es.
///
/// # Example:
/// ```no_run
@@ -197,10 +194,17 @@ impl FlightClient {
pub async fn do_get(&mut self, ticket: Ticket) ->
Result<FlightRecordBatchStream> {
let request = self.make_request(ticket);
- let response = self.inner.do_get(request).await?.into_inner();
-
- let flight_data_stream = FlightDataStream::new(response);
- Ok(FlightRecordBatchStream::new(flight_data_stream))
+ let response_stream = self
+ .inner
+ .do_get(request)
+ .await?
+ .into_inner()
+ // convert to FlightError
+ .map_err(|e| e.into());
+
+ Ok(FlightRecordBatchStream::new_from_flight_data(
+ response_stream,
+ ))
}
/// Make a `GetFlightInfo` call to the server with the provided
@@ -268,300 +272,3 @@ impl FlightClient {
request
}
}
-
-/// A stream of [`RecordBatch`]es from from an Arrow Flight server.
-///
-/// To access the lower level Flight messages directly, consider
-/// calling [`Self::into_inner`] and using the [`FlightDataStream`]
-/// directly.
-#[derive(Debug)]
-pub struct FlightRecordBatchStream {
- inner: FlightDataStream,
- got_schema: bool,
-}
-
-impl FlightRecordBatchStream {
- pub fn new(inner: FlightDataStream) -> Self {
- Self {
- inner,
- got_schema: false,
- }
- }
-
- /// Has a message defining the schema been received yet?
- pub fn got_schema(&self) -> bool {
- self.got_schema
- }
-
- /// Consume self and return the wrapped [`FlightDataStream`]
- pub fn into_inner(self) -> FlightDataStream {
- self.inner
- }
-}
-impl futures::Stream for FlightRecordBatchStream {
- type Item = Result<RecordBatch>;
-
- /// Returns the next [`RecordBatch`] available in this stream, or `None` if
- /// there are no further results available.
- fn poll_next(
- mut self: Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Result<RecordBatch>>> {
- loop {
- let res = ready!(self.inner.poll_next_unpin(cx));
- match res {
- // Inner exhausted
- None => {
- return Poll::Ready(None);
- }
- Some(Err(e)) => {
- return Poll::Ready(Some(Err(e)));
- }
- // translate data
- Some(Ok(data)) => match data.payload {
- DecodedPayload::Schema(_) if self.got_schema => {
- return Poll::Ready(Some(Err(FlightError::protocol(
- "Unexpectedly saw multiple Schema messages in
FlightData stream",
- ))));
- }
- DecodedPayload::Schema(_) => {
- self.got_schema = true;
- // Need next message, poll inner again
- }
- DecodedPayload::RecordBatch(batch) => {
- return Poll::Ready(Some(Ok(batch)));
- }
- DecodedPayload::None => {
- // Need next message
- }
- },
- }
- }
- }
-}
-
-/// Wrapper around a stream of [`FlightData`] that handles the details
-/// of decoding low level Flight messages into [`Schema`] and
-/// [`RecordBatch`]es, including details such as dictionaries.
-///
-/// # Protocol Details
-///
-/// The client handles flight messages as followes:
-///
-/// - **None:** This message has no effect. This is useful to
-/// transmit metadata without any actual payload.
-///
-/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
-/// the decoded schema is returned.
-///
-/// - **Dictionary Batch:** A new dictionary for a given column is registered.
An existing
-/// dictionary for the same column will be overwritten. This
-/// message is NOT visible.
-///
-/// - **Record Batch:** Record batch is created based on the current
-/// schema and dictionaries. This fails if no schema was transmitted
-/// yet.
-///
-/// All other message types (at the time of writing: e.g. tensor and
-/// sparse tensor) lead to an error.
-///
-/// Example usecases
-///
-/// 1. Using this low level stream it is possible to receive a steam
-/// of RecordBatches in FlightData that have different schemas by
-/// handling multiple schema messages separately.
-#[derive(Debug)]
-pub struct FlightDataStream {
- /// Underlying data stream
- response: Streaming<FlightData>,
- /// Decoding state
- state: Option<FlightStreamState>,
- /// seen the end of the inner stream?
- done: bool,
-}
-
-impl FlightDataStream {
- /// Create a new wrapper around the stream of FlightData
- pub fn new(response: Streaming<FlightData>) -> Self {
- Self {
- state: None,
- response,
- done: false,
- }
- }
-
- /// Extracts flight data from the next message, updating decoding
- /// state as necessary.
- fn extract_message(&mut self, data: FlightData) ->
Result<Option<DecodedFlightData>> {
- use arrow_ipc::MessageHeader;
- let message =
arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| {
- FlightError::DecodeError(format!("Error decoding root message:
{e}"))
- })?;
-
- match message.header_type() {
- MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
- MessageHeader::Schema => {
- let schema = Schema::try_from(&data).map_err(|e| {
- FlightError::DecodeError(format!("Error decoding schema:
{e}"))
- })?;
-
- let schema = Arc::new(schema);
- let dictionaries_by_field = HashMap::new();
-
- self.state = Some(FlightStreamState {
- schema: Arc::clone(&schema),
- dictionaries_by_field,
- });
- Ok(Some(DecodedFlightData::new_schema(data, schema)))
- }
- MessageHeader::DictionaryBatch => {
- let state = if let Some(state) = self.state.as_mut() {
- state
- } else {
- return Err(FlightError::protocol(
- "Received DictionaryBatch prior to Schema",
- ));
- };
-
- let buffer: arrow_buffer::Buffer = data.data_body.into();
- let dictionary_batch =
- message.header_as_dictionary_batch().ok_or_else(|| {
- FlightError::protocol(
- "Could not get dictionary batch from
DictionaryBatch message",
- )
- })?;
-
- arrow_ipc::reader::read_dictionary(
- &buffer,
- dictionary_batch,
- &state.schema,
- &mut state.dictionaries_by_field,
- &message.version(),
- )
- .map_err(|e| {
- FlightError::DecodeError(format!(
- "Error decoding ipc dictionary: {e}"
- ))
- })?;
-
- // Updated internal state, but no decoded message
- Ok(None)
- }
- MessageHeader::RecordBatch => {
- let state = if let Some(state) = self.state.as_ref() {
- state
- } else {
- return Err(FlightError::protocol(
- "Received RecordBatch prior to Schema",
- ));
- };
-
- let batch = flight_data_to_arrow_batch(
- &data,
- Arc::clone(&state.schema),
- &state.dictionaries_by_field,
- )
- .map_err(|e| {
- FlightError::DecodeError(format!(
- "Error decoding ipc RecordBatch: {e}"
- ))
- })?;
-
- Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
- }
- other => {
- let name = other.variant_name().unwrap_or("UNKNOWN");
- Err(FlightError::protocol(format!("Unexpected message:
{name}")))
- }
- }
- }
-}
-
-impl futures::Stream for FlightDataStream {
- type Item = Result<DecodedFlightData>;
- /// Returns the result of decoding the next [`FlightData`] message
- /// from the server, or `None` if there are no further results
- /// available.
- fn poll_next(
- mut self: Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- if self.done {
- return Poll::Ready(None);
- }
- loop {
- let res = ready!(self.response.poll_next_unpin(cx));
-
- return Poll::Ready(match res {
- None => {
- self.done = true;
- None // inner is exhausted
- }
- Some(data) => Some(match data {
- Err(e) => Err(FlightError::Tonic(e)),
- Ok(data) => match self.extract_message(data) {
- Ok(Some(extracted)) => Ok(extracted),
- Ok(None) => continue, // Need next input message
- Err(e) => Err(e),
- },
- }),
- });
- }
- }
-}
-
-/// tracks the state needed to reconstruct [`RecordBatch`]es from a
-/// streaming flight response.
-#[derive(Debug)]
-struct FlightStreamState {
- schema: Arc<Schema>,
- dictionaries_by_field: HashMap<i64, ArrayRef>,
-}
-
-/// FlightData and the decoded payload (Schema, RecordBatch), if any
-#[derive(Debug)]
-pub struct DecodedFlightData {
- pub inner: FlightData,
- pub payload: DecodedPayload,
-}
-
-impl DecodedFlightData {
- pub fn new_none(inner: FlightData) -> Self {
- Self {
- inner,
- payload: DecodedPayload::None,
- }
- }
-
- pub fn new_schema(inner: FlightData, schema: Arc<Schema>) -> Self {
- Self {
- inner,
- payload: DecodedPayload::Schema(schema),
- }
- }
-
- pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
- Self {
- inner,
- payload: DecodedPayload::RecordBatch(batch),
- }
- }
-
- /// return the metadata field of the inner flight data
- pub fn app_metadata(&self) -> &[u8] {
- &self.inner.app_metadata
- }
-}
-
-/// The result of decoding [`FlightData`]
-#[derive(Debug)]
-pub enum DecodedPayload {
- /// None (no data was sent in the corresponding FlightData)
- None,
-
- /// A decoded Schema message
- Schema(Arc<Schema>),
-
- /// A decoded Record batch.
- RecordBatch(RecordBatch),
-}
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/decode.rs
similarity index 52%
copy from arrow-flight/src/client.rs
copy to arrow-flight/src/decode.rs
index 0e75ac7c0..cab52a434 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/decode.rs
@@ -15,275 +15,92 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{
- flight_service_client::FlightServiceClient,
utils::flight_data_to_arrow_batch,
- FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
-};
+use crate::{utils::flight_data_to_arrow_batch, FlightData};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::Schema;
use bytes::Bytes;
-use futures::{future::ready, ready, stream, StreamExt};
-use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc,
task::Poll};
-use tonic::{metadata::MetadataMap, transport::Channel, Streaming};
+use futures::{ready, stream::BoxStream, Stream, StreamExt};
+use std::{
+ collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc,
task::Poll,
+};
use crate::error::{FlightError, Result};
-/// A "Mid level" [Apache Arrow
Flight](https://arrow.apache.org/docs/format/Flight.html) client.
+/// Decodes a [Stream] of [`FlightData`] back into
+/// [`RecordBatch`]es. This can be used to decode the response from an
+/// Arrow Flight server
///
-/// [`FlightClient`] is intended as a convenience for interactions
-/// with Arrow Flight servers. For more direct control, such as access
-/// to the response headers, use [`FlightServiceClient`] directly
-/// via methods such as [`Self::inner`] or [`Self::into_inner`].
+/// # Note
+/// To access the lower level Flight messages (e.g. to access
+/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`]
+/// and use the [`FlightDataDecoder`] directly.
///
/// # Example:
/// ```no_run
-/// # async fn run() {
-/// # use arrow_flight::FlightClient;
+/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{
/// # use bytes::Bytes;
+/// // make a do_get request
+/// use arrow_flight::{
+/// error::Result,
+/// decode::FlightRecordBatchStream,
+/// Ticket,
+/// flight_service_client::FlightServiceClient
+/// };
/// use tonic::transport::Channel;
-/// let channel = Channel::from_static("http://localhost:1234")
-/// .connect()
-/// .await
-/// .expect("error connecting");
+/// use futures::stream::{StreamExt, TryStreamExt};
+///
+/// let client: FlightServiceClient<Channel> = // make client..
+/// # unimplemented!();
+///
+/// let request = tonic::Request::new(
+/// Ticket { ticket: Bytes::new() }
+/// );
///
-/// let mut client = FlightClient::new(channel);
+/// // Get a stream of FlightData;
+/// let flight_data_stream = client
+/// .do_get(request)
+/// .await?
+/// .into_inner();
///
-/// // Send 'Hi' bytes as the handshake request to the server
-/// let response = client
-/// .handshake(Bytes::from("Hi"))
-/// .await
-/// .expect("error handshaking");
+/// // Decode stream of FlightData to RecordBatches
+/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
+/// // convert tonic::Status to FlightError
+/// flight_data_stream.map_err(|e| e.into())
+/// );
///
-/// // Expect the server responded with 'Ho'
-/// assert_eq!(response, Bytes::from("Ho"));
+/// // Read back RecordBatches
+/// while let Some(batch) = record_batch_stream.next().await {
+/// match batch {
+/// Ok(batch) => { /* process batch */ },
+/// Err(e) => { /* handle error */ },
+/// };
+/// }
+///
+/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
-pub struct FlightClient {
- /// Optional grpc header metadata to include with each request
- metadata: MetadataMap,
-
- /// The inner client
- inner: FlightServiceClient<Channel>,
+pub struct FlightRecordBatchStream {
+ inner: FlightDataDecoder,
+ got_schema: bool,
}
-impl FlightClient {
- /// Creates a client client with the provided
[`Channel`](tonic::transport::Channel)
- pub fn new(channel: Channel) -> Self {
- Self::new_from_inner(FlightServiceClient::new(channel))
- }
-
- /// Creates a new higher level client with the provided lower level client
- pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
+impl FlightRecordBatchStream {
+ /// Create a new [`FlightRecordBatchStream`] from a decoded stream
+ pub fn new(inner: FlightDataDecoder) -> Self {
Self {
- metadata: MetadataMap::new(),
inner,
+ got_schema: false,
}
}
- /// Return a reference to gRPC metadata included with each request
- pub fn metadata(&self) -> &MetadataMap {
- &self.metadata
- }
-
- /// Return a reference to gRPC metadata included with each request
- ///
- /// These headers can be used, for example, to include
- /// authorization or other application specific headers.
- pub fn metadata_mut(&mut self) -> &mut MetadataMap {
- &mut self.metadata
- }
-
- /// Add the specified header with value to all subsequent
- /// requests. See [`Self::metadata_mut`] for fine grained control.
- pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
- let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
- .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
-
- let value = value
- .parse()
- .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
-
- // ignore previous value
- self.metadata.insert(key, value);
-
- Ok(())
- }
-
- /// Return a reference to the underlying tonic
- /// [`FlightServiceClient`]
- pub fn inner(&self) -> &FlightServiceClient<Channel> {
- &self.inner
- }
-
- /// Return a mutable reference to the underlying tonic
- /// [`FlightServiceClient`]
- pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
- &mut self.inner
- }
-
- /// Consume this client and return the underlying tonic
- /// [`FlightServiceClient`]
- pub fn into_inner(self) -> FlightServiceClient<Channel> {
- self.inner
- }
-
- /// Perform an Arrow Flight handshake with the server, sending
- /// `payload` as the [`HandshakeRequest`] payload and returning
- /// the [`HandshakeResponse`](crate::HandshakeResponse)
- /// bytes returned from the server
- ///
- /// See [`FlightClient`] docs for an example.
- pub async fn handshake(&mut self, payload: impl Into<Bytes>) ->
Result<Bytes> {
- let request = HandshakeRequest {
- protocol_version: 0,
- payload: payload.into(),
- };
-
- // apply headers, etc
- let request = self.make_request(stream::once(ready(request)));
-
- let mut response_stream =
self.inner.handshake(request).await?.into_inner();
-
- if let Some(response) = response_stream.next().await.transpose()? {
- // check if there is another response
- if response_stream.next().await.is_some() {
- return Err(FlightError::protocol(
- "Got unexpected second response from handshake",
- ));
- }
-
- Ok(response.payload)
- } else {
- Err(FlightError::protocol("No response from handshake"))
- }
- }
-
- /// Make a `DoGet` call to the server with the provided ticket,
- /// returning a [`FlightRecordBatchStream`] for reading
- /// [`RecordBatch`]es.
- ///
- /// # Example:
- /// ```no_run
- /// # async fn run() {
- /// # use bytes::Bytes;
- /// # use arrow_flight::FlightClient;
- /// # use arrow_flight::Ticket;
- /// # use arrow_array::RecordBatch;
- /// # use tonic::transport::Channel;
- /// # use futures::stream::TryStreamExt;
- /// # let channel = Channel::from_static("http://localhost:1234")
- /// # .connect()
- /// # .await
- /// # .expect("error connecting");
- /// # let ticket = Ticket { ticket: Bytes::from("foo") };
- /// let mut client = FlightClient::new(channel);
- ///
- /// // Invoke a do_get request on the server with a previously
- /// // received Ticket
- ///
- /// let response = client
- /// .do_get(ticket)
- /// .await
- /// .expect("error invoking do_get");
- ///
- /// // Use try_collect to get the RecordBatches from the server
- /// let batches: Vec<RecordBatch> = response
- /// .try_collect()
- /// .await
- /// .expect("no stream errors");
- /// # }
- /// ```
- pub async fn do_get(&mut self, ticket: Ticket) ->
Result<FlightRecordBatchStream> {
- let request = self.make_request(ticket);
-
- let response = self.inner.do_get(request).await?.into_inner();
-
- let flight_data_stream = FlightDataStream::new(response);
- Ok(FlightRecordBatchStream::new(flight_data_stream))
- }
-
- /// Make a `GetFlightInfo` call to the server with the provided
- /// [`FlightDescriptor`] and return the [`FlightInfo`] from the
- /// server. The [`FlightInfo`] can be used with [`Self::do_get`]
- /// to retrieve the requested batches.
- ///
- /// # Example:
- /// ```no_run
- /// # async fn run() {
- /// # use arrow_flight::FlightClient;
- /// # use arrow_flight::FlightDescriptor;
- /// # use tonic::transport::Channel;
- /// # let channel = Channel::from_static("http://localhost:1234")
- /// # .connect()
- /// # .await
- /// # .expect("error connecting");
- /// let mut client = FlightClient::new(channel);
- ///
- /// // Send a 'CMD' request to the server
- /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
- /// let flight_info = client
- /// .get_flight_info(request)
- /// .await
- /// .expect("error handshaking");
- ///
- /// // retrieve the first endpoint from the returned flight info
- /// let ticket = flight_info
- /// .endpoint[0]
- /// // Extract the ticket
- /// .ticket
- /// .clone()
- /// .expect("expected ticket");
- ///
- /// // Retrieve the corresponding RecordBatch stream with do_get
- /// let data = client
- /// .do_get(ticket)
- /// .await
- /// .expect("error fetching data");
- /// # }
- /// ```
- pub async fn get_flight_info(
- &mut self,
- descriptor: FlightDescriptor,
- ) -> Result<FlightInfo> {
- let request = self.make_request(descriptor);
-
- let response = self.inner.get_flight_info(request).await?.into_inner();
- Ok(response)
- }
-
- // TODO other methods
- // list_flights
- // get_schema
- // do_put
- // do_action
- // list_actions
- // do_exchange
-
- /// return a Request, adding any configured metadata
- fn make_request<T>(&self, t: T) -> tonic::Request<T> {
- // Pass along metadata
- let mut request = tonic::Request::new(t);
- *request.metadata_mut() = self.metadata.clone();
- request
- }
-}
-
-/// A stream of [`RecordBatch`]es from from an Arrow Flight server.
-///
-/// To access the lower level Flight messages directly, consider
-/// calling [`Self::into_inner`] and using the [`FlightDataStream`]
-/// directly.
-#[derive(Debug)]
-pub struct FlightRecordBatchStream {
- inner: FlightDataStream,
- got_schema: bool,
-}
-
-impl FlightRecordBatchStream {
- pub fn new(inner: FlightDataStream) -> Self {
+ /// Create a new [`FlightRecordBatchStream`] from a stream of
[`FlightData`]
+ pub fn new_from_flight_data<S>(inner: S) -> Self
+ where
+ S: Stream<Item = Result<FlightData>> + Send + 'static,
+ {
Self {
- inner,
+ inner: FlightDataDecoder::new(inner),
got_schema: false,
}
}
@@ -293,8 +110,8 @@ impl FlightRecordBatchStream {
self.got_schema
}
- /// Consume self and return the wrapped [`FlightDataStream`]
- pub fn into_inner(self) -> FlightDataStream {
+ /// Consume self and return the wrapped [`FlightDataDecoder`]
+ pub fn into_inner(self) -> FlightDataDecoder {
self.inner
}
}
@@ -370,22 +187,34 @@ impl futures::Stream for FlightRecordBatchStream {
/// 1. Using this low level stream it is possible to receive a steam
/// of RecordBatches in FlightData that have different schemas by
/// handling multiple schema messages separately.
-#[derive(Debug)]
-pub struct FlightDataStream {
+pub struct FlightDataDecoder {
/// Underlying data stream
- response: Streaming<FlightData>,
+ response: BoxStream<'static, Result<FlightData>>,
/// Decoding state
state: Option<FlightStreamState>,
- /// seen the end of the inner stream?
+ /// Seen the end of the inner stream?
done: bool,
}
-impl FlightDataStream {
- /// Create a new wrapper around the stream of FlightData
- pub fn new(response: Streaming<FlightData>) -> Self {
+impl Debug for FlightDataDecoder {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("FlightDataDecoder")
+ .field("response", &"<stream>")
+ .field("state", &self.state)
+ .field("done", &self.done)
+ .finish()
+ }
+}
+
+impl FlightDataDecoder {
+ /// Create a new wrapper around the stream of [`FlightData`]
+ pub fn new<S>(response: S) -> Self
+ where
+ S: Stream<Item = Result<FlightData>> + Send + 'static,
+ {
Self {
state: None,
- response,
+ response: response.boxed(),
done: false,
}
}
@@ -477,7 +306,7 @@ impl FlightDataStream {
}
}
-impl futures::Stream for FlightDataStream {
+impl futures::Stream for FlightDataDecoder {
type Item = Result<DecodedFlightData>;
/// Returns the result of decoding the next [`FlightData`] message
/// from the server, or `None` if there are no further results
@@ -498,7 +327,7 @@ impl futures::Stream for FlightDataStream {
None // inner is exhausted
}
Some(data) => Some(match data {
- Err(e) => Err(FlightError::Tonic(e)),
+ Err(e) => Err(e),
Ok(data) => match self.extract_message(data) {
Ok(Some(extracted)) => Ok(extracted),
Ok(None) => continue, // Need next input message
@@ -548,8 +377,8 @@ impl DecodedFlightData {
}
/// return the metadata field of the inner flight data
- pub fn app_metadata(&self) -> &[u8] {
- &self.inner.app_metadata
+ pub fn app_metadata(&self) -> Bytes {
+ self.inner.app_metadata.clone()
}
}
diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs
new file mode 100644
index 000000000..7c339b67d
--- /dev/null
+++ b/arrow-flight/src/encode.rs
@@ -0,0 +1,511 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
+
+use crate::{error::Result, FlightData, SchemaAsIpc};
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
+use arrow_schema::{DataType, Field, Schema, SchemaRef};
+use bytes::Bytes;
+use futures::{ready, stream::BoxStream, Stream, StreamExt};
+
+/// Creates a [`Stream`](futures::Stream) of [`FlightData`]s from a
+/// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>.
+///
+/// This can be used to implement [`FlightService::do_get`] in an
+/// Arrow Flight implementation;
+///
+/// # Caveats
+/// 1. [`DictionaryArray`](arrow_array::array::DictionaryArray)s
+/// are converted to their underlying types prior to transport, due to
+/// <https://github.com/apache/arrow-rs/issues/3389>.
+///
+/// # Example
+/// ```no_run
+/// # use std::sync::Arc;
+/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
+/// # async fn f() {
+/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
+/// # let record_batch = RecordBatch::try_from_iter(vec![
+/// # ("a", Arc::new(c1) as ArrayRef)
+/// # ])
+/// # .expect("cannot create record batch");
+/// use arrow_flight::encode::FlightDataEncoderBuilder;
+///
+/// // Get an input stream of Result<RecordBatch, FlightError>
+/// let input_stream = futures::stream::iter(vec![Ok(record_batch)]);
+///
+/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
+/// let flight_data_stream = FlightDataEncoderBuilder::new()
+/// .build(input_stream);
+///
+/// // Create a tonic `Response` that can be returned from a Flight server
+/// let response = tonic::Response::new(flight_data_stream);
+/// # }
+/// ```
+///
+/// [`FlightService::do_get`]:
crate::flight_service_server::FlightService::do_get
+/// [`FlightError`]: crate::error::FlightError
+#[derive(Debug)]
+pub struct FlightDataEncoderBuilder {
+ /// The maximum message size (see details on
[`Self::with_max_message_size`]).
+ max_batch_size: usize,
+ /// Ipc writer options
+ options: IpcWriteOptions,
+ /// Metadata to add to the schema message
+ app_metadata: Bytes,
+}
+
+/// Default target size for record batches to send.
+///
+/// Note this value would normally be 4MB, but the size calculation is
+/// somewhat inexact, so we set it to 2MB.
+pub const GRPC_TARGET_MAX_BATCH_SIZE: usize = 2097152;
+
+impl Default for FlightDataEncoderBuilder {
+ fn default() -> Self {
+ Self {
+ max_batch_size: GRPC_TARGET_MAX_BATCH_SIZE,
+ options: IpcWriteOptions::default(),
+ app_metadata: Bytes::new(),
+ }
+ }
+}
+
+impl FlightDataEncoderBuilder {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// Set the (approximate) maximum encoded [`RecordBatch`] size to
+ /// limit the gRPC message size. Defaults to 2MB.
+ ///
+ /// The encoder splits up [`RecordBatch`]s (preserving order) to
+ /// limit individual messages to approximately this size. The size
+ /// is approximate because there additional encoding overhead on
+ /// top of the underlying data itself.
+ ///
+ pub fn with_max_message_size(mut self, max_batch_size: usize) -> Self {
+ self.max_batch_size = max_batch_size;
+ self
+ }
+
+ /// Specify application specific metadata included in the
+ /// [`FlightData::app_metadata`] field of the the first Schema
+ /// message
+ pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
+ self.app_metadata = app_metadata;
+ self
+ }
+
+ /// Set the [`IpcWriteOptions`] used to encode the [`RecordBatch`]es for
transport.
+ pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
+ self.options = options;
+ self
+ }
+
+ /// Return a [`Stream`](futures::Stream) of [`FlightData`],
+ /// consuming self. More details on [`FlightDataEncoder`]
+ pub fn build<S>(self, input: S) -> FlightDataEncoder
+ where
+ S: Stream<Item = Result<RecordBatch>> + Send + 'static,
+ {
+ let Self {
+ max_batch_size,
+ options,
+ app_metadata,
+ } = self;
+
+ FlightDataEncoder::new(input.boxed(), max_batch_size, options,
app_metadata)
+ }
+}
+
+/// Stream that encodes a stream of record batches to flight data.
+///
+/// See [`FlightDataEncoderBuilder`] for details and example.
+pub struct FlightDataEncoder {
+ /// Input stream
+ inner: BoxStream<'static, Result<RecordBatch>>,
+ /// schema, set after the first batch
+ schema: Option<SchemaRef>,
+ /// Max size of batches to encode
+ max_batch_size: usize,
+ /// do the encoding / tracking of dictionaries
+ encoder: FlightIpcEncoder,
+ /// optional metadata to add to schema FlightData
+ app_metadata: Option<Bytes>,
+ /// data queued up to send but not yet sent
+ queue: VecDeque<FlightData>,
+ /// Is this strema done (inner is empty or errored)
+ done: bool,
+}
+
+impl FlightDataEncoder {
+ fn new(
+ inner: BoxStream<'static, Result<RecordBatch>>,
+ max_batch_size: usize,
+ options: IpcWriteOptions,
+ app_metadata: Bytes,
+ ) -> Self {
+ Self {
+ inner,
+ schema: None,
+ max_batch_size,
+ encoder: FlightIpcEncoder::new(options),
+ app_metadata: Some(app_metadata),
+ queue: VecDeque::new(),
+ done: false,
+ }
+ }
+
+ /// Place the `FlightData` in the queue to send
+ fn queue_message(&mut self, data: FlightData) {
+ self.queue.push_back(data);
+ }
+
+ /// Place the `FlightData` in the queue to send
+ fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
+ for data in datas {
+ self.queue_message(data)
+ }
+ }
+
+ /// Encodes batch into one or more `FlightData` messages in self.queue
+ fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
+ let schema = match &self.schema {
+ Some(schema) => schema.clone(),
+ None => {
+ let batch_schema = batch.schema();
+ // The first message is the schema message, and all
+ // batches have the same schema
+ let schema =
Arc::new(prepare_schema_for_flight(&batch_schema));
+ let mut schema_flight_data =
self.encoder.encode_schema(&schema);
+
+ // attach any metadata requested
+ if let Some(app_metadata) = self.app_metadata.take() {
+ schema_flight_data.app_metadata = app_metadata;
+ }
+ self.queue_message(schema_flight_data);
+ // remember schema
+ self.schema = Some(schema.clone());
+ schema
+ }
+ };
+
+ // encode the batch
+ let batch = prepare_batch_for_flight(&batch, schema)?;
+
+ for batch in split_batch_for_grpc_response(batch, self.max_batch_size)
{
+ let (flight_dictionaries, flight_batch) =
+ self.encoder.encode_batch(&batch)?;
+
+ self.queue_messages(flight_dictionaries);
+ self.queue_message(flight_batch);
+ }
+
+ Ok(())
+ }
+}
+
+impl Stream for FlightDataEncoder {
+ type Item = Result<FlightData>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ loop {
+ if self.done && self.queue.is_empty() {
+ return Poll::Ready(None);
+ }
+
+ // Any messages queued to send?
+ if let Some(data) = self.queue.pop_front() {
+ return Poll::Ready(Some(Ok(data)));
+ }
+
+ // Get next batch
+ let batch = ready!(self.inner.poll_next_unpin(cx));
+
+ match batch {
+ None => {
+ // inner is done
+ self.done = true;
+ // queue must also be empty so we are done
+ assert!(self.queue.is_empty());
+ return Poll::Ready(None);
+ }
+ Some(Err(e)) => {
+ // error from inner
+ self.done = true;
+ self.queue.clear();
+ return Poll::Ready(Some(Err(e)));
+ }
+ Some(Ok(batch)) => {
+ // had data, encode into the queue
+ if let Err(e) = self.encode_batch(batch) {
+ self.done = true;
+ self.queue.clear();
+ return Poll::Ready(Some(Err(e)));
+ }
+ }
+ }
+ }
+ }
+}
+
+/// Prepare an arrow Schema for transport over the Arrow Flight protocol
+///
+/// Convert dictionary types to underlying types
+///
+/// See hydrate_dictionary for more information
+fn prepare_schema_for_flight(schema: &Schema) -> Schema {
+ let fields = schema
+ .fields()
+ .iter()
+ .map(|field| match field.data_type() {
+ DataType::Dictionary(_, value_type) => Field::new(
+ field.name(),
+ value_type.as_ref().clone(),
+ field.is_nullable(),
+ )
+ .with_metadata(field.metadata().clone()),
+ _ => field.clone(),
+ })
+ .collect();
+
+ Schema::new(fields)
+}
+
+/// Split [`RecordBatch`] so it hopefully fits into a gRPC response.
+///
+/// Data is zero-copy sliced into batches.
+///
+/// Note: this method does not take into account already sliced
+/// arrays: <https://github.com/apache/arrow-rs/issues/3407>
+fn split_batch_for_grpc_response(
+ batch: RecordBatch,
+ max_batch_size: usize,
+) -> Vec<RecordBatch> {
+ let size = batch
+ .columns()
+ .iter()
+ .map(|col| col.get_buffer_memory_size())
+ .sum::<usize>();
+
+ let n_batches =
+ (size / max_batch_size + usize::from(size % max_batch_size !=
0)).max(1);
+ let rows_per_batch = (batch.num_rows() / n_batches).max(1);
+ let mut out = Vec::with_capacity(n_batches + 1);
+
+ let mut offset = 0;
+ while offset < batch.num_rows() {
+ let length = (rows_per_batch).min(batch.num_rows() - offset);
+ out.push(batch.slice(offset, length));
+
+ offset += length;
+ }
+
+ out
+}
+
+/// The data needed to encode a stream of flight data, holding on to
+/// shared Dictionaries.
+///
+/// TODO: at allow dictionaries to be flushed / avoid building them
+///
+/// TODO limit on the number of dictionaries???
+struct FlightIpcEncoder {
+ options: IpcWriteOptions,
+ data_gen: IpcDataGenerator,
+ dictionary_tracker: DictionaryTracker,
+}
+
+impl FlightIpcEncoder {
+ fn new(options: IpcWriteOptions) -> Self {
+ let error_on_replacement = true;
+ Self {
+ options,
+ data_gen: IpcDataGenerator::default(),
+ dictionary_tracker: DictionaryTracker::new(error_on_replacement),
+ }
+ }
+
+ /// Encode a schema as a FlightData
+ fn encode_schema(&self, schema: &Schema) -> FlightData {
+ SchemaAsIpc::new(schema, &self.options).into()
+ }
+
+ /// Convert a `RecordBatch` to a Vec of `FlightData` representing
+ /// dictionaries and a `FlightData` representing the batch
+ fn encode_batch(
+ &mut self,
+ batch: &RecordBatch,
+ ) -> Result<(Vec<FlightData>, FlightData)> {
+ let (encoded_dictionaries, encoded_batch) =
self.data_gen.encoded_batch(
+ batch,
+ &mut self.dictionary_tracker,
+ &self.options,
+ )?;
+
+ let flight_dictionaries =
+ encoded_dictionaries.into_iter().map(Into::into).collect();
+ let flight_batch = encoded_batch.into();
+
+ Ok((flight_dictionaries, flight_batch))
+ }
+}
+
+/// Prepares a RecordBatch for transport over the Arrow Flight protocol
+///
+/// This means:
+///
+/// 1. Hydrates any dictionaries to its underlying type. See
+/// hydrate_dictionary for more information.
+///
+fn prepare_batch_for_flight(
+ batch: &RecordBatch,
+ schema: SchemaRef,
+) -> Result<RecordBatch> {
+ let columns = batch
+ .columns()
+ .iter()
+ .map(hydrate_dictionary)
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(RecordBatch::try_new(schema, columns)?)
+}
+
+/// Hydrates a dictionary to its underlying type
+///
+/// An IPC response, streaming or otherwise, defines its schema up front
+/// which defines the mapping from dictionary IDs. It then sends these
+/// dictionaries over the wire.
+///
+/// This requires identifying the different dictionaries in use, assigning
+/// them IDs, and sending new dictionaries, delta or otherwise, when needed
+///
+/// See also:
+/// * <https://github.com/apache/arrow-rs/issues/1206>
+///
+/// For now we just hydrate the dictionaries to their underlying type
+fn hydrate_dictionary(array: &ArrayRef) -> Result<ArrayRef> {
+ let arr = if let DataType::Dictionary(_, value) = array.data_type() {
+ arrow_cast::cast(array, value)?
+ } else {
+ Arc::clone(array)
+ };
+ Ok(arr)
+}
+
+#[cfg(test)]
+mod tests {
+ use arrow::{
+ array::{UInt32Array, UInt8Array},
+ compute::concat_batches,
+ };
+
+ use super::*;
+
+ #[test]
+ /// ensure only the batch's used data (not the allocated data) is sent
+ /// <https://github.com/apache/arrow-rs/issues/208>
+ fn test_encode_flight_data() {
+ let options = arrow::ipc::writer::IpcWriteOptions::default();
+ let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
+
+ let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as
ArrayRef)])
+ .expect("cannot create record batch");
+ let schema = batch.schema();
+
+ let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
+
+ let big_batch = batch.slice(0, batch.num_rows() - 1);
+ let optimized_big_batch =
+ prepare_batch_for_flight(&big_batch, Arc::clone(&schema))
+ .expect("failed to optimize");
+ let (_, optimized_big_flight_batch) =
+ make_flight_data(&optimized_big_batch, &options);
+
+ assert_eq!(
+ baseline_flight_batch.data_body.len(),
+ optimized_big_flight_batch.data_body.len()
+ );
+
+ let small_batch = batch.slice(0, 1);
+ let optimized_small_batch =
+ prepare_batch_for_flight(&small_batch, Arc::clone(&schema))
+ .expect("failed to optimize");
+ let (_, optimized_small_flight_batch) =
+ make_flight_data(&optimized_small_batch, &options);
+
+ assert!(
+ baseline_flight_batch.data_body.len()
+ > optimized_small_flight_batch.data_body.len()
+ );
+ }
+
+ pub fn make_flight_data(
+ batch: &RecordBatch,
+ options: &IpcWriteOptions,
+ ) -> (Vec<FlightData>, FlightData) {
+ let data_gen = IpcDataGenerator::default();
+ let mut dictionary_tracker = DictionaryTracker::new(false);
+
+ let (encoded_dictionaries, encoded_batch) = data_gen
+ .encoded_batch(batch, &mut dictionary_tracker, options)
+ .expect("DictionaryTracker configured above to not error on
replacement");
+
+ let flight_dictionaries =
+ encoded_dictionaries.into_iter().map(Into::into).collect();
+ let flight_batch = encoded_batch.into();
+
+ (flight_dictionaries, flight_batch)
+ }
+
+ #[test]
+ fn test_split_batch_for_grpc_response() {
+ let max_batch_size = 1024;
+
+ // no split
+ let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
+ let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as
ArrayRef)])
+ .expect("cannot create record batch");
+ let split = split_batch_for_grpc_response(batch.clone(),
max_batch_size);
+ assert_eq!(split.len(), 1);
+ assert_eq!(batch, split[0]);
+
+ // split once
+ let n_rows = max_batch_size + 1;
+ assert!(n_rows % 2 == 1, "should be an odd number");
+ let c =
+ UInt8Array::from((0..n_rows).map(|i| (i % 256) as
u8).collect::<Vec<_>>());
+ let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as
ArrayRef)])
+ .expect("cannot create record batch");
+ let split = split_batch_for_grpc_response(batch.clone(),
max_batch_size);
+ assert_eq!(split.len(), 3);
+ assert_eq!(
+ split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
+ n_rows
+ );
+ assert_eq!(concat_batches(&batch.schema(), &split).unwrap(), batch);
+ }
+
+ // test sending record batches
+ // test sending record batches with multiple different dictionaries
+}
diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs
index fbb9efa44..11e0ae5c9 100644
--- a/arrow-flight/src/error.rs
+++ b/arrow-flight/src/error.rs
@@ -15,9 +15,13 @@
// specific language governing permissions and limitations
// under the License.
+use arrow_schema::ArrowError;
+
/// Errors for the Apache Arrow Flight crate
#[derive(Debug)]
pub enum FlightError {
+ /// Underlying arrow error
+ Arrow(ArrowError),
/// Returned when functionality is not yet available.
NotYetImplemented(String),
/// Error from the underlying tonic library
@@ -56,4 +60,25 @@ impl From<tonic::Status> for FlightError {
}
}
+impl From<ArrowError> for FlightError {
+ fn from(value: ArrowError) -> Self {
+ Self::Arrow(value)
+ }
+}
+
+// default conversion from FlightError to tonic treats everything
+// other than `Status` as an internal error
+impl From<FlightError> for tonic::Status {
+ fn from(value: FlightError) -> Self {
+ match value {
+ FlightError::Arrow(e) => tonic::Status::internal(e.to_string()),
+ FlightError::NotYetImplemented(e) => tonic::Status::internal(e),
+ FlightError::Tonic(status) => status,
+ FlightError::ProtocolError(e) => tonic::Status::internal(e),
+ FlightError::DecodeError(e) => tonic::Status::internal(e),
+ FlightError::ExternalError(e) =>
tonic::Status::internal(e.to_string()),
+ }
+ }
+}
+
pub type Result<T> = std::result::Result<T, FlightError>;
diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs
index f30cb5484..c2da58eb5 100644
--- a/arrow-flight/src/lib.rs
+++ b/arrow-flight/src/lib.rs
@@ -71,10 +71,18 @@ pub mod flight_service_server {
pub use gen::flight_service_server::FlightServiceServer;
}
-/// Mid Level [`FlightClient`] for
+/// Mid Level [`FlightClient`]
pub mod client;
pub use client::FlightClient;
+/// Decoder to create [`RecordBatch`](arrow_array::RecordBatch) streams from
[`FlightData`] streams.
+/// See [`FlightRecordBatchStream`](decode::FlightRecordBatchStream).
+pub mod decode;
+
+/// Encoder to create [`FlightData`] streams from
[`RecordBatch`](arrow_array::RecordBatch) streams.
+/// See [`FlightDataEncoderBuilder`](encode::FlightDataEncoderBuilder).
+pub mod encode;
+
/// Common error types
pub mod error;
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index 5bc1062f0..c471294d7 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -20,20 +20,21 @@
mod common {
pub mod server;
}
+use arrow_array::{RecordBatch, UInt64Array};
use arrow_flight::{
error::FlightError, FlightClient, FlightDescriptor, FlightInfo,
HandshakeRequest,
- HandshakeResponse,
+ HandshakeResponse, Ticket,
};
use bytes::Bytes;
use common::server::TestFlightServer;
-use futures::Future;
+use futures::{Future, TryStreamExt};
use tokio::{net::TcpListener, task::JoinHandle};
use tonic::{
transport::{Channel, Uri},
Status,
};
-use std::{net::SocketAddr, time::Duration};
+use std::{net::SocketAddr, sync::Arc, time::Duration};
const DEFAULT_TIMEOUT_SECONDS: u64 = 30;
@@ -173,7 +174,90 @@ async fn test_get_flight_info_metadata() {
// TODO more negative tests (like if there are endpoints defined, etc)
-// TODO test for do_get
+#[tokio::test]
+async fn test_do_get() {
+ do_test(|test_server, mut client| async move {
+ let ticket = Ticket {
+ ticket: Bytes::from("my awesome flight ticket"),
+ };
+
+ let batch = RecordBatch::try_from_iter(vec![(
+ "col",
+ Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
+ )])
+ .unwrap();
+
+ let response = vec![Ok(batch.clone())];
+ test_server.set_do_get_response(response);
+ let response_stream = client
+ .do_get(ticket.clone())
+ .await
+ .expect("error making request");
+
+ let expected_response = vec![batch];
+ let response: Vec<_> = response_stream
+ .try_collect()
+ .await
+ .expect("Error streaming data");
+
+ assert_eq!(response, expected_response);
+ assert_eq!(test_server.take_do_get_request(), Some(ticket));
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_get_error() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo", "bar").unwrap();
+ let ticket = Ticket {
+ ticket: Bytes::from("my awesome flight ticket"),
+ };
+
+ let response = client.do_get(ticket.clone()).await.unwrap_err();
+
+ let e = Status::internal("No do_get response configured");
+ expect_status(response, e);
+ // server still got the request
+ assert_eq!(test_server.take_do_get_request(), Some(ticket));
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_get_error_in_record_batch_stream() {
+ do_test(|test_server, mut client| async move {
+ let ticket = Ticket {
+ ticket: Bytes::from("my awesome flight ticket"),
+ };
+
+ let batch = RecordBatch::try_from_iter(vec![(
+ "col",
+ Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
+ )])
+ .unwrap();
+
+ let e = Status::data_loss("she's dead jim");
+
+ let expected_response = vec![Ok(batch),
Err(FlightError::Tonic(e.clone()))];
+
+ test_server.set_do_get_response(expected_response);
+
+ let response_stream = client
+ .do_get(ticket.clone())
+ .await
+ .expect("error making request");
+
+ let response: Result<Vec<_>, FlightError> =
response_stream.try_collect().await;
+
+ let response = response.unwrap_err();
+ expect_status(response, e);
+ // server still got the request
+ assert_eq!(test_server.take_do_get_request(), Some(ticket));
+ })
+ .await;
+}
/// Runs the future returned by the function, passing it a test server and
client
async fn do_test<F, Fut>(f: F)
diff --git a/arrow-flight/tests/common/server.rs
b/arrow-flight/tests/common/server.rs
index f1cb140b6..45f81b189 100644
--- a/arrow-flight/tests/common/server.rs
+++ b/arrow-flight/tests/common/server.rs
@@ -17,10 +17,13 @@
use std::sync::{Arc, Mutex};
-use futures::stream::BoxStream;
+use arrow_array::RecordBatch;
+use futures::{stream::BoxStream, TryStreamExt};
use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming};
use arrow_flight::{
+ encode::FlightDataEncoderBuilder,
+ error::FlightError,
flight_service_server::{FlightService, FlightServiceServer},
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor,
FlightInfo,
HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
@@ -80,6 +83,21 @@ impl TestFlightServer {
.take()
}
+ /// Specify the response returned from the next call to `do_get`
+ pub fn set_do_get_response(&self, response: Vec<Result<RecordBatch,
FlightError>>) {
+ let mut state = self.state.lock().expect("mutex not poisoned");
+ state.do_get_response.replace(response);
+ }
+
+ /// Take and return last do_get request send to the server,
+ pub fn take_do_get_request(&self) -> Option<Ticket> {
+ self.state
+ .lock()
+ .expect("mutex not poisoned")
+ .do_get_request
+ .take()
+ }
+
/// Returns the last metadata from a request received by the server
pub fn take_last_request_metadata(&self) -> Option<MetadataMap> {
self.state
@@ -97,7 +115,7 @@ impl TestFlightServer {
}
}
-/// mutable state for the TestFlightSwrver
+/// mutable state for the TestFlightServer, captures requests and provides
responses
#[derive(Debug, Default)]
struct State {
/// The last handshake request that was received
@@ -108,6 +126,10 @@ struct State {
pub get_flight_info_request: Option<FlightDescriptor>,
/// the next response to return from `get_flight_info`
pub get_flight_info_response: Option<Result<FlightInfo, Status>>,
+ /// The last do_get request received
+ pub do_get_request: Option<Ticket>,
+ /// The next response returned from `do_get`
+ pub do_get_response: Option<Vec<Result<RecordBatch, FlightError>>>,
/// The last request headers received
pub last_request_metadata: Option<MetadataMap>,
}
@@ -177,9 +199,25 @@ impl FlightService for TestFlightServer {
async fn do_get(
&self,
- _request: Request<Ticket>,
+ request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
- Err(Status::unimplemented("Implement do_get"))
+ self.save_metadata(&request);
+ let mut state = self.state.lock().expect("mutex not poisoned");
+
+ state.do_get_request = Some(request.into_inner());
+
+ let batches: Vec<_> = state
+ .do_get_response
+ .take()
+ .ok_or_else(|| Status::internal("No do_get response configured"))?;
+
+ let batch_stream = futures::stream::iter(batches);
+
+ let stream = FlightDataEncoderBuilder::new()
+ .build(batch_stream)
+ .map_err(|e| e.into());
+
+ Ok(Response::new(Box::pin(stream) as _))
}
async fn do_put(
diff --git a/arrow-flight/tests/encode_decode.rs
b/arrow-flight/tests/encode_decode.rs
new file mode 100644
index 000000000..45b8c0bf5
--- /dev/null
+++ b/arrow-flight/tests/encode_decode.rs
@@ -0,0 +1,453 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Tests for round trip encoding / decoding
+
+use std::sync::Arc;
+
+use arrow::{compute::concat_batches, datatypes::Int32Type};
+use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch,
UInt8Array};
+use arrow_flight::{
+ decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream},
+ encode::FlightDataEncoderBuilder,
+ error::FlightError,
+};
+use arrow_schema::{DataType, Field, Schema, SchemaRef};
+use bytes::Bytes;
+use futures::{StreamExt, TryStreamExt};
+
+#[tokio::test]
+async fn test_empty() {
+ roundtrip(vec![]).await;
+}
+
+#[tokio::test]
+async fn test_empty_batch() {
+ let batch = make_primative_batch(5);
+ let empty = RecordBatch::new_empty(batch.schema());
+ roundtrip(vec![empty]).await;
+}
+
+#[tokio::test]
+async fn test_error() {
+ let input_batch_stream =
+
futures::stream::iter(vec![Err(FlightError::NotYetImplemented("foo".into()))]);
+
+ let encoder = FlightDataEncoderBuilder::default();
+ let encode_stream = encoder.build(input_batch_stream);
+
+ let decode_stream =
FlightRecordBatchStream::new_from_flight_data(encode_stream);
+ let result: Result<Vec<_>, _> = decode_stream.try_collect().await;
+
+ let result = result.unwrap_err();
+ assert_eq!(result.to_string(), r#"NotYetImplemented("foo")"#);
+}
+
+#[tokio::test]
+async fn test_primative_one() {
+ roundtrip(vec![make_primative_batch(5)]).await;
+}
+
+#[tokio::test]
+async fn test_primative_many() {
+ roundtrip(vec![
+ make_primative_batch(1),
+ make_primative_batch(7),
+ make_primative_batch(32),
+ ])
+ .await;
+}
+
+#[tokio::test]
+async fn test_primative_empty() {
+ let batch = make_primative_batch(5);
+ let empty = RecordBatch::new_empty(batch.schema());
+
+ roundtrip(vec![batch, empty]).await;
+}
+
+#[tokio::test]
+async fn test_dictionary_one() {
+ roundtrip_dictionary(vec![make_dictionary_batch(5)]).await;
+}
+
+#[tokio::test]
+async fn test_dictionary_many() {
+ roundtrip_dictionary(vec![
+ make_dictionary_batch(5),
+ make_dictionary_batch(9),
+ make_dictionary_batch(5),
+ make_dictionary_batch(5),
+ ])
+ .await;
+}
+
+#[tokio::test]
+async fn test_app_metadata() {
+ let input_batch_stream =
futures::stream::iter(vec![Ok(make_primative_batch(78))]);
+
+ let app_metadata = Bytes::from("My Metadata");
+ let encoder =
FlightDataEncoderBuilder::default().with_metadata(app_metadata.clone());
+
+ let encode_stream = encoder.build(input_batch_stream);
+
+ // use lower level stream to get access to app metadata
+ let decode_stream =
+
FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner();
+
+ let mut messages: Vec<_> =
decode_stream.try_collect().await.expect("encode fails");
+
+ println!("{messages:#?}");
+
+ // expect that the app metadata made it through on the schema message
+ assert_eq!(messages.len(), 2);
+ let message2 = messages.pop().unwrap();
+ let message1 = messages.pop().unwrap();
+
+ assert_eq!(message1.app_metadata(), app_metadata);
+ assert!(matches!(message1.payload, DecodedPayload::Schema(_)));
+
+ // but not on the data
+ assert_eq!(message2.app_metadata(), Bytes::new());
+ assert!(matches!(message2.payload, DecodedPayload::RecordBatch(_)));
+}
+
+#[tokio::test]
+async fn test_max_message_size() {
+ let input_batch_stream =
futures::stream::iter(vec![Ok(make_primative_batch(5))]);
+
+ // 5 input rows, with a very small limit should result in 5 batch messages
+ let encoder = FlightDataEncoderBuilder::default().with_max_message_size(1);
+
+ let encode_stream = encoder.build(input_batch_stream);
+
+ // use lower level stream to get access to app metadata
+ let decode_stream =
+
FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner();
+
+ let messages: Vec<_> = decode_stream.try_collect().await.expect("encode
fails");
+
+ println!("{messages:#?}");
+
+ assert_eq!(messages.len(), 6);
+ assert!(matches!(messages[0].payload, DecodedPayload::Schema(_)));
+ for message in messages.iter().skip(1) {
+ assert!(matches!(message.payload, DecodedPayload::RecordBatch(_)));
+ }
+}
+
+#[tokio::test]
+async fn test_max_message_size_fuzz() {
+ // send through batches of varying sizes with various max
+ // batch sizes and ensure the data gets through ok
+ let input = vec![
+ make_primative_batch(123),
+ make_primative_batch(17),
+ make_primative_batch(201),
+ make_primative_batch(2),
+ make_primative_batch(1),
+ make_primative_batch(11),
+ make_primative_batch(127),
+ ];
+
+ for max_message_size in [10, 1024, 2048, 6400, 3211212] {
+ let encoder =
+
FlightDataEncoderBuilder::default().with_max_message_size(max_message_size);
+
+ let input_batch_stream = futures::stream::iter(input.clone()).map(Ok);
+
+ let encode_stream = encoder.build(input_batch_stream);
+
+ let decode_stream =
FlightRecordBatchStream::new_from_flight_data(encode_stream);
+ let output: Vec<_> = decode_stream.try_collect().await.expect("encode
/ decode");
+
+ let input_batch = concat_batches(&input[0].schema(), &input).unwrap();
+ let output_batch = concat_batches(&output[0].schema(),
&output).unwrap();
+ assert_eq!(input_batch, output_batch);
+ }
+}
+
+#[tokio::test]
+async fn test_mismatched_record_batch_schema() {
+ // send 2 batches with different schemas
+ let input_batch_stream = futures::stream::iter(vec![
+ Ok(make_primative_batch(5)),
+ Ok(make_dictionary_batch(3)),
+ ]);
+
+ let encoder = FlightDataEncoderBuilder::default();
+ let encode_stream = encoder.build(input_batch_stream);
+
+ let result: Result<Vec<_>, FlightError> =
encode_stream.try_collect().await;
+ let err = result.unwrap_err();
+ assert_eq!(
+ err.to_string(),
+ "Arrow(InvalidArgumentError(\"number of columns(1) must match number
of fields(2) in schema\"))"
+ );
+}
+
+#[tokio::test]
+async fn test_chained_streams_batch_decoder() {
+ let batch1 = make_primative_batch(5);
+ let batch2 = make_dictionary_batch(3);
+
+ // Model sending two flight streams back to back, with different schemas
+ let encode_stream1 = FlightDataEncoderBuilder::default()
+ .build(futures::stream::iter(vec![Ok(batch1.clone())]));
+ let encode_stream2 = FlightDataEncoderBuilder::default()
+ .build(futures::stream::iter(vec![Ok(batch2.clone())]));
+
+ // append the two streams (so they will have two different schema messages)
+ let encode_stream = encode_stream1.chain(encode_stream2);
+
+ // FlightRecordBatchStream errors if the schema changes
+ let decode_stream =
FlightRecordBatchStream::new_from_flight_data(encode_stream);
+ let result: Result<Vec<_>, FlightError> =
decode_stream.try_collect().await;
+
+ let err = result.unwrap_err();
+ assert_eq!(
+ err.to_string(),
+ "ProtocolError(\"Unexpectedly saw multiple Schema messages in
FlightData stream\")"
+ );
+}
+
+#[tokio::test]
+async fn test_chained_streams_data_decoder() {
+ let batch1 = make_primative_batch(5);
+ let batch2 = make_dictionary_batch(3);
+
+ // Model sending two flight streams back to back, with different schemas
+ let encode_stream1 = FlightDataEncoderBuilder::default()
+ .build(futures::stream::iter(vec![Ok(batch1.clone())]));
+ let encode_stream2 = FlightDataEncoderBuilder::default()
+ .build(futures::stream::iter(vec![Ok(batch2.clone())]));
+
+ // append the two streams (so they will have two different schema messages)
+ let encode_stream = encode_stream1.chain(encode_stream2);
+
+ // lower level decode stream can handle multiple schema messages
+ let decode_stream = FlightDataDecoder::new(encode_stream);
+
+ let decoded_data: Vec<_> =
+ decode_stream.try_collect().await.expect("encode / decode");
+
+ println!("decoded data: {decoded_data:#?}");
+
+ // expect two schema messages with the data
+ assert_eq!(decoded_data.len(), 4);
+ assert!(matches!(decoded_data[0].payload, DecodedPayload::Schema(_)));
+ assert!(matches!(
+ decoded_data[1].payload,
+ DecodedPayload::RecordBatch(_)
+ ));
+ assert!(matches!(decoded_data[2].payload, DecodedPayload::Schema(_)));
+ assert!(matches!(
+ decoded_data[3].payload,
+ DecodedPayload::RecordBatch(_)
+ ));
+}
+
+#[tokio::test]
+async fn test_mismatched_schema_message() {
+ // Model sending schema that is mismatched with the data
+ // and expect an error
+ async fn do_test(batch1: RecordBatch, batch2: RecordBatch, expected: &str)
{
+ let encode_stream1 = FlightDataEncoderBuilder::default()
+ .build(futures::stream::iter(vec![Ok(batch1.clone())]))
+ // take only schema message from first stream
+ .take(1);
+ let encode_stream2 = FlightDataEncoderBuilder::default()
+ .build(futures::stream::iter(vec![Ok(batch2.clone())]))
+ // take only data message from second
+ .skip(1);
+
+ // append the two streams
+ let encode_stream = encode_stream1.chain(encode_stream2);
+
+ // FlightRecordBatchStream errors if the schema changes
+ let decode_stream =
FlightRecordBatchStream::new_from_flight_data(encode_stream);
+ let result: Result<Vec<_>, FlightError> =
decode_stream.try_collect().await;
+
+ let err = result.unwrap_err().to_string();
+ assert!(
+ err.contains(expected),
+ "could not find '{expected}' in '{err}'"
+ );
+ }
+
+ // primitive batch first (has more columns)
+ do_test(
+ make_primative_batch(5),
+ make_dictionary_batch(3),
+ "Error decoding ipc RecordBatch: Io error: Invalid data for schema",
+ )
+ .await;
+
+ // dictioanry batch first
+ do_test(
+ make_dictionary_batch(3),
+ make_primative_batch(5),
+ "Error decoding ipc RecordBatch: Invalid argument error",
+ )
+ .await;
+}
+
+/// Make a primtive batch for testing
+///
+/// Example:
+/// i: 0, 1, None, 3, 4
+/// f: 5.0, 4.0, None, 2.0, 1.0
+fn make_primative_batch(num_rows: usize) -> RecordBatch {
+ let i: UInt8Array = (0..num_rows)
+ .map(|i| {
+ if i == num_rows / 2 {
+ None
+ } else {
+ Some(i.try_into().unwrap())
+ }
+ })
+ .collect();
+
+ let f: Float64Array = (0..num_rows)
+ .map(|i| {
+ if i == num_rows / 2 {
+ None
+ } else {
+ Some((num_rows - i) as f64)
+ }
+ })
+ .collect();
+
+ RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f",
Arc::new(f))])
+ .unwrap()
+}
+
+/// Make a dictionary batch for testing
+///
+/// Example:
+/// a: value0, value1, value2, None, value1, value2
+fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
+ let values: Vec<_> = (0..num_rows)
+ .map(|i| {
+ if i == num_rows / 2 {
+ None
+ } else {
+ // repeat some values for low cardinality
+ let v = i / 3;
+ Some(format!("value{v}"))
+ }
+ })
+ .collect();
+
+ let a: DictionaryArray<Int32Type> = values
+ .iter()
+ .map(|s| s.as_ref().map(|s| s.as_str()))
+ .collect();
+
+ RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
+}
+
+/// Encodes input as a FlightData stream, and then decodes it using
+/// FlightRecordBatchStream and valides the decoded record batches
+/// match the input.
+async fn roundtrip(input: Vec<RecordBatch>) {
+ let expected_output = input.clone();
+ roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input,
expected_output)
+ .await
+}
+
+/// Encodes input as a FlightData stream, and then decodes it using
+/// FlightRecordBatchStream and valides the decoded record batches
+/// match the expected input.
+///
+/// When <https://github.com/apache/arrow-rs/issues/3389> is resolved,
+/// it should be possible to use `roundtrip`
+async fn roundtrip_dictionary(input: Vec<RecordBatch>) {
+ let schema = Arc::new(prepare_schema_for_flight(&input[0].schema()));
+ let expected_output: Vec<_> = input
+ .iter()
+ .map(|batch| prepare_batch_for_flight(batch, schema.clone()).unwrap())
+ .collect();
+ roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input,
expected_output)
+ .await
+}
+
+async fn roundtrip_with_encoder(
+ encoder: FlightDataEncoderBuilder,
+ input_batches: Vec<RecordBatch>,
+ expected_batches: Vec<RecordBatch>,
+) {
+ println!("Round tripping with encoder:\n{encoder:#?}");
+
+ let input_batch_stream =
futures::stream::iter(input_batches.clone()).map(Ok);
+
+ let encode_stream = encoder.build(input_batch_stream);
+
+ let decode_stream =
FlightRecordBatchStream::new_from_flight_data(encode_stream);
+ let output_batches: Vec<_> =
+ decode_stream.try_collect().await.expect("encode / decode");
+
+ // remove any empty batches from input as they are not transmitted
+ let expected_batches: Vec<_> = expected_batches
+ .into_iter()
+ .filter(|b| b.num_rows() > 0)
+ .collect();
+
+ assert_eq!(expected_batches, output_batches);
+}
+
+/// Workaround for https://github.com/apache/arrow-rs/issues/1206
+fn prepare_schema_for_flight(schema: &Schema) -> Schema {
+ let fields = schema
+ .fields()
+ .iter()
+ .map(|field| match field.data_type() {
+ DataType::Dictionary(_, value_type) => Field::new(
+ field.name(),
+ value_type.as_ref().clone(),
+ field.is_nullable(),
+ )
+ .with_metadata(field.metadata().clone()),
+ _ => field.clone(),
+ })
+ .collect();
+
+ Schema::new(fields)
+}
+
+/// Workaround for https://github.com/apache/arrow-rs/issues/1206
+fn prepare_batch_for_flight(
+ batch: &RecordBatch,
+ schema: SchemaRef,
+) -> Result<RecordBatch, FlightError> {
+ let columns = batch
+ .columns()
+ .iter()
+ .map(hydrate_dictionary)
+ .collect::<Result<Vec<_>, _>>()?;
+
+ Ok(RecordBatch::try_new(schema, columns)?)
+}
+
+fn hydrate_dictionary(array: &ArrayRef) -> Result<ArrayRef, FlightError> {
+ let arr = if let DataType::Dictionary(_, value) = array.data_type() {
+ arrow_cast::cast(array, value)?
+ } else {
+ Arc::clone(array)
+ };
+ Ok(arr)
+}
diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index ef0a49be6..231f72910 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -298,6 +298,12 @@ fn create_array(
make_array(data)
}
_ => {
+ if nodes.len() <= node_index {
+ return Err(ArrowError::IoError(format!(
+ "Invalid data for schema. {} refers to node index {} but
only {} in schema",
+ field, node_index, nodes.len()
+ )));
+ }
let array = create_primitive_array(
nodes.get(node_index),
data_type,
diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index 106b4e4c9..82cf2c90b 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -37,7 +37,7 @@ use arrow_schema::*;
use crate::compression::CompressionCodec;
use crate::CONTINUATION_MARKER;
-/// IPC write options used to control the behaviour of the writer
+/// IPC write options used to control the behaviour of the [`IpcDataGenerator`]
#[derive(Debug, Clone)]
pub struct IpcWriteOptions {
/// Write padding after memory buffers to this multiple of bytes.
@@ -514,6 +514,9 @@ pub struct DictionaryTracker {
}
impl DictionaryTracker {
+ /// Create a new [`DictionaryTracker`]. If `error_on_replacement`
+ /// is true, an error will be generated if an update to an
+ /// existing dictionary is attempted.
pub fn new(error_on_replacement: bool) -> Self {
Self {
written: HashMap::new(),