This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 0130af34ba Expose bulk ingest in flight sql client and server (#6201)
0130af34ba is described below
commit 0130af34ba63f04e9a4a901f10fa602a087a84b4
Author: Douglas Anderson <[email protected]>
AuthorDate: Thu Aug 15 15:09:46 2024 -0600
Expose bulk ingest in flight sql client and server (#6201)
* Expose CommandStatementIngest as pub in sql module
* Add do_put_statement_ingest to FlightSqlService
Dispatch this handler for the new CommandStatementIngest command.
* Sort list
* Implement stub do_put_statement_ingest in example
* Refactor helper functions into tests/common/utils
* Implement execute_ingest for flight sql client
I referenced the C++ implementation here:
https://github.com/apache/arrow/commit/0d1ea5db1f9312412fe2cc28363e8c9deb2521ba
* Add integration test for sql client execute_ingest
* Fix lint clippy::new_without_default
* Allow streaming ingest for FlightClient::execute_ingest
* Properly return client errors
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-flight/examples/flight_sql_server.rs | 14 +++-
arrow-flight/src/client.rs | 4 +-
arrow-flight/src/sql/client.rs | 55 +++++++++++++-
arrow-flight/src/sql/mod.rs | 8 +-
arrow-flight/src/sql/server.rs | 25 +++++-
arrow-flight/tests/common/fixture.rs | 1 +
arrow-flight/tests/common/mod.rs | 1 +
arrow-flight/tests/common/utils.rs | 118 +++++++++++++++++++++++++++++
arrow-flight/tests/encode_decode.rs | 98 +-----------------------
arrow-flight/tests/flight_sql_client.rs | 107 ++++++++++++++++++++++++--
10 files changed, 319 insertions(+), 112 deletions(-)
diff --git a/arrow-flight/examples/flight_sql_server.rs
b/arrow-flight/examples/flight_sql_server.rs
index d5168debc4..81afecf856 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -46,9 +46,9 @@ use arrow_flight::sql::{
ActionEndTransactionRequest, Any, CommandGetCatalogs,
CommandGetCrossReference,
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys,
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo,
- CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery,
- CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable,
ProstMessageExt, Searchable,
- SqlInfo, TicketStatementQuery, XdbcDataType,
+ CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementIngest,
+ CommandStatementQuery, CommandStatementSubstraitPlan,
CommandStatementUpdate, Nullable,
+ ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType,
};
use arrow_flight::utils::batches_to_flight_data;
use arrow_flight::{
@@ -615,6 +615,14 @@ impl FlightSqlService for FlightSqlServiceImpl {
Ok(FAKE_UPDATE_RESULT)
}
+ async fn do_put_statement_ingest(
+ &self,
+ _ticket: CommandStatementIngest,
+ _request: Request<PeekableFlightDataStream>,
+ ) -> Result<i64, Status> {
+ Ok(FAKE_UPDATE_RESULT)
+ }
+
async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index 3f62b256d5..af3c8fba30 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -679,7 +679,7 @@ impl FlightClient {
/// it encounters an error it uses the oneshot sender to
/// notify the error and stop any further streaming. See `do_put` or
/// `do_exchange` for it's uses.
-struct FallibleRequestStream<T, E> {
+pub(crate) struct FallibleRequestStream<T, E> {
/// sender to notify error
sender: Option<Sender<E>>,
/// fallible stream
@@ -687,7 +687,7 @@ struct FallibleRequestStream<T, E> {
}
impl<T, E> FallibleRequestStream<T, E> {
- fn new(
+ pub(crate) fn new(
sender: Sender<E>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>>
+ Send + 'static>>,
) -> Self {
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index 91790898b1..9f9963c925 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -24,6 +24,7 @@ use std::collections::HashMap;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;
+use crate::client::FallibleRequestStream;
use crate::decode::FlightRecordBatchStream;
use crate::encode::FlightDataEncoderBuilder;
use crate::error::FlightError;
@@ -39,8 +40,8 @@ use crate::sql::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
CommandPreparedStatementUpdate,
- CommandStatementQuery, CommandStatementUpdate,
DoPutPreparedStatementResult, DoPutUpdateResult,
- ProstMessageExt, SqlInfo,
+ CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
+ DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::trailers::extract_lazy_trailers;
use crate::{
@@ -53,10 +54,10 @@ use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::reader::read_record_batch;
use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
-use futures::{stream, TryStreamExt};
+use futures::{stream, Stream, TryStreamExt};
use prost::Message;
use tonic::transport::Channel;
-use tonic::{IntoRequest, Streaming};
+use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow
data
/// by FlightSQL protocol.
@@ -227,6 +228,52 @@ impl FlightSqlServiceClient<Channel> {
Ok(result.record_count)
}
+ /// Execute a bulk ingest on the server and return the number of records
added
+ pub async fn execute_ingest<S>(
+ &mut self,
+ command: CommandStatementIngest,
+ stream: S,
+ ) -> Result<i64, ArrowError>
+ where
+ S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
+ {
+ let (sender, receiver) = futures::channel::oneshot::channel();
+
+ let descriptor =
FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
+ let flight_data = FlightDataEncoderBuilder::new()
+ .with_flight_descriptor(Some(descriptor))
+ .build(stream);
+
+ // Intercept client errors and send them to the one shot channel above
+ let flight_data = Box::pin(flight_data);
+ let flight_data: FallibleRequestStream<FlightData, FlightError> =
+ FallibleRequestStream::new(sender, flight_data);
+
+ let req =
self.set_request_headers(flight_data.into_streaming_request())?;
+ let mut result = self
+ .flight_client
+ .do_put(req)
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner();
+
+ // check if the there were any errors in the input stream provided note
+ // if receiver.await fails, it means the sender was dropped and there
is
+ // no message to return.
+ if let Ok(msg) = receiver.await {
+ return Err(ArrowError::ExternalError(Box::new(msg)));
+ }
+
+ let result = result
+ .message()
+ .await
+ .map_err(status_to_arrow_error)?
+ .unwrap();
+ let any =
Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+ let result: DoPutUpdateResult = any.unpack()?.unwrap();
+ Ok(result.record_count)
+ }
+
/// Request a list of catalogs as tabular FlightInfo results
pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(CommandGetCatalogs {})
diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs
index 61eb67b693..453f608d35 100644
--- a/arrow-flight/src/sql/mod.rs
+++ b/arrow-flight/src/sql/mod.rs
@@ -50,6 +50,10 @@ mod gen {
}
pub use gen::action_end_transaction_request::EndTransaction;
+pub use gen::command_statement_ingest::table_definition_options::{
+ TableExistsOption, TableNotExistOption,
+};
+pub use gen::command_statement_ingest::TableDefinitionOptions;
pub use gen::ActionBeginSavepointRequest;
pub use gen::ActionBeginSavepointResult;
pub use gen::ActionBeginTransactionRequest;
@@ -74,6 +78,7 @@ pub use gen::CommandGetTables;
pub use gen::CommandGetXdbcTypeInfo;
pub use gen::CommandPreparedStatementQuery;
pub use gen::CommandPreparedStatementUpdate;
+pub use gen::CommandStatementIngest;
pub use gen::CommandStatementQuery;
pub use gen::CommandStatementSubstraitPlan;
pub use gen::CommandStatementUpdate;
@@ -250,11 +255,12 @@ prost_message_ext!(
CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery,
CommandPreparedStatementUpdate,
+ CommandStatementIngest,
CommandStatementQuery,
CommandStatementSubstraitPlan,
CommandStatementUpdate,
- DoPutUpdateResult,
DoPutPreparedStatementResult,
+ DoPutUpdateResult,
TicketStatementQuery,
);
diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs
index b47691c7da..e348367a91 100644
--- a/arrow-flight/src/sql/server.rs
+++ b/arrow-flight/src/sql/server.rs
@@ -32,9 +32,9 @@ use super::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
CommandPreparedStatementUpdate,
- CommandStatementQuery, CommandStatementSubstraitPlan,
CommandStatementUpdate,
- DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
- TicketStatementQuery,
+ CommandStatementIngest, CommandStatementQuery,
CommandStatementSubstraitPlan,
+ CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult,
ProstMessageExt,
+ SqlInfo, TicketStatementQuery,
};
use crate::{
flight_service_server::FlightService, gen::PollInfo, Action, ActionType,
Criteria, Empty,
@@ -397,6 +397,17 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
))
}
+ /// Execute a bulk ingestion.
+ async fn do_put_statement_ingest(
+ &self,
+ _ticket: CommandStatementIngest,
+ _request: Request<PeekableFlightDataStream>,
+ ) -> Result<i64, Status> {
+ Err(Status::unimplemented(
+ "do_put_statement_ingest has no default implementation",
+ ))
+ }
+
/// Bind parameters to given prepared statement.
///
/// Returns an opaque handle that the client should pass
@@ -713,6 +724,14 @@ where
})]);
Ok(Response::new(Box::pin(output)))
}
+ Command::CommandStatementIngest(command) => {
+ let record_count = self.do_put_statement_ingest(command,
request).await?;
+ let result = DoPutUpdateResult { record_count };
+ let output = futures::stream::iter(vec![Ok(PutResult {
+ app_metadata: result.as_any().encode_to_vec().into(),
+ })]);
+ Ok(Response::new(Box::pin(output)))
+ }
Command::CommandPreparedStatementQuery(command) => {
let result = self
.do_put_prepared_statement_query(command, request)
diff --git a/arrow-flight/tests/common/fixture.rs
b/arrow-flight/tests/common/fixture.rs
index 141879e2a3..a666fa5d0d 100644
--- a/arrow-flight/tests/common/fixture.rs
+++ b/arrow-flight/tests/common/fixture.rs
@@ -41,6 +41,7 @@ pub struct TestFixture {
impl TestFixture {
/// create a new test fixture from the server
+ #[allow(dead_code)]
pub async fn new<T: FlightService>(test_server: FlightServiceServer<T>) ->
Self {
// let OS choose a free port
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
diff --git a/arrow-flight/tests/common/mod.rs b/arrow-flight/tests/common/mod.rs
index 85716e5605..c4ac027c58 100644
--- a/arrow-flight/tests/common/mod.rs
+++ b/arrow-flight/tests/common/mod.rs
@@ -18,3 +18,4 @@
pub mod fixture;
pub mod server;
pub mod trailers_layer;
+pub mod utils;
diff --git a/arrow-flight/tests/common/utils.rs
b/arrow-flight/tests/common/utils.rs
new file mode 100644
index 0000000000..0f70e4b310
--- /dev/null
+++ b/arrow-flight/tests/common/utils.rs
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Common utilities for testing flight clients and servers
+
+use std::sync::Arc;
+
+use arrow_array::{
+ types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray,
Float64Array, RecordBatch,
+ StringViewArray, UInt8Array,
+};
+use arrow_schema::{DataType, Field, Schema};
+
+/// Make a primitive batch for testing
+///
+/// Example:
+/// i: 0, 1, None, 3, 4
+/// f: 5.0, 4.0, None, 2.0, 1.0
+#[allow(dead_code)]
+pub fn make_primitive_batch(num_rows: usize) -> RecordBatch {
+ let i: UInt8Array = (0..num_rows)
+ .map(|i| {
+ if i == num_rows / 2 {
+ None
+ } else {
+ Some(i.try_into().unwrap())
+ }
+ })
+ .collect();
+
+ let f: Float64Array = (0..num_rows)
+ .map(|i| {
+ if i == num_rows / 2 {
+ None
+ } else {
+ Some((num_rows - i) as f64)
+ }
+ })
+ .collect();
+
+ RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f",
Arc::new(f))]).unwrap()
+}
+
+/// Make a dictionary batch for testing
+///
+/// Example:
+/// a: value0, value1, value2, None, value1, value2
+#[allow(dead_code)]
+pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
+ let values: Vec<_> = (0..num_rows)
+ .map(|i| {
+ if i == num_rows / 2 {
+ None
+ } else {
+ // repeat some values for low cardinality
+ let v = i / 3;
+ Some(format!("value{v}"))
+ }
+ })
+ .collect();
+
+ let a: DictionaryArray<Int32Type> = values
+ .iter()
+ .map(|s| s.as_ref().map(|s| s.as_str()))
+ .collect();
+
+ RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
+}
+
+#[allow(dead_code)]
+pub fn make_view_batches(num_rows: usize) -> RecordBatch {
+ const LONG_TEST_STRING: &str =
+ "This is a long string to make sure binary view array handles it";
+ let schema = Schema::new(vec![
+ Field::new("field1", DataType::BinaryView, true),
+ Field::new("field2", DataType::Utf8View, true),
+ ]);
+
+ let string_view_values: Vec<Option<&str>> = (0..num_rows)
+ .map(|i| match i % 3 {
+ 0 => None,
+ 1 => Some("foo"),
+ 2 => Some(LONG_TEST_STRING),
+ _ => unreachable!(),
+ })
+ .collect();
+
+ let bin_view_values: Vec<Option<&[u8]>> = (0..num_rows)
+ .map(|i| match i % 3 {
+ 0 => None,
+ 1 => Some("bar".as_bytes()),
+ 2 => Some(LONG_TEST_STRING.as_bytes()),
+ _ => unreachable!(),
+ })
+ .collect();
+
+ let binary_array = BinaryViewArray::from_iter(bin_view_values);
+ let utf8_array = StringViewArray::from_iter(string_view_values);
+ RecordBatch::try_new(
+ Arc::new(schema.clone()),
+ vec![Arc::new(binary_array), Arc::new(utf8_array)],
+ )
+ .unwrap()
+}
diff --git a/arrow-flight/tests/encode_decode.rs
b/arrow-flight/tests/encode_decode.rs
index 0185fa77f0..cbfae18258 100644
--- a/arrow-flight/tests/encode_decode.rs
+++ b/arrow-flight/tests/encode_decode.rs
@@ -19,11 +19,7 @@
use std::{collections::HashMap, sync::Arc};
-use arrow_array::types::Int32Type;
-use arrow_array::{
- ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch,
StringViewArray,
- UInt8Array,
-};
+use arrow_array::{ArrayRef, RecordBatch};
use arrow_cast::pretty::pretty_format_batches;
use arrow_flight::flight_descriptor::DescriptorType;
use arrow_flight::FlightDescriptor;
@@ -36,6 +32,9 @@ use arrow_schema::{DataType, Field, Fields, Schema,
SchemaRef};
use bytes::Bytes;
use futures::{StreamExt, TryStreamExt};
+mod common;
+use common::utils::{make_dictionary_batch, make_primitive_batch,
make_view_batches};
+
#[tokio::test]
async fn test_empty() {
roundtrip(vec![]).await;
@@ -415,95 +414,6 @@ async fn test_mismatched_schema_message() {
.await;
}
-/// Make a primitive batch for testing
-///
-/// Example:
-/// i: 0, 1, None, 3, 4
-/// f: 5.0, 4.0, None, 2.0, 1.0
-fn make_primitive_batch(num_rows: usize) -> RecordBatch {
- let i: UInt8Array = (0..num_rows)
- .map(|i| {
- if i == num_rows / 2 {
- None
- } else {
- Some(i.try_into().unwrap())
- }
- })
- .collect();
-
- let f: Float64Array = (0..num_rows)
- .map(|i| {
- if i == num_rows / 2 {
- None
- } else {
- Some((num_rows - i) as f64)
- }
- })
- .collect();
-
- RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f",
Arc::new(f))]).unwrap()
-}
-
-/// Make a dictionary batch for testing
-///
-/// Example:
-/// a: value0, value1, value2, None, value1, value2
-fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
- let values: Vec<_> = (0..num_rows)
- .map(|i| {
- if i == num_rows / 2 {
- None
- } else {
- // repeat some values for low cardinality
- let v = i / 3;
- Some(format!("value{v}"))
- }
- })
- .collect();
-
- let a: DictionaryArray<Int32Type> = values
- .iter()
- .map(|s| s.as_ref().map(|s| s.as_str()))
- .collect();
-
- RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
-}
-
-fn make_view_batches(num_rows: usize) -> RecordBatch {
- const LONG_TEST_STRING: &str =
- "This is a long string to make sure binary view array handles it";
- let schema = Schema::new(vec![
- Field::new("field1", DataType::BinaryView, true),
- Field::new("field2", DataType::Utf8View, true),
- ]);
-
- let string_view_values: Vec<Option<&str>> = (0..num_rows)
- .map(|i| match i % 3 {
- 0 => None,
- 1 => Some("foo"),
- 2 => Some(LONG_TEST_STRING),
- _ => unreachable!(),
- })
- .collect();
-
- let bin_view_values: Vec<Option<&[u8]>> = (0..num_rows)
- .map(|i| match i % 3 {
- 0 => None,
- 1 => Some("bar".as_bytes()),
- 2 => Some(LONG_TEST_STRING.as_bytes()),
- _ => unreachable!(),
- })
- .collect();
-
- let binary_array = BinaryViewArray::from_iter(bin_view_values);
- let utf8_array = StringViewArray::from_iter(string_view_values);
- RecordBatch::try_new(
- Arc::new(schema.clone()),
- vec![Arc::new(binary_array), Arc::new(utf8_array)],
- )
- .unwrap()
-}
-
/// Encodes input as a FlightData stream, and then decodes it using
/// FlightRecordBatchStream and validates the decoded record batches
/// match the input.
diff --git a/arrow-flight/tests/flight_sql_client.rs
b/arrow-flight/tests/flight_sql_client.rs
index 94b768a136..349da062a8 100644
--- a/arrow-flight/tests/flight_sql_client.rs
+++ b/arrow-flight/tests/flight_sql_client.rs
@@ -18,14 +18,21 @@
mod common;
use crate::common::fixture::TestFixture;
+use crate::common::utils::make_primitive_batch;
+
+use arrow_array::RecordBatch;
+use arrow_flight::decode::FlightRecordBatchStream;
+use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightServiceServer;
use arrow_flight::sql::client::FlightSqlServiceClient;
-use arrow_flight::sql::server::FlightSqlService;
+use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream};
use arrow_flight::sql::{
ActionBeginTransactionRequest, ActionBeginTransactionResult,
ActionEndTransactionRequest,
- EndTransaction, SqlInfo,
+ CommandStatementIngest, EndTransaction, SqlInfo, TableDefinitionOptions,
TableExistsOption,
+ TableNotExistOption,
};
use arrow_flight::Action;
+use futures::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
@@ -34,9 +41,7 @@ use uuid::Uuid;
#[tokio::test]
pub async fn test_begin_end_transaction() {
- let test_server = FlightSqlServiceImpl {
- transactions: Arc::new(Mutex::new(HashMap::new())),
- };
+ let test_server = FlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service()).await;
let channel = fixture.channel().await;
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
@@ -63,12 +68,83 @@ pub async fn test_begin_end_transaction() {
.is_err());
}
+#[tokio::test]
+pub async fn test_execute_ingest() {
+ let test_server = FlightSqlServiceImpl::new();
+ let fixture = TestFixture::new(test_server.service()).await;
+ let channel = fixture.channel().await;
+ let mut flight_sql_client = FlightSqlServiceClient::new(channel);
+ let cmd = make_ingest_command();
+ let expected_rows = 10;
+ let batches = vec![
+ make_primitive_batch(5),
+ make_primitive_batch(3),
+ make_primitive_batch(2),
+ ];
+ let actual_rows = flight_sql_client
+ .execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok))
+ .await
+ .expect("ingest should succeed");
+ assert_eq!(actual_rows, expected_rows);
+ // make sure the batches made it through to the server
+ let ingested_batches = test_server.ingested_batches.lock().await.clone();
+ assert_eq!(ingested_batches, batches);
+}
+
+#[tokio::test]
+pub async fn test_execute_ingest_error() {
+ let test_server = FlightSqlServiceImpl::new();
+ let fixture = TestFixture::new(test_server.service()).await;
+ let channel = fixture.channel().await;
+ let mut flight_sql_client = FlightSqlServiceClient::new(channel);
+ let cmd = make_ingest_command();
+ // send an error from the client
+ let batches = vec![
+ Ok(make_primitive_batch(5)),
+ Err(FlightError::NotYetImplemented(
+ "Client error message".to_string(),
+ )),
+ ];
+ // make sure the client returns the error from the client
+ let err = flight_sql_client
+ .execute_ingest(cmd, futures::stream::iter(batches))
+ .await
+ .unwrap_err();
+ assert_eq!(
+ err.to_string(),
+ "External error: Not yet implemented: Client error message"
+ );
+}
+
+fn make_ingest_command() -> CommandStatementIngest {
+ CommandStatementIngest {
+ table_definition_options: Some(TableDefinitionOptions {
+ if_not_exist: TableNotExistOption::Create.into(),
+ if_exists: TableExistsOption::Fail.into(),
+ }),
+ table: String::from("test"),
+ schema: None,
+ catalog: None,
+ temporary: true,
+ transaction_id: None,
+ options: HashMap::default(),
+ }
+}
+
#[derive(Clone)]
pub struct FlightSqlServiceImpl {
transactions: Arc<Mutex<HashMap<String, ()>>>,
+ ingested_batches: Arc<Mutex<Vec<RecordBatch>>>,
}
impl FlightSqlServiceImpl {
+ pub fn new() -> Self {
+ Self {
+ transactions: Arc::new(Mutex::new(HashMap::new())),
+ ingested_batches: Arc::new(Mutex::new(Vec::new())),
+ }
+ }
+
/// Return an [`FlightServiceServer`] that can be used with a
/// [`Server`](tonic::transport::Server)
pub fn service(&self) -> FlightServiceServer<Self> {
@@ -77,6 +153,12 @@ impl FlightSqlServiceImpl {
}
}
+impl Default for FlightSqlServiceImpl {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
#[tonic::async_trait]
impl FlightSqlService for FlightSqlServiceImpl {
type FlightService = FlightSqlServiceImpl;
@@ -116,4 +198,19 @@ impl FlightSqlService for FlightSqlServiceImpl {
}
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
+
+ async fn do_put_statement_ingest(
+ &self,
+ _ticket: CommandStatementIngest,
+ request: Request<PeekableFlightDataStream>,
+ ) -> Result<i64, Status> {
+ let batches: Vec<RecordBatch> =
FlightRecordBatchStream::new_from_flight_data(
+ request.into_inner().map_err(|e| e.into()),
+ )
+ .try_collect()
+ .await?;
+ let affected_rows = batches.iter().map(|batch| batch.num_rows() as
i64).sum();
+ *self.ingested_batches.lock().await.as_mut() = batches;
+ Ok(affected_rows)
+ }
}