This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 0130af34ba Expose bulk ingest in flight sql client and server (#6201)
0130af34ba is described below

commit 0130af34ba63f04e9a4a901f10fa602a087a84b4
Author: Douglas Anderson <[email protected]>
AuthorDate: Thu Aug 15 15:09:46 2024 -0600

    Expose bulk ingest in flight sql client and server (#6201)
    
    * Expose CommandStatementIngest as pub in sql module
    
    * Add do_put_statement_ingest to FlightSqlService
    
    Dispatch this handler for the new CommandStatementIngest command.
    
    * Sort list
    
    * Implement stub do_put_statement_ingest in example
    
    * Refactor helper functions into tests/common/utils
    
    * Implement execute_ingest for flight sql client
    
    I referenced the C++ implementation here: 
https://github.com/apache/arrow/commit/0d1ea5db1f9312412fe2cc28363e8c9deb2521ba
    
    * Add integration test for sql client execute_ingest
    
    * Fix lint clippy::new_without_default
    
    * Allow streaming ingest for FlightClient::execute_ingest
    
    * Properly return client errors
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 arrow-flight/examples/flight_sql_server.rs |  14 +++-
 arrow-flight/src/client.rs                 |   4 +-
 arrow-flight/src/sql/client.rs             |  55 +++++++++++++-
 arrow-flight/src/sql/mod.rs                |   8 +-
 arrow-flight/src/sql/server.rs             |  25 +++++-
 arrow-flight/tests/common/fixture.rs       |   1 +
 arrow-flight/tests/common/mod.rs           |   1 +
 arrow-flight/tests/common/utils.rs         | 118 +++++++++++++++++++++++++++++
 arrow-flight/tests/encode_decode.rs        |  98 +-----------------------
 arrow-flight/tests/flight_sql_client.rs    | 107 ++++++++++++++++++++++++--
 10 files changed, 319 insertions(+), 112 deletions(-)

diff --git a/arrow-flight/examples/flight_sql_server.rs 
b/arrow-flight/examples/flight_sql_server.rs
index d5168debc4..81afecf856 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -46,9 +46,9 @@ use arrow_flight::sql::{
     ActionEndTransactionRequest, Any, CommandGetCatalogs, 
CommandGetCrossReference,
     CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, 
CommandGetPrimaryKeys,
     CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, 
CommandGetXdbcTypeInfo,
-    CommandPreparedStatementQuery, CommandPreparedStatementUpdate, 
CommandStatementQuery,
-    CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, 
ProstMessageExt, Searchable,
-    SqlInfo, TicketStatementQuery, XdbcDataType,
+    CommandPreparedStatementQuery, CommandPreparedStatementUpdate, 
CommandStatementIngest,
+    CommandStatementQuery, CommandStatementSubstraitPlan, 
CommandStatementUpdate, Nullable,
+    ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType,
 };
 use arrow_flight::utils::batches_to_flight_data;
 use arrow_flight::{
@@ -615,6 +615,14 @@ impl FlightSqlService for FlightSqlServiceImpl {
         Ok(FAKE_UPDATE_RESULT)
     }
 
+    async fn do_put_statement_ingest(
+        &self,
+        _ticket: CommandStatementIngest,
+        _request: Request<PeekableFlightDataStream>,
+    ) -> Result<i64, Status> {
+        Ok(FAKE_UPDATE_RESULT)
+    }
+
     async fn do_put_substrait_plan(
         &self,
         _ticket: CommandStatementSubstraitPlan,
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index 3f62b256d5..af3c8fba30 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -679,7 +679,7 @@ impl FlightClient {
 /// it encounters an error it uses the oneshot sender to
 /// notify the error and stop any further streaming. See `do_put` or
 /// `do_exchange` for it's uses.
-struct FallibleRequestStream<T, E> {
+pub(crate) struct FallibleRequestStream<T, E> {
     /// sender to notify error
     sender: Option<Sender<E>>,
     /// fallible stream
@@ -687,7 +687,7 @@ struct FallibleRequestStream<T, E> {
 }
 
 impl<T, E> FallibleRequestStream<T, E> {
-    fn new(
+    pub(crate) fn new(
         sender: Sender<E>,
         fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> 
+ Send + 'static>>,
     ) -> Self {
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index 91790898b1..9f9963c925 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -24,6 +24,7 @@ use std::collections::HashMap;
 use std::str::FromStr;
 use tonic::metadata::AsciiMetadataKey;
 
+use crate::client::FallibleRequestStream;
 use crate::decode::FlightRecordBatchStream;
 use crate::encode::FlightDataEncoderBuilder;
 use crate::error::FlightError;
@@ -39,8 +40,8 @@ use crate::sql::{
     CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, 
CommandGetImportedKeys,
     CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, 
CommandGetTables,
     CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, 
CommandPreparedStatementUpdate,
-    CommandStatementQuery, CommandStatementUpdate, 
DoPutPreparedStatementResult, DoPutUpdateResult,
-    ProstMessageExt, SqlInfo,
+    CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
+    DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
 };
 use crate::trailers::extract_lazy_trailers;
 use crate::{
@@ -53,10 +54,10 @@ use arrow_ipc::convert::fb_to_schema;
 use arrow_ipc::reader::read_record_batch;
 use arrow_ipc::{root_as_message, MessageHeader};
 use arrow_schema::{ArrowError, Schema, SchemaRef};
-use futures::{stream, TryStreamExt};
+use futures::{stream, Stream, TryStreamExt};
 use prost::Message;
 use tonic::transport::Channel;
-use tonic::{IntoRequest, Streaming};
+use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
 
 /// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow 
data
 /// by FlightSQL protocol.
@@ -227,6 +228,52 @@ impl FlightSqlServiceClient<Channel> {
         Ok(result.record_count)
     }
 
+    /// Execute a bulk ingest on the server and return the number of records 
added
+    pub async fn execute_ingest<S>(
+        &mut self,
+        command: CommandStatementIngest,
+        stream: S,
+    ) -> Result<i64, ArrowError>
+    where
+        S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
+    {
+        let (sender, receiver) = futures::channel::oneshot::channel();
+
+        let descriptor = 
FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
+        let flight_data = FlightDataEncoderBuilder::new()
+            .with_flight_descriptor(Some(descriptor))
+            .build(stream);
+
+        // Intercept client errors and send them to the one shot channel above
+        let flight_data = Box::pin(flight_data);
+        let flight_data: FallibleRequestStream<FlightData, FlightError> =
+            FallibleRequestStream::new(sender, flight_data);
+
+        let req = 
self.set_request_headers(flight_data.into_streaming_request())?;
+        let mut result = self
+            .flight_client
+            .do_put(req)
+            .await
+            .map_err(status_to_arrow_error)?
+            .into_inner();
+
+        // check if the there were any errors in the input stream provided note
+        // if receiver.await fails, it means the sender was dropped and there 
is
+        // no message to return.
+        if let Ok(msg) = receiver.await {
+            return Err(ArrowError::ExternalError(Box::new(msg)));
+        }
+
+        let result = result
+            .message()
+            .await
+            .map_err(status_to_arrow_error)?
+            .unwrap();
+        let any = 
Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+        let result: DoPutUpdateResult = any.unpack()?.unwrap();
+        Ok(result.record_count)
+    }
+
     /// Request a list of catalogs as tabular FlightInfo results
     pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
         self.get_flight_info_for_command(CommandGetCatalogs {})
diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs
index 61eb67b693..453f608d35 100644
--- a/arrow-flight/src/sql/mod.rs
+++ b/arrow-flight/src/sql/mod.rs
@@ -50,6 +50,10 @@ mod gen {
 }
 
 pub use gen::action_end_transaction_request::EndTransaction;
+pub use gen::command_statement_ingest::table_definition_options::{
+    TableExistsOption, TableNotExistOption,
+};
+pub use gen::command_statement_ingest::TableDefinitionOptions;
 pub use gen::ActionBeginSavepointRequest;
 pub use gen::ActionBeginSavepointResult;
 pub use gen::ActionBeginTransactionRequest;
@@ -74,6 +78,7 @@ pub use gen::CommandGetTables;
 pub use gen::CommandGetXdbcTypeInfo;
 pub use gen::CommandPreparedStatementQuery;
 pub use gen::CommandPreparedStatementUpdate;
+pub use gen::CommandStatementIngest;
 pub use gen::CommandStatementQuery;
 pub use gen::CommandStatementSubstraitPlan;
 pub use gen::CommandStatementUpdate;
@@ -250,11 +255,12 @@ prost_message_ext!(
     CommandGetXdbcTypeInfo,
     CommandPreparedStatementQuery,
     CommandPreparedStatementUpdate,
+    CommandStatementIngest,
     CommandStatementQuery,
     CommandStatementSubstraitPlan,
     CommandStatementUpdate,
-    DoPutUpdateResult,
     DoPutPreparedStatementResult,
+    DoPutUpdateResult,
     TicketStatementQuery,
 );
 
diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs
index b47691c7da..e348367a91 100644
--- a/arrow-flight/src/sql/server.rs
+++ b/arrow-flight/src/sql/server.rs
@@ -32,9 +32,9 @@ use super::{
     CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, 
CommandGetImportedKeys,
     CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, 
CommandGetTables,
     CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, 
CommandPreparedStatementUpdate,
-    CommandStatementQuery, CommandStatementSubstraitPlan, 
CommandStatementUpdate,
-    DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
-    TicketStatementQuery,
+    CommandStatementIngest, CommandStatementQuery, 
CommandStatementSubstraitPlan,
+    CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, 
ProstMessageExt,
+    SqlInfo, TicketStatementQuery,
 };
 use crate::{
     flight_service_server::FlightService, gen::PollInfo, Action, ActionType, 
Criteria, Empty,
@@ -397,6 +397,17 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
         ))
     }
 
+    /// Execute a bulk ingestion.
+    async fn do_put_statement_ingest(
+        &self,
+        _ticket: CommandStatementIngest,
+        _request: Request<PeekableFlightDataStream>,
+    ) -> Result<i64, Status> {
+        Err(Status::unimplemented(
+            "do_put_statement_ingest has no default implementation",
+        ))
+    }
+
     /// Bind parameters to given prepared statement.
     ///
     /// Returns an opaque handle that the client should pass
@@ -713,6 +724,14 @@ where
                 })]);
                 Ok(Response::new(Box::pin(output)))
             }
+            Command::CommandStatementIngest(command) => {
+                let record_count = self.do_put_statement_ingest(command, 
request).await?;
+                let result = DoPutUpdateResult { record_count };
+                let output = futures::stream::iter(vec![Ok(PutResult {
+                    app_metadata: result.as_any().encode_to_vec().into(),
+                })]);
+                Ok(Response::new(Box::pin(output)))
+            }
             Command::CommandPreparedStatementQuery(command) => {
                 let result = self
                     .do_put_prepared_statement_query(command, request)
diff --git a/arrow-flight/tests/common/fixture.rs 
b/arrow-flight/tests/common/fixture.rs
index 141879e2a3..a666fa5d0d 100644
--- a/arrow-flight/tests/common/fixture.rs
+++ b/arrow-flight/tests/common/fixture.rs
@@ -41,6 +41,7 @@ pub struct TestFixture {
 
 impl TestFixture {
     /// create a new test fixture from the server
+    #[allow(dead_code)]
     pub async fn new<T: FlightService>(test_server: FlightServiceServer<T>) -> 
Self {
         // let OS choose a free port
         let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
diff --git a/arrow-flight/tests/common/mod.rs b/arrow-flight/tests/common/mod.rs
index 85716e5605..c4ac027c58 100644
--- a/arrow-flight/tests/common/mod.rs
+++ b/arrow-flight/tests/common/mod.rs
@@ -18,3 +18,4 @@
 pub mod fixture;
 pub mod server;
 pub mod trailers_layer;
+pub mod utils;
diff --git a/arrow-flight/tests/common/utils.rs 
b/arrow-flight/tests/common/utils.rs
new file mode 100644
index 0000000000..0f70e4b310
--- /dev/null
+++ b/arrow-flight/tests/common/utils.rs
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Common utilities for testing flight clients and servers
+
+use std::sync::Arc;
+
+use arrow_array::{
+    types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, 
Float64Array, RecordBatch,
+    StringViewArray, UInt8Array,
+};
+use arrow_schema::{DataType, Field, Schema};
+
+/// Make a primitive batch for testing
+///
+/// Example:
+/// i: 0, 1, None, 3, 4
+/// f: 5.0, 4.0, None, 2.0, 1.0
+#[allow(dead_code)]
+pub fn make_primitive_batch(num_rows: usize) -> RecordBatch {
+    let i: UInt8Array = (0..num_rows)
+        .map(|i| {
+            if i == num_rows / 2 {
+                None
+            } else {
+                Some(i.try_into().unwrap())
+            }
+        })
+        .collect();
+
+    let f: Float64Array = (0..num_rows)
+        .map(|i| {
+            if i == num_rows / 2 {
+                None
+            } else {
+                Some((num_rows - i) as f64)
+            }
+        })
+        .collect();
+
+    RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", 
Arc::new(f))]).unwrap()
+}
+
+/// Make a dictionary batch for testing
+///
+/// Example:
+/// a: value0, value1, value2, None, value1, value2
+#[allow(dead_code)]
+pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
+    let values: Vec<_> = (0..num_rows)
+        .map(|i| {
+            if i == num_rows / 2 {
+                None
+            } else {
+                // repeat some values for low cardinality
+                let v = i / 3;
+                Some(format!("value{v}"))
+            }
+        })
+        .collect();
+
+    let a: DictionaryArray<Int32Type> = values
+        .iter()
+        .map(|s| s.as_ref().map(|s| s.as_str()))
+        .collect();
+
+    RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
+}
+
+#[allow(dead_code)]
+pub fn make_view_batches(num_rows: usize) -> RecordBatch {
+    const LONG_TEST_STRING: &str =
+        "This is a long string to make sure binary view array handles it";
+    let schema = Schema::new(vec![
+        Field::new("field1", DataType::BinaryView, true),
+        Field::new("field2", DataType::Utf8View, true),
+    ]);
+
+    let string_view_values: Vec<Option<&str>> = (0..num_rows)
+        .map(|i| match i % 3 {
+            0 => None,
+            1 => Some("foo"),
+            2 => Some(LONG_TEST_STRING),
+            _ => unreachable!(),
+        })
+        .collect();
+
+    let bin_view_values: Vec<Option<&[u8]>> = (0..num_rows)
+        .map(|i| match i % 3 {
+            0 => None,
+            1 => Some("bar".as_bytes()),
+            2 => Some(LONG_TEST_STRING.as_bytes()),
+            _ => unreachable!(),
+        })
+        .collect();
+
+    let binary_array = BinaryViewArray::from_iter(bin_view_values);
+    let utf8_array = StringViewArray::from_iter(string_view_values);
+    RecordBatch::try_new(
+        Arc::new(schema.clone()),
+        vec![Arc::new(binary_array), Arc::new(utf8_array)],
+    )
+    .unwrap()
+}
diff --git a/arrow-flight/tests/encode_decode.rs 
b/arrow-flight/tests/encode_decode.rs
index 0185fa77f0..cbfae18258 100644
--- a/arrow-flight/tests/encode_decode.rs
+++ b/arrow-flight/tests/encode_decode.rs
@@ -19,11 +19,7 @@
 
 use std::{collections::HashMap, sync::Arc};
 
-use arrow_array::types::Int32Type;
-use arrow_array::{
-    ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, 
StringViewArray,
-    UInt8Array,
-};
+use arrow_array::{ArrayRef, RecordBatch};
 use arrow_cast::pretty::pretty_format_batches;
 use arrow_flight::flight_descriptor::DescriptorType;
 use arrow_flight::FlightDescriptor;
@@ -36,6 +32,9 @@ use arrow_schema::{DataType, Field, Fields, Schema, 
SchemaRef};
 use bytes::Bytes;
 use futures::{StreamExt, TryStreamExt};
 
+mod common;
+use common::utils::{make_dictionary_batch, make_primitive_batch, 
make_view_batches};
+
 #[tokio::test]
 async fn test_empty() {
     roundtrip(vec![]).await;
@@ -415,95 +414,6 @@ async fn test_mismatched_schema_message() {
     .await;
 }
 
-/// Make a primitive batch for testing
-///
-/// Example:
-/// i: 0, 1, None, 3, 4
-/// f: 5.0, 4.0, None, 2.0, 1.0
-fn make_primitive_batch(num_rows: usize) -> RecordBatch {
-    let i: UInt8Array = (0..num_rows)
-        .map(|i| {
-            if i == num_rows / 2 {
-                None
-            } else {
-                Some(i.try_into().unwrap())
-            }
-        })
-        .collect();
-
-    let f: Float64Array = (0..num_rows)
-        .map(|i| {
-            if i == num_rows / 2 {
-                None
-            } else {
-                Some((num_rows - i) as f64)
-            }
-        })
-        .collect();
-
-    RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", 
Arc::new(f))]).unwrap()
-}
-
-/// Make a dictionary batch for testing
-///
-/// Example:
-/// a: value0, value1, value2, None, value1, value2
-fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
-    let values: Vec<_> = (0..num_rows)
-        .map(|i| {
-            if i == num_rows / 2 {
-                None
-            } else {
-                // repeat some values for low cardinality
-                let v = i / 3;
-                Some(format!("value{v}"))
-            }
-        })
-        .collect();
-
-    let a: DictionaryArray<Int32Type> = values
-        .iter()
-        .map(|s| s.as_ref().map(|s| s.as_str()))
-        .collect();
-
-    RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
-}
-
-fn make_view_batches(num_rows: usize) -> RecordBatch {
-    const LONG_TEST_STRING: &str =
-        "This is a long string to make sure binary view array handles it";
-    let schema = Schema::new(vec![
-        Field::new("field1", DataType::BinaryView, true),
-        Field::new("field2", DataType::Utf8View, true),
-    ]);
-
-    let string_view_values: Vec<Option<&str>> = (0..num_rows)
-        .map(|i| match i % 3 {
-            0 => None,
-            1 => Some("foo"),
-            2 => Some(LONG_TEST_STRING),
-            _ => unreachable!(),
-        })
-        .collect();
-
-    let bin_view_values: Vec<Option<&[u8]>> = (0..num_rows)
-        .map(|i| match i % 3 {
-            0 => None,
-            1 => Some("bar".as_bytes()),
-            2 => Some(LONG_TEST_STRING.as_bytes()),
-            _ => unreachable!(),
-        })
-        .collect();
-
-    let binary_array = BinaryViewArray::from_iter(bin_view_values);
-    let utf8_array = StringViewArray::from_iter(string_view_values);
-    RecordBatch::try_new(
-        Arc::new(schema.clone()),
-        vec![Arc::new(binary_array), Arc::new(utf8_array)],
-    )
-    .unwrap()
-}
-
 /// Encodes input as a FlightData stream, and then decodes it using
 /// FlightRecordBatchStream and validates the decoded record batches
 /// match the input.
diff --git a/arrow-flight/tests/flight_sql_client.rs 
b/arrow-flight/tests/flight_sql_client.rs
index 94b768a136..349da062a8 100644
--- a/arrow-flight/tests/flight_sql_client.rs
+++ b/arrow-flight/tests/flight_sql_client.rs
@@ -18,14 +18,21 @@
 mod common;
 
 use crate::common::fixture::TestFixture;
+use crate::common::utils::make_primitive_batch;
+
+use arrow_array::RecordBatch;
+use arrow_flight::decode::FlightRecordBatchStream;
+use arrow_flight::error::FlightError;
 use arrow_flight::flight_service_server::FlightServiceServer;
 use arrow_flight::sql::client::FlightSqlServiceClient;
-use arrow_flight::sql::server::FlightSqlService;
+use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream};
 use arrow_flight::sql::{
     ActionBeginTransactionRequest, ActionBeginTransactionResult, 
ActionEndTransactionRequest,
-    EndTransaction, SqlInfo,
+    CommandStatementIngest, EndTransaction, SqlInfo, TableDefinitionOptions, 
TableExistsOption,
+    TableNotExistOption,
 };
 use arrow_flight::Action;
