This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 298d3aaf62 Change FlightSQLClient to return `FlightError` & cleanup 
code (#8916)
298d3aaf62 is described below

commit 298d3aaf6225acc992099915de300497cbbba4b0
Author: 张林伟 <[email protected]>
AuthorDate: Sat Jan 10 21:40:52 2026 +0800

    Change FlightSQLClient to return `FlightError` & cleanup code (#8916)
    
    # Which issue does this PR close?
    
    None.
    
    # Rationale for this change
    
    1. Keep consistent with `FlightClient`.
    2. We need inner error (for example `tonic::Status`) to know if we
    should retry, but ArrowError stores string of `tonic::Status`.
    
    # What changes are included in this PR?
    
    1. Let flight sql client returns `FlightError` .
    2. Cleanup code.
    
    # Are these changes tested?
    
    CI
    
    # Are there any user-facing changes?
    
    Yes.
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 arrow-flight/src/error.rs      |   6 ++
 arrow-flight/src/sql/client.rs | 229 ++++++++++++-----------------------------
 2 files changed, 70 insertions(+), 165 deletions(-)

diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs
index d5ac568e97..d22c24eea6 100644
--- a/arrow-flight/src/error.rs
+++ b/arrow-flight/src/error.rs
@@ -78,6 +78,12 @@ impl From<tonic::Status> for FlightError {
     }
 }
 
+impl From<prost::DecodeError> for FlightError {
+    fn from(error: prost::DecodeError) -> Self {
+        Self::DecodeError(error.to_string())
+    }
+}
+
 impl From<ArrowError> for FlightError {
     fn from(value: ArrowError) -> Self {
         Self::Arrow(value)
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index 4fb27ab5fc..5476d4ede9 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -17,6 +17,12 @@
 
 //! A FlightSQL Client [`FlightSqlServiceClient`]
 
+use arrow_buffer::Buffer;
+use arrow_ipc::MessageHeader;
+use arrow_ipc::convert::fb_to_schema;
+use arrow_ipc::reader::read_record_batch;
+use arrow_ipc::root_as_message;
+use arrow_schema::SchemaRef;
 use base64::Engine;
 use base64::prelude::BASE64_STANDARD;
 use bytes::Bytes;
@@ -27,6 +33,7 @@ use tonic::metadata::AsciiMetadataKey;
 use crate::decode::FlightRecordBatchStream;
 use crate::encode::FlightDataEncoderBuilder;
 use crate::error::FlightError;
+use crate::error::Result;
 use crate::flight_service_client::FlightServiceClient;
 use crate::sql::r#gen::action_end_transaction_request::EndTransaction;
 use crate::sql::server::{
@@ -49,11 +56,7 @@ use crate::{
     IpcMessage, PutResult, Ticket,
 };
 use arrow_array::RecordBatch;
-use arrow_buffer::Buffer;
-use arrow_ipc::convert::fb_to_schema;
-use arrow_ipc::reader::read_record_batch;
-use arrow_ipc::{MessageHeader, root_as_message};
-use arrow_schema::{ArrowError, Schema, SchemaRef};
+use arrow_schema::{ArrowError, Schema};
 use futures::{Stream, TryStreamExt, stream};
 use prost::Message;
 use tonic::codegen::{Body, StdError};
@@ -132,15 +135,10 @@ where
     async fn get_flight_info_for_command<M: ProstMessageExt>(
         &mut self,
         cmd: M,
-    ) -> Result<FlightInfo, ArrowError> {
+    ) -> Result<FlightInfo> {
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
         let req = self.set_request_headers(descriptor.into_request())?;
-        let fi = self
-            .flight_client
-            .get_flight_info(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
+        let fi = self.flight_client.get_flight_info(req).await?.into_inner();
         Ok(fi)
     }
 
@@ -149,7 +147,7 @@ where
         &mut self,
         query: String,
         transaction_id: Option<Bytes>,
-    ) -> Result<FlightInfo, ArrowError> {
+    ) -> Result<FlightInfo> {
         let cmd = CommandStatementQuery {
             query,
             transaction_id,
@@ -162,7 +160,7 @@ where
     /// If the server returns an "authorization" header, it is automatically 
parsed and set as
     /// a token for future requests. Any other data returned by the server in 
the handshake
     /// response is returned as a binary blob.
-    pub async fn handshake(&mut self, username: &str, password: &str) -> 
Result<Bytes, ArrowError> {
+    pub async fn handshake(&mut self, username: &str, password: &str) -> 
Result<Bytes> {
         let cmd = HandshakeRequest {
             protocol_version: 0,
             payload: Default::default(),
@@ -185,7 +183,7 @@ where
                 .map_err(|_| ArrowError::ParseError("Can't read auth 
header".to_string()))?;
             let bearer = "Bearer ";
             if !auth.starts_with(bearer) {
-                Err(ArrowError::ParseError("Invalid auth 
header!".to_string()))?;
+                return Err(ArrowError::ParseError("Invalid auth 
header!".to_string()))?;
             }
             let auth = auth[bearer.len()..].to_string();
             self.token = Some(auth);
@@ -210,7 +208,7 @@ where
         &mut self,
         query: String,
         transaction_id: Option<Bytes>,
-    ) -> Result<i64, ArrowError> {
+    ) -> Result<i64> {
         let cmd = CommandStatementUpdate {
             query,
             transaction_id,
@@ -223,19 +221,9 @@ where
             }])
             .into_request(),
         )?;
-        let mut result = self
-            .flight_client
-            .do_put(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
-        let result = result
-            .message()
-            .await
-            .map_err(status_to_arrow_error)?
-            .unwrap();
-        let result: DoPutUpdateResult =
-            
Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+        let mut result = self.flight_client.do_put(req).await?.into_inner();
+        let result = result.message().await?.unwrap();
+        let result: DoPutUpdateResult = 
Message::decode(&*result.app_metadata)?;
         Ok(result.record_count)
     }
 
@@ -244,7 +232,7 @@ where
         &mut self,
         command: CommandStatementIngest,
         stream: S,
-    ) -> Result<i64, ArrowError>
+    ) -> Result<i64>
     where
         S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
     {
@@ -261,41 +249,28 @@ where
             FallibleRequestStream::new(sender, flight_data);
 
         let req = 
self.set_request_headers(flight_data.into_streaming_request())?;
-        let mut result = self
-            .flight_client
-            .do_put(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
+        let mut result = self.flight_client.do_put(req).await?.into_inner();
 
         // check if the there were any errors in the input stream provided note
         // if receiver.await fails, it means the sender was dropped and there 
is
         // no message to return.
         if let Ok(msg) = receiver.await {
-            return Err(ArrowError::ExternalError(Box::new(msg)));
+            return Err(FlightError::ExternalError(Box::new(msg)));
         }
 
-        let result = result
-            .message()
-            .await
-            .map_err(status_to_arrow_error)?
-            .unwrap();
-        let result: DoPutUpdateResult =
-            
Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+        let result = result.message().await?.unwrap();
+        let result: DoPutUpdateResult = 
Message::decode(&*result.app_metadata)?;
         Ok(result.record_count)
     }
 
     /// Request a list of catalogs as tabular FlightInfo results
-    pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
+    pub async fn get_catalogs(&mut self) -> Result<FlightInfo> {
         self.get_flight_info_for_command(CommandGetCatalogs {})
             .await
     }
 
     /// Request a list of database schemas as tabular FlightInfo results
-    pub async fn get_db_schemas(
-        &mut self,
-        request: CommandGetDbSchemas,
-    ) -> Result<FlightInfo, ArrowError> {
+    pub async fn get_db_schemas(&mut self, request: CommandGetDbSchemas) -> 
Result<FlightInfo> {
         self.get_flight_info_for_command(request).await
     }
 
@@ -303,15 +278,10 @@ where
     pub async fn do_get(
         &mut self,
         ticket: impl IntoRequest<Ticket>,
-    ) -> Result<FlightRecordBatchStream, ArrowError> {
+    ) -> Result<FlightRecordBatchStream> {
         let req = self.set_request_headers(ticket.into_request())?;
 
-        let (md, response_stream, _ext) = self
-            .flight_client
-            .do_get(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_parts();
+        let (md, response_stream, _ext) = 
self.flight_client.do_get(req).await?.into_parts();
         let (response_stream, trailers) = 
extract_lazy_trailers(response_stream);
 
         Ok(FlightRecordBatchStream::new_from_flight_data(
@@ -325,43 +295,27 @@ where
     pub async fn do_put(
         &mut self,
         request: impl tonic::IntoStreamingRequest<Message = FlightData>,
-    ) -> Result<Streaming<PutResult>, ArrowError> {
+    ) -> Result<Streaming<PutResult>> {
         let req = self.set_request_headers(request.into_streaming_request())?;
-        Ok(self
-            .flight_client
-            .do_put(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner())
+        Ok(self.flight_client.do_put(req).await?.into_inner())
     }
 
     /// DoAction allows a flight client to do a specific action against a 
flight service
     pub async fn do_action(
         &mut self,
         request: impl IntoRequest<Action>,
-    ) -> Result<Streaming<crate::Result>, ArrowError> {
+    ) -> Result<Streaming<crate::Result>> {
         let req = self.set_request_headers(request.into_request())?;
-        Ok(self
-            .flight_client
-            .do_action(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner())
+        Ok(self.flight_client.do_action(req).await?.into_inner())
     }
 
     /// Request a list of tables.
-    pub async fn get_tables(
-        &mut self,
-        request: CommandGetTables,
-    ) -> Result<FlightInfo, ArrowError> {
+    pub async fn get_tables(&mut self, request: CommandGetTables) -> 
Result<FlightInfo> {
         self.get_flight_info_for_command(request).await
     }
 
     /// Request the primary keys for a table.
-    pub async fn get_primary_keys(
-        &mut self,
-        request: CommandGetPrimaryKeys,
-    ) -> Result<FlightInfo, ArrowError> {
+    pub async fn get_primary_keys(&mut self, request: CommandGetPrimaryKeys) 
-> Result<FlightInfo> {
         self.get_flight_info_for_command(request).await
     }
 
@@ -370,7 +324,7 @@ where
     pub async fn get_exported_keys(
         &mut self,
         request: CommandGetExportedKeys,
-    ) -> Result<FlightInfo, ArrowError> {
+    ) -> Result<FlightInfo> {
         self.get_flight_info_for_command(request).await
     }
 
@@ -378,7 +332,7 @@ where
     pub async fn get_imported_keys(
         &mut self,
         request: CommandGetImportedKeys,
-    ) -> Result<FlightInfo, ArrowError> {
+    ) -> Result<FlightInfo> {
         self.get_flight_info_for_command(request).await
     }
 
@@ -388,21 +342,18 @@ where
     pub async fn get_cross_reference(
         &mut self,
         request: CommandGetCrossReference,
-    ) -> Result<FlightInfo, ArrowError> {
+    ) -> Result<FlightInfo> {
         self.get_flight_info_for_command(request).await
     }
 
     /// Request a list of table types.
-    pub async fn get_table_types(&mut self) -> Result<FlightInfo, ArrowError> {
+    pub async fn get_table_types(&mut self) -> Result<FlightInfo> {
         self.get_flight_info_for_command(CommandGetTableTypes {})
             .await
     }
 
     /// Request a list of SQL information.
-    pub async fn get_sql_info(
-        &mut self,
-        sql_infos: Vec<SqlInfo>,
-    ) -> Result<FlightInfo, ArrowError> {
+    pub async fn get_sql_info(&mut self, sql_infos: Vec<SqlInfo>) -> 
Result<FlightInfo> {
         let request = CommandGetSqlInfo {
             info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
         };
@@ -413,7 +364,7 @@ where
     pub async fn get_xdbc_type_info(
         &mut self,
         request: CommandGetXdbcTypeInfo,
-    ) -> Result<FlightInfo, ArrowError> {
+    ) -> Result<FlightInfo> {
         self.get_flight_info_for_command(request).await
     }
 
@@ -422,7 +373,7 @@ where
         &mut self,
         query: String,
         transaction_id: Option<Bytes>,
-    ) -> Result<PreparedStatement<T>, ArrowError>
+    ) -> Result<PreparedStatement<T>>
     where
         T: Clone,
     {
@@ -435,18 +386,9 @@ where
             body: cmd.as_any().encode_to_vec().into(),
         };
         let req = self.set_request_headers(action.into_request())?;
-        let mut result = self
-            .flight_client
-            .do_action(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
-        let result = result
-            .message()
-            .await
-            .map_err(status_to_arrow_error)?
-            .unwrap();
-        let any = 
Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
+        let mut result = self.flight_client.do_action(req).await?.into_inner();
+        let result = result.message().await?.unwrap();
+        let any = Any::decode(&*result.body)?;
         let prepared_result: ActionCreatePreparedStatementResult = 
any.unpack()?.unwrap();
         let dataset_schema = match prepared_result.dataset_schema.len() {
             0 => Schema::empty(),
@@ -465,25 +407,16 @@ where
     }
 
     /// Request to begin a transaction.
-    pub async fn begin_transaction(&mut self) -> Result<Bytes, ArrowError> {
+    pub async fn begin_transaction(&mut self) -> Result<Bytes> {
         let cmd = ActionBeginTransactionRequest {};
         let action = Action {
             r#type: BEGIN_TRANSACTION.to_string(),
             body: cmd.as_any().encode_to_vec().into(),
         };
         let req = self.set_request_headers(action.into_request())?;
-        let mut result = self
-            .flight_client
-            .do_action(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
-        let result = result
-            .message()
-            .await
-            .map_err(status_to_arrow_error)?
-            .unwrap();
-        let any = 
Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
+        let mut result = self.flight_client.do_action(req).await?.into_inner();
+        let result = result.message().await?.unwrap();
+        let any = Any::decode(&*result.body)?;
         let begin_result: ActionBeginTransactionResult = 
any.unpack()?.unwrap();
         Ok(begin_result.transaction_id)
     }
@@ -493,7 +426,7 @@ where
         &mut self,
         transaction_id: Bytes,
         action: EndTransaction,
-    ) -> Result<(), ArrowError> {
+    ) -> Result<()> {
         let cmd = ActionEndTransactionRequest {
             transaction_id,
             action: action as i32,
@@ -503,25 +436,17 @@ where
             body: cmd.as_any().encode_to_vec().into(),
         };
         let req = self.set_request_headers(action.into_request())?;
-        let _ = self
-            .flight_client
-            .do_action(req)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
+        let _ = self.flight_client.do_action(req).await?.into_inner();
         Ok(())
     }
 
     /// Explicitly shut down and clean up the client.
-    pub async fn close(&mut self) -> Result<(), ArrowError> {
+    pub async fn close(&mut self) -> Result<()> {
         // TODO: consume self instead of &mut self to explicitly prevent reuse?
         Ok(())
     }
 
-    fn set_request_headers<M>(
-        &self,
-        mut req: tonic::Request<M>,
-    ) -> Result<tonic::Request<M>, ArrowError> {
+    fn set_request_headers<M>(&self, mut req: tonic::Request<M>) -> 
Result<tonic::Request<M>> {
         for (k, v) in &self.headers {
             let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
                 ArrowError::ParseError(format!("Cannot convert header key 
\"{k}\": {e}"))
@@ -584,7 +509,7 @@ where
     }
 
     /// Executes the prepared statement query on the server.
-    pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
+    pub async fn execute(&mut self) -> Result<FlightInfo> {
         self.write_bind_params().await?;
 
         let cmd = CommandPreparedStatementQuery {
@@ -599,7 +524,7 @@ where
     }
 
     /// Executes the prepared statement update query on the server.
-    pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
+    pub async fn execute_update(&mut self) -> Result<i64> {
         self.write_bind_params().await?;
 
         let cmd = CommandPreparedStatementUpdate {
@@ -613,35 +538,30 @@ where
                 ..Default::default()
             }]))
             .await?;
-        let result = result
-            .message()
-            .await
-            .map_err(status_to_arrow_error)?
-            .unwrap();
-        let result: DoPutUpdateResult =
-            
Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+        let result = result.message().await?.unwrap();
+        let result: DoPutUpdateResult = 
Message::decode(&*result.app_metadata)?;
         Ok(result.record_count)
     }
 
     /// Retrieve the parameter schema from the query.
-    pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> {
+    pub fn parameter_schema(&self) -> Result<&Schema> {
         Ok(&self.parameter_schema)
     }
 
     /// Retrieve the ResultSet schema from the query.
-    pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> {
+    pub fn dataset_schema(&self) -> Result<&Schema> {
         Ok(&self.dataset_schema)
     }
 
     /// Set a RecordBatch that contains the parameters that will be bind.
-    pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> 
Result<(), ArrowError> {
+    pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> 
Result<()> {
         self.parameter_binding = Some(parameter_binding);
         Ok(())
     }
 
     /// Submit parameters to the server, if any have been set on this prepared 
statement instance
     /// Updates our stored prepared statement handle with the handle given by 
the server response.
-    async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
+    async fn write_bind_params(&mut self) -> Result<()> {
         if let Some(ref params_batch) = self.parameter_binding {
             let cmd = CommandPreparedStatementQuery {
                 prepared_statement_handle: self.handle.clone(),
@@ -656,8 +576,7 @@ where
                     self.parameter_binding.clone().map(Ok),
                 ))
                 .try_collect::<Vec<_>>()
-                .await
-                .map_err(flight_error_to_arrow_error)?;
+                .await?;
 
             // Attempt to update the stored handle with any updated handle in 
the DoPut result.
             // Older servers do not respond with a result for DoPut, so skip 
this step when
@@ -667,8 +586,7 @@ where
                 .do_put(stream::iter(flight_data))
                 .await?
                 .message()
-                .await
-                .map_err(status_to_arrow_error)?
+                .await?
             {
                 if let Some(handle) = 
self.unpack_prepared_statement_handle(&result)? {
                     self.handle = handle;
@@ -681,18 +599,14 @@ where
     /// Decodes the app_metadata stored in a [`PutResult`] as a
     /// [`DoPutPreparedStatementResult`] and then returns
     /// the inner prepared statement handle as [`Bytes`]
-    fn unpack_prepared_statement_handle(
-        &self,
-        put_result: &PutResult,
-    ) -> Result<Option<Bytes>, ArrowError> {
-        let result: DoPutPreparedStatementResult =
-            
Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
+    fn unpack_prepared_statement_handle(&self, put_result: &PutResult) -> 
Result<Option<Bytes>> {
+        let result: DoPutPreparedStatementResult = 
Message::decode(&*put_result.app_metadata)?;
         Ok(result.prepared_statement_handle)
     }
 
     /// Close the prepared statement, so that this PreparedStatement can not 
used
     /// anymore and server can free up any resources.
-    pub async fn close(mut self) -> Result<(), ArrowError> {
+    pub async fn close(mut self) -> Result<()> {
         let cmd = ActionClosePreparedStatementRequest {
             prepared_statement_handle: self.handle.clone(),
         };
@@ -705,21 +619,6 @@ where
     }
 }
 
-fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
-    ArrowError::IpcError(err.to_string())
-}
-
-fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
-    ArrowError::IpcError(format!("{status:?}"))
-}
-
-fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
-    match err {
-        FlightError::Arrow(e) => e,
-        e => ArrowError::ExternalError(Box::new(e)),
-    }
-}
-
 /// A polymorphic structure to natively represent different types of data 
contained in `FlightData`
 pub enum ArrowFlightData {
     /// A record batch
@@ -732,7 +631,7 @@ pub enum ArrowFlightData {
 pub fn arrow_data_from_flight_data(
     flight_data: FlightData,
     arrow_schema_ref: &SchemaRef,
-) -> Result<ArrowFlightData, ArrowError> {
+) -> std::result::Result<ArrowFlightData, ArrowError> {
     let ipc_message = root_as_message(&flight_data.data_header[..])
         .map_err(|err| ArrowError::ParseError(format!("Unable to get root as 
message: {err:?}")))?;
 

Reply via email to