This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a5ea83885c Implement FlightSQL spec change to support stateless 
prepared statements (#5433)
9a5ea83885c is described below

commit 9a5ea83885ce11439270910c5828943b747ac35e
Author: Adam Curtis <[email protected]>
AuthorDate: Tue Mar 26 14:42:55 2024 -0400

    Implement FlightSQL spec change to support stateless prepared statements 
(#5433)
    
    * feat: stateless FlightSQL prepared statements
    
    * update protobuf and improve legacy behavior
    
    * Update arrow-flight/src/sql/server.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * make DoPutPreparedStatenentResult mandatory
    
    * update DoPutPreparedStatementResult docs to match arrow repo
    
    * update comment about legacy server behavior in DoPut
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 arrow-flight/examples/flight_sql_server.rs        |  3 +-
 arrow-flight/src/sql/arrow.flight.protocol.sql.rs | 19 ++++++
 arrow-flight/src/sql/client.rs                    | 33 ++++++++--
 arrow-flight/src/sql/mod.rs                       |  2 +
 arrow-flight/src/sql/server.rs                    | 17 ++++-
 arrow-flight/tests/flight_sql_client_cli.rs       | 77 +++++++++++++++++------
 format/FlightSql.proto                            | 22 ++++++-
 7 files changed, 145 insertions(+), 28 deletions(-)

diff --git a/arrow-flight/examples/flight_sql_server.rs 
b/arrow-flight/examples/flight_sql_server.rs
index 02714e67c0b..a8f8d160650 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use arrow_flight::sql::server::PeekableFlightDataStream;
+use arrow_flight::sql::DoPutPreparedStatementResult;
 use base64::prelude::BASE64_STANDARD;
 use base64::Engine;
 use futures::{stream, Stream, TryStreamExt};
@@ -619,7 +620,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
         &self,
         _query: CommandPreparedStatementQuery,
         _request: Request<PeekableFlightDataStream>,
-    ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
+    ) -> Result<DoPutPreparedStatementResult, Status> {
         Err(Status::unimplemented(
             "do_put_prepared_statement_query not implemented",
         ))
diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs 
b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs
index 2b2f4af7ac9..01ea9b61a8f 100644
--- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs
+++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs
@@ -808,6 +808,25 @@ pub struct DoPutUpdateResult {
     #[prost(int64, tag = "1")]
     pub record_count: i64,
 }
+/// An *optional* response returned when `DoPut` is called with 
`CommandPreparedStatementQuery`.
+///
+/// *Note on legacy behavior*: previous versions of the protocol did not 
return any result for
+/// this command, and that behavior should still be supported by clients. In 
that case, the client
+/// can continue as though the fields in this message were not provided or set 
to sensible default values.
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct DoPutPreparedStatementResult {
+    /// Represents a (potentially updated) opaque handle for the prepared 
statement on the server.
+    /// Because the handle could potentially be updated, any previous handles 
for this prepared
+    /// statement should be considered invalid, and all subsequent requests 
for this prepared
+    /// statement must use this new handle.
+    /// The updated handle allows implementing query parameters with stateless 
services.
+    ///
+    /// When an updated handle is not provided by the server, clients should 
contiue
+    /// using the previous handle provided by 
`ActionCreatePreparedStatementResonse`.
+    #[prost(bytes = "bytes", optional, tag = "1")]
+    pub prepared_statement_handle: 
::core::option::Option<::prost::bytes::Bytes>,
+}
 ///
 /// Request message for the "CancelQuery" action.
 ///
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index a014137f6fa..44250fbe63e 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -35,7 +35,8 @@ use crate::sql::{
     CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, 
CommandGetPrimaryKeys,
     CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, 
CommandGetXdbcTypeInfo,
     CommandPreparedStatementQuery, CommandPreparedStatementUpdate, 
CommandStatementQuery,
-    CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
+    CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, 
ProstMessageExt,
+    SqlInfo,
 };
 use crate::trailers::extract_lazy_trailers;
 use crate::{
@@ -501,6 +502,7 @@ impl PreparedStatement<Channel> {
     }
 
     /// Submit parameters to the server, if any have been set on this prepared 
statement instance
+    /// Updates our stored prepared statement handle with the handle given by 
the server response.
     async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
         if let Some(ref params_batch) = self.parameter_binding {
             let cmd = CommandPreparedStatementQuery {
@@ -519,17 +521,38 @@ impl PreparedStatement<Channel> {
                 .await
                 .map_err(flight_error_to_arrow_error)?;
 
-            self.flight_sql_client
+            // Attempt to update the stored handle with any updated handle in 
the DoPut result.
+            // Older servers do not respond with a result for DoPut, so skip 
this step when
+            // the stream closes with no response.
+            if let Some(result) = self
+                .flight_sql_client
                 .do_put(stream::iter(flight_data))
                 .await?
-                .try_collect::<Vec<_>>()
+                .message()
                 .await
-                .map_err(status_to_arrow_error)?;
+                .map_err(status_to_arrow_error)?
+            {
+                if let Some(handle) = 
self.unpack_prepared_statement_handle(&result)? {
+                    self.handle = handle;
+                }
+            }
         }
-
         Ok(())
     }
 
+    /// Decodes the app_metadata stored in a [`PutResult`] as a
+    /// [`DoPutPreparedStatementResult`] and then returns
+    /// the inner prepared statement handle as [`Bytes`]
+    fn unpack_prepared_statement_handle(
+        &self,
+        put_result: &PutResult,
+    ) -> Result<Option<Bytes>, ArrowError> {
+        let any = 
Any::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
+        Ok(any
+            .unpack::<DoPutPreparedStatementResult>()?
+            .and_then(|result| result.prepared_statement_handle))
+    }
+
     /// Close the prepared statement, so that this PreparedStatement can not 
used
     /// anymore and server can free up any resources.
     pub async fn close(mut self) -> Result<(), ArrowError> {
diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs
index 97645ae7840..089ee4dd8c3 100644
--- a/arrow-flight/src/sql/mod.rs
+++ b/arrow-flight/src/sql/mod.rs
@@ -75,6 +75,7 @@ pub use gen::CommandPreparedStatementUpdate;
 pub use gen::CommandStatementQuery;
 pub use gen::CommandStatementSubstraitPlan;
 pub use gen::CommandStatementUpdate;
+pub use gen::DoPutPreparedStatementResult;
 pub use gen::DoPutUpdateResult;
 pub use gen::Nullable;
 pub use gen::Searchable;
@@ -251,6 +252,7 @@ prost_message_ext!(
     CommandStatementSubstraitPlan,
     CommandStatementUpdate,
     DoPutUpdateResult,
+    DoPutPreparedStatementResult,
     TicketStatementQuery,
 );
 
diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs
index 0431e58111a..c18024cf068 100644
--- a/arrow-flight/src/sql/server.rs
+++ b/arrow-flight/src/sql/server.rs
@@ -33,7 +33,8 @@ use super::{
     CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, 
CommandGetTables,
     CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, 
CommandPreparedStatementUpdate,
     CommandStatementQuery, CommandStatementSubstraitPlan, 
CommandStatementUpdate,
-    DoPutUpdateResult, ProstMessageExt, SqlInfo, TicketStatementQuery,
+    DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
+    TicketStatementQuery,
 };
 use crate::{
     flight_service_server::FlightService, gen::PollInfo, Action, ActionType, 
Criteria, Empty,
@@ -397,11 +398,15 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static 
{
     }
 
     /// Bind parameters to given prepared statement.
+    ///
+    /// Returns an opaque handle that the client should pass
+    /// back to the server during subsequent requests with this
+    /// prepared statement.
     async fn do_put_prepared_statement_query(
         &self,
         _query: CommandPreparedStatementQuery,
         _request: Request<PeekableFlightDataStream>,
-    ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
+    ) -> Result<DoPutPreparedStatementResult, Status> {
         Err(Status::unimplemented(
             "do_put_prepared_statement_query has no default implementation",
         ))
@@ -709,7 +714,13 @@ where
                 Ok(Response::new(Box::pin(output)))
             }
             Command::CommandPreparedStatementQuery(command) => {
-                self.do_put_prepared_statement_query(command, request).await
+                let result = self
+                    .do_put_prepared_statement_query(command, request)
+                    .await?;
+                let output = futures::stream::iter(vec![Ok(PutResult {
+                    app_metadata: result.as_any().encode_to_vec().into(),
+                })]);
+                Ok(Response::new(Box::pin(output)))
             }
             Command::CommandStatementSubstraitPlan(command) => {
                 let record_count = self.do_put_substrait_plan(command, 
request).await?;
diff --git a/arrow-flight/tests/flight_sql_client_cli.rs 
b/arrow-flight/tests/flight_sql_client_cli.rs
index cc270eeb618..50a4ec0d8c6 100644
--- a/arrow-flight/tests/flight_sql_client_cli.rs
+++ b/arrow-flight/tests/flight_sql_client_cli.rs
@@ -32,17 +32,18 @@ use arrow_flight::{
         CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, 
CommandGetTableTypes,
         CommandGetTables, CommandGetXdbcTypeInfo, 
CommandPreparedStatementQuery,
         CommandPreparedStatementUpdate, CommandStatementQuery, 
CommandStatementSubstraitPlan,
-        CommandStatementUpdate, ProstMessageExt, SqlInfo, TicketStatementQuery,
+        CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt, 
SqlInfo,
+        TicketStatementQuery,
     },
     utils::batches_to_flight_data,
     Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, 
HandshakeRequest,
-    HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket,
+    HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket,
 };
 use arrow_ipc::writer::IpcWriteOptions;
 use arrow_schema::{ArrowError, DataType, Field, Schema};
 use assert_cmd::Command;
 use bytes::Bytes;
-use futures::{Stream, StreamExt, TryStreamExt};
+use futures::{Stream, TryStreamExt};
 use prost::Message;
 use tokio::{net::TcpListener, task::JoinHandle};
 use tonic::{Request, Response, Status, Streaming};
@@ -51,7 +52,7 @@ const QUERY: &str = "SELECT * FROM table;";
 
 #[tokio::test]
 async fn test_simple() {
-    let test_server = FlightSqlServiceImpl {};
+    let test_server = FlightSqlServiceImpl::default();
     let fixture = TestFixture::new(&test_server).await;
     let addr = fixture.addr;
 
@@ -92,10 +93,9 @@ async fn test_simple() {
 
 const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1";
 const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle";
+const UPDATED_PREPARED_STATEMENT_HANDLE: &str = 
"updated_prepared_statement_handle";
 
-#[tokio::test]
-async fn test_do_put_prepared_statement() {
-    let test_server = FlightSqlServiceImpl {};
+async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) {
     let fixture = TestFixture::new(&test_server).await;
     let addr = fixture.addr;
 
@@ -136,11 +136,40 @@ async fn test_do_put_prepared_statement() {
     );
 }
 
+#[tokio::test]
+pub async fn test_do_put_prepared_statement_stateless() {
+    test_do_put_prepared_statement(FlightSqlServiceImpl {
+        stateless_prepared_statements: true,
+    })
+    .await
+}
+
+#[tokio::test]
+pub async fn test_do_put_prepared_statement_stateful() {
+    test_do_put_prepared_statement(FlightSqlServiceImpl {
+        stateless_prepared_statements: false,
+    })
+    .await
+}
+
 /// All tests must complete within this many seconds or else the test server 
is shutdown
 const DEFAULT_TIMEOUT_SECONDS: u64 = 30;
 
-#[derive(Clone, Default)]
-pub struct FlightSqlServiceImpl {}
+#[derive(Clone)]
+pub struct FlightSqlServiceImpl {
+    /// Whether to emulate stateless (true) or stateful (false) behavior for
+    /// prepared statements. stateful servers will not return an updated
+    /// handle after executing `DoPut(CommandPreparedStatementQuery)`
+    stateless_prepared_statements: bool,
+}
+
+impl Default for FlightSqlServiceImpl {
+    fn default() -> Self {
+        Self {
+            stateless_prepared_statements: true,
+        }
+    }
+}
 
 impl FlightSqlServiceImpl {
     /// Return an [`FlightServiceServer`] that can be used with a
@@ -274,10 +303,17 @@ impl FlightSqlService for FlightSqlServiceImpl {
         cmd: CommandPreparedStatementQuery,
         _request: Request<FlightDescriptor>,
     ) -> Result<Response<FlightInfo>, Status> {
-        assert_eq!(
-            cmd.prepared_statement_handle,
-            PREPARED_STATEMENT_HANDLE.as_bytes()
-        );
+        if self.stateless_prepared_statements {
+            assert_eq!(
+                cmd.prepared_statement_handle,
+                UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes()
+            );
+        } else {
+            assert_eq!(
+                cmd.prepared_statement_handle,
+                PREPARED_STATEMENT_HANDLE.as_bytes()
+            );
+        }
         let resp = Response::new(self.fake_flight_info().unwrap());
         Ok(resp)
     }
@@ -524,7 +560,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
         &self,
         _query: CommandPreparedStatementQuery,
         request: Request<PeekableFlightDataStream>,
-    ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
+    ) -> Result<DoPutPreparedStatementResult, Status> {
         // just make sure decoding the parameters works
         let parameters = FlightRecordBatchStream::new_from_flight_data(
             request.into_inner().map_err(|e| e.into()),
@@ -543,10 +579,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
                 )));
             }
         }
-
-        Ok(Response::new(
-            futures::stream::once(async { Ok(PutResult::default()) }).boxed(),
-        ))
+        let handle = if self.stateless_prepared_statements {
+            UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into()
+        } else {
+            PREPARED_STATEMENT_HANDLE.to_string().into()
+        };
+        let result = DoPutPreparedStatementResult {
+            prepared_statement_handle: Some(handle),
+        };
+        Ok(result)
     }
 
     async fn do_put_prepared_statement_update(
diff --git a/format/FlightSql.proto b/format/FlightSql.proto
index f78e77e2327..4fc68f2a5db 100644
--- a/format/FlightSql.proto
+++ b/format/FlightSql.proto
@@ -1796,7 +1796,27 @@
    // an unknown updated record count.
    int64 record_count = 1;
  }
- 
+
+  /* An *optional* response returned when `DoPut` is called with 
`CommandPreparedStatementQuery`.
+  *
+  * *Note on legacy behavior*: previous versions of the protocol did not 
return any result for
+  * this command, and that behavior should still be supported by clients. In 
that case, the client
+  * can continue as though the fields in this message were not provided or set 
to sensible default values.
+  */
+  message DoPutPreparedStatementResult {
+    option (experimental) = true;
+
+    // Represents a (potentially updated) opaque handle for the prepared 
statement on the server.
+    // Because the handle could potentially be updated, any previous handles 
for this prepared
+    // statement should be considered invalid, and all subsequent requests for 
this prepared
+    // statement must use this new handle.
+    // The updated handle allows implementing query parameters with stateless 
services.
+    // 
+    // When an updated handle is not provided by the server, clients should 
contiue
+    // using the previous handle provided by 
`ActionCreatePreparedStatementResonse`.
+    optional bytes prepared_statement_handle = 1;
+  }
+
  /*
   * Request message for the "CancelQuery" action.
   *

Reply via email to