+use futures::{StreamExt, TryStreamExt};
 use std::collections::HashMap;
 use std::sync::Arc;
 use tokio::sync::Mutex;
@@ -34,9 +41,7 @@ use uuid::Uuid;
 
 #[tokio::test]
 pub async fn test_begin_end_transaction() {
-    let test_server = FlightSqlServiceImpl {
-        transactions: Arc::new(Mutex::new(HashMap::new())),
-    };
+    let test_server = FlightSqlServiceImpl::new();
     let fixture = TestFixture::new(test_server.service()).await;
     let channel = fixture.channel().await;
     let mut flight_sql_client = FlightSqlServiceClient::new(channel);
@@ -63,12 +68,83 @@ pub async fn test_begin_end_transaction() {
         .is_err());
 }
 
+#[tokio::test]
+pub async fn test_execute_ingest() {
+    let test_server = FlightSqlServiceImpl::new();
+    let fixture = TestFixture::new(test_server.service()).await;
+    let channel = fixture.channel().await;
+    let mut flight_sql_client = FlightSqlServiceClient::new(channel);
+    let cmd = make_ingest_command();
+    let expected_rows = 10;
+    let batches = vec![
+        make_primitive_batch(5),
+        make_primitive_batch(3),
+        make_primitive_batch(2),
+    ];
+    let actual_rows = flight_sql_client
+        .execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok))
+        .await
+        .expect("ingest should succeed");
+    assert_eq!(actual_rows, expected_rows);
+    // make sure the batches made it through to the server
+    let ingested_batches = test_server.ingested_batches.lock().await.clone();
+    assert_eq!(ingested_batches, batches);
+}
+
+#[tokio::test]
+pub async fn test_execute_ingest_error() {
+    let test_server = FlightSqlServiceImpl::new();
+    let fixture = TestFixture::new(test_server.service()).await;
+    let channel = fixture.channel().await;
+    let mut flight_sql_client = FlightSqlServiceClient::new(channel);
+    let cmd = make_ingest_command();
+    // send an error from the client
+    let batches = vec![
+        Ok(make_primitive_batch(5)),
+        Err(FlightError::NotYetImplemented(
+            "Client error message".to_string(),
+        )),
+    ];
+    // make sure the client returns the error from the client
+    let err = flight_sql_client
+        .execute_ingest(cmd, futures::stream::iter(batches))
+        .await
+        .unwrap_err();
+    assert_eq!(
+        err.to_string(),
+        "External error: Not yet implemented: Client error message"
+    );
+}
+
+fn make_ingest_command() -> CommandStatementIngest {
+    CommandStatementIngest {
+        table_definition_options: Some(TableDefinitionOptions {
+            if_not_exist: TableNotExistOption::Create.into(),
+            if_exists: TableExistsOption::Fail.into(),
+        }),
+        table: String::from("test"),
+        schema: None,
+        catalog: None,
+        temporary: true,
+        transaction_id: None,
+        options: HashMap::default(),
+    }
+}
+
 #[derive(Clone)]
 pub struct FlightSqlServiceImpl {
     transactions: Arc<Mutex<HashMap<String, ()>>>,
+    ingested_batches: Arc<Mutex<Vec<RecordBatch>>>,
 }
 
 impl FlightSqlServiceImpl {
+    pub fn new() -> Self {
+        Self {
+            transactions: Arc::new(Mutex::new(HashMap::new())),
+            ingested_batches: Arc::new(Mutex::new(Vec::new())),
+        }
+    }
+
     /// Return an [`FlightServiceServer`] that can be used with a
     /// [`Server`](tonic::transport::Server)
     pub fn service(&self) -> FlightServiceServer<Self> {
@@ -77,6 +153,12 @@ impl FlightSqlServiceImpl {
     }
 }
 
+impl Default for FlightSqlServiceImpl {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
 #[tonic::async_trait]
 impl FlightSqlService for FlightSqlServiceImpl {
     type FlightService = FlightSqlServiceImpl;
@@ -116,4 +198,19 @@ impl FlightSqlService for FlightSqlServiceImpl {
     }
 
     async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
+
+    async fn do_put_statement_ingest(
+        &self,
+        _ticket: CommandStatementIngest,
+        request: Request<PeekableFlightDataStream>,
+    ) -> Result<i64, Status> {
+        let batches: Vec<RecordBatch> = 
FlightRecordBatchStream::new_from_flight_data(
+            request.into_inner().map_err(|e| e.into()),
+        )
+        .try_collect()
+        .await?;
+        let affected_rows = batches.iter().map(|batch| batch.num_rows() as 
i64).sum();
+        *self.ingested_batches.lock().await.as_mut() = batches;
+        Ok(affected_rows)
+    }
 }

Reply via email to