This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 298d3aaf62 Change FlightSQLClient to return `FlightError` & cleanup
code (#8916)
298d3aaf62 is described below
commit 298d3aaf6225acc992099915de300497cbbba4b0
Author: 张林伟 <[email protected]>
AuthorDate: Sat Jan 10 21:40:52 2026 +0800
Change FlightSQLClient to return `FlightError` & cleanup code (#8916)
# Which issue does this PR close?
None.
# Rationale for this change
1. Keep consistent with `FlightClient`.
2. We need inner error (for example `tonic::Status`) to know if we
should retry, but ArrowError stores string of `tonic::Status`.
# What changes are included in this PR?
1. Let flight sql client returns `FlightError` .
2. Cleanup code.
# Are these changes tested?
CI
# Are there any user-facing changes?
Yes.
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-flight/src/error.rs | 6 ++
arrow-flight/src/sql/client.rs | 229 ++++++++++++-----------------------------
2 files changed, 70 insertions(+), 165 deletions(-)
diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs
index d5ac568e97..d22c24eea6 100644
--- a/arrow-flight/src/error.rs
+++ b/arrow-flight/src/error.rs
@@ -78,6 +78,12 @@ impl From<tonic::Status> for FlightError {
}
}
+impl From<prost::DecodeError> for FlightError {
+ fn from(error: prost::DecodeError) -> Self {
+ Self::DecodeError(error.to_string())
+ }
+}
+
impl From<ArrowError> for FlightError {
fn from(value: ArrowError) -> Self {
Self::Arrow(value)
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index 4fb27ab5fc..5476d4ede9 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -17,6 +17,12 @@
//! A FlightSQL Client [`FlightSqlServiceClient`]
+use arrow_buffer::Buffer;
+use arrow_ipc::MessageHeader;
+use arrow_ipc::convert::fb_to_schema;
+use arrow_ipc::reader::read_record_batch;
+use arrow_ipc::root_as_message;
+use arrow_schema::SchemaRef;
use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use bytes::Bytes;
@@ -27,6 +33,7 @@ use tonic::metadata::AsciiMetadataKey;
use crate::decode::FlightRecordBatchStream;
use crate::encode::FlightDataEncoderBuilder;
use crate::error::FlightError;
+use crate::error::Result;
use crate::flight_service_client::FlightServiceClient;
use crate::sql::r#gen::action_end_transaction_request::EndTransaction;
use crate::sql::server::{
@@ -49,11 +56,7 @@ use crate::{
IpcMessage, PutResult, Ticket,
};
use arrow_array::RecordBatch;
-use arrow_buffer::Buffer;
-use arrow_ipc::convert::fb_to_schema;
-use arrow_ipc::reader::read_record_batch;
-use arrow_ipc::{MessageHeader, root_as_message};
-use arrow_schema::{ArrowError, Schema, SchemaRef};
+use arrow_schema::{ArrowError, Schema};
use futures::{Stream, TryStreamExt, stream};
use prost::Message;
use tonic::codegen::{Body, StdError};
@@ -132,15 +135,10 @@ where
async fn get_flight_info_for_command<M: ProstMessageExt>(
&mut self,
cmd: M,
- ) -> Result<FlightInfo, ArrowError> {
+ ) -> Result<FlightInfo> {
let descriptor =
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let req = self.set_request_headers(descriptor.into_request())?;
- let fi = self
- .flight_client
- .get_flight_info(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
+ let fi = self.flight_client.get_flight_info(req).await?.into_inner();
Ok(fi)
}
@@ -149,7 +147,7 @@ where
&mut self,
query: String,
transaction_id: Option<Bytes>,
- ) -> Result<FlightInfo, ArrowError> {
+ ) -> Result<FlightInfo> {
let cmd = CommandStatementQuery {
query,
transaction_id,
@@ -162,7 +160,7 @@ where
/// If the server returns an "authorization" header, it is automatically
parsed and set as
/// a token for future requests. Any other data returned by the server in
the handshake
/// response is returned as a binary blob.
- pub async fn handshake(&mut self, username: &str, password: &str) ->
Result<Bytes, ArrowError> {
+ pub async fn handshake(&mut self, username: &str, password: &str) ->
Result<Bytes> {
let cmd = HandshakeRequest {
protocol_version: 0,
payload: Default::default(),
@@ -185,7 +183,7 @@ where
.map_err(|_| ArrowError::ParseError("Can't read auth
header".to_string()))?;
let bearer = "Bearer ";
if !auth.starts_with(bearer) {
- Err(ArrowError::ParseError("Invalid auth
header!".to_string()))?;
+ return Err(ArrowError::ParseError("Invalid auth
header!".to_string()))?;
}
let auth = auth[bearer.len()..].to_string();
self.token = Some(auth);
@@ -210,7 +208,7 @@ where
&mut self,
query: String,
transaction_id: Option<Bytes>,
- ) -> Result<i64, ArrowError> {
+ ) -> Result<i64> {
let cmd = CommandStatementUpdate {
query,
transaction_id,
@@ -223,19 +221,9 @@ where
}])
.into_request(),
)?;
- let mut result = self
- .flight_client
- .do_put(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
- let result = result
- .message()
- .await
- .map_err(status_to_arrow_error)?
- .unwrap();
- let result: DoPutUpdateResult =
-
Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+ let mut result = self.flight_client.do_put(req).await?.into_inner();
+ let result = result.message().await?.unwrap();
+ let result: DoPutUpdateResult =
Message::decode(&*result.app_metadata)?;
Ok(result.record_count)
}
@@ -244,7 +232,7 @@ where
&mut self,
command: CommandStatementIngest,
stream: S,
- ) -> Result<i64, ArrowError>
+ ) -> Result<i64>
where
S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
{
@@ -261,41 +249,28 @@ where
FallibleRequestStream::new(sender, flight_data);
let req =
self.set_request_headers(flight_data.into_streaming_request())?;
- let mut result = self
- .flight_client
- .do_put(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
+ let mut result = self.flight_client.do_put(req).await?.into_inner();
// check if the there were any errors in the input stream provided note
// if receiver.await fails, it means the sender was dropped and there
is
// no message to return.
if let Ok(msg) = receiver.await {
- return Err(ArrowError::ExternalError(Box::new(msg)));
+ return Err(FlightError::ExternalError(Box::new(msg)));
}
- let result = result
- .message()
- .await
- .map_err(status_to_arrow_error)?
- .unwrap();
- let result: DoPutUpdateResult =
-
Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+ let result = result.message().await?.unwrap();
+ let result: DoPutUpdateResult =
Message::decode(&*result.app_metadata)?;
Ok(result.record_count)
}
/// Request a list of catalogs as tabular FlightInfo results
- pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
+ pub async fn get_catalogs(&mut self) -> Result<FlightInfo> {
self.get_flight_info_for_command(CommandGetCatalogs {})
.await
}
/// Request a list of database schemas as tabular FlightInfo results
- pub async fn get_db_schemas(
- &mut self,
- request: CommandGetDbSchemas,
- ) -> Result<FlightInfo, ArrowError> {
+ pub async fn get_db_schemas(&mut self, request: CommandGetDbSchemas) ->
Result<FlightInfo> {
self.get_flight_info_for_command(request).await
}
@@ -303,15 +278,10 @@ where
pub async fn do_get(
&mut self,
ticket: impl IntoRequest<Ticket>,
- ) -> Result<FlightRecordBatchStream, ArrowError> {
+ ) -> Result<FlightRecordBatchStream> {
let req = self.set_request_headers(ticket.into_request())?;
- let (md, response_stream, _ext) = self
- .flight_client
- .do_get(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_parts();
+ let (md, response_stream, _ext) =
self.flight_client.do_get(req).await?.into_parts();
let (response_stream, trailers) =
extract_lazy_trailers(response_stream);
Ok(FlightRecordBatchStream::new_from_flight_data(
@@ -325,43 +295,27 @@ where
pub async fn do_put(
&mut self,
request: impl tonic::IntoStreamingRequest<Message = FlightData>,
- ) -> Result<Streaming<PutResult>, ArrowError> {
+ ) -> Result<Streaming<PutResult>> {
let req = self.set_request_headers(request.into_streaming_request())?;
- Ok(self
- .flight_client
- .do_put(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner())
+ Ok(self.flight_client.do_put(req).await?.into_inner())
}
/// DoAction allows a flight client to do a specific action against a
flight service
pub async fn do_action(
&mut self,
request: impl IntoRequest<Action>,
- ) -> Result<Streaming<crate::Result>, ArrowError> {
+ ) -> Result<Streaming<crate::Result>> {
let req = self.set_request_headers(request.into_request())?;
- Ok(self
- .flight_client
- .do_action(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner())
+ Ok(self.flight_client.do_action(req).await?.into_inner())
}
/// Request a list of tables.
- pub async fn get_tables(
- &mut self,
- request: CommandGetTables,
- ) -> Result<FlightInfo, ArrowError> {
+ pub async fn get_tables(&mut self, request: CommandGetTables) ->
Result<FlightInfo> {
self.get_flight_info_for_command(request).await
}
/// Request the primary keys for a table.
- pub async fn get_primary_keys(
- &mut self,
- request: CommandGetPrimaryKeys,
- ) -> Result<FlightInfo, ArrowError> {
+ pub async fn get_primary_keys(&mut self, request: CommandGetPrimaryKeys)
-> Result<FlightInfo> {
self.get_flight_info_for_command(request).await
}
@@ -370,7 +324,7 @@ where
pub async fn get_exported_keys(
&mut self,
request: CommandGetExportedKeys,
- ) -> Result<FlightInfo, ArrowError> {
+ ) -> Result<FlightInfo> {
self.get_flight_info_for_command(request).await
}
@@ -378,7 +332,7 @@ where
pub async fn get_imported_keys(
&mut self,
request: CommandGetImportedKeys,
- ) -> Result<FlightInfo, ArrowError> {
+ ) -> Result<FlightInfo> {
self.get_flight_info_for_command(request).await
}
@@ -388,21 +342,18 @@ where
pub async fn get_cross_reference(
&mut self,
request: CommandGetCrossReference,
- ) -> Result<FlightInfo, ArrowError> {
+ ) -> Result<FlightInfo> {
self.get_flight_info_for_command(request).await
}
/// Request a list of table types.
- pub async fn get_table_types(&mut self) -> Result<FlightInfo, ArrowError> {
+ pub async fn get_table_types(&mut self) -> Result<FlightInfo> {
self.get_flight_info_for_command(CommandGetTableTypes {})
.await
}
/// Request a list of SQL information.
- pub async fn get_sql_info(
- &mut self,
- sql_infos: Vec<SqlInfo>,
- ) -> Result<FlightInfo, ArrowError> {
+ pub async fn get_sql_info(&mut self, sql_infos: Vec<SqlInfo>) ->
Result<FlightInfo> {
let request = CommandGetSqlInfo {
info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
};
@@ -413,7 +364,7 @@ where
pub async fn get_xdbc_type_info(
&mut self,
request: CommandGetXdbcTypeInfo,
- ) -> Result<FlightInfo, ArrowError> {
+ ) -> Result<FlightInfo> {
self.get_flight_info_for_command(request).await
}
@@ -422,7 +373,7 @@ where
&mut self,
query: String,
transaction_id: Option<Bytes>,
- ) -> Result<PreparedStatement<T>, ArrowError>
+ ) -> Result<PreparedStatement<T>>
where
T: Clone,
{
@@ -435,18 +386,9 @@ where
body: cmd.as_any().encode_to_vec().into(),
};
let req = self.set_request_headers(action.into_request())?;
- let mut result = self
- .flight_client
- .do_action(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
- let result = result
- .message()
- .await
- .map_err(status_to_arrow_error)?
- .unwrap();
- let any =
Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
+ let mut result = self.flight_client.do_action(req).await?.into_inner();
+ let result = result.message().await?.unwrap();
+ let any = Any::decode(&*result.body)?;
let prepared_result: ActionCreatePreparedStatementResult =
any.unpack()?.unwrap();
let dataset_schema = match prepared_result.dataset_schema.len() {
0 => Schema::empty(),
@@ -465,25 +407,16 @@ where
}
/// Request to begin a transaction.
- pub async fn begin_transaction(&mut self) -> Result<Bytes, ArrowError> {
+ pub async fn begin_transaction(&mut self) -> Result<Bytes> {
let cmd = ActionBeginTransactionRequest {};
let action = Action {
r#type: BEGIN_TRANSACTION.to_string(),
body: cmd.as_any().encode_to_vec().into(),
};
let req = self.set_request_headers(action.into_request())?;
- let mut result = self
- .flight_client
- .do_action(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
- let result = result
- .message()
- .await
- .map_err(status_to_arrow_error)?
- .unwrap();
- let any =
Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
+ let mut result = self.flight_client.do_action(req).await?.into_inner();
+ let result = result.message().await?.unwrap();
+ let any = Any::decode(&*result.body)?;
let begin_result: ActionBeginTransactionResult =
any.unpack()?.unwrap();
Ok(begin_result.transaction_id)
}
@@ -493,7 +426,7 @@ where
&mut self,
transaction_id: Bytes,
action: EndTransaction,
- ) -> Result<(), ArrowError> {
+ ) -> Result<()> {
let cmd = ActionEndTransactionRequest {
transaction_id,
action: action as i32,
@@ -503,25 +436,17 @@ where
body: cmd.as_any().encode_to_vec().into(),
};
let req = self.set_request_headers(action.into_request())?;
- let _ = self
- .flight_client
- .do_action(req)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
+ let _ = self.flight_client.do_action(req).await?.into_inner();
Ok(())
}
/// Explicitly shut down and clean up the client.
- pub async fn close(&mut self) -> Result<(), ArrowError> {
+ pub async fn close(&mut self) -> Result<()> {
// TODO: consume self instead of &mut self to explicitly prevent reuse?
Ok(())
}
- fn set_request_headers<M>(
- &self,
- mut req: tonic::Request<M>,
- ) -> Result<tonic::Request<M>, ArrowError> {
+ fn set_request_headers<M>(&self, mut req: tonic::Request<M>) ->
Result<tonic::Request<M>> {
for (k, v) in &self.headers {
let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
ArrowError::ParseError(format!("Cannot convert header key
\"{k}\": {e}"))
@@ -584,7 +509,7 @@ where
}
/// Executes the prepared statement query on the server.
- pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
+ pub async fn execute(&mut self) -> Result<FlightInfo> {
self.write_bind_params().await?;
let cmd = CommandPreparedStatementQuery {
@@ -599,7 +524,7 @@ where
}
/// Executes the prepared statement update query on the server.
- pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
+ pub async fn execute_update(&mut self) -> Result<i64> {
self.write_bind_params().await?;
let cmd = CommandPreparedStatementUpdate {
@@ -613,35 +538,30 @@ where
..Default::default()
}]))
.await?;
- let result = result
- .message()
- .await
- .map_err(status_to_arrow_error)?
- .unwrap();
- let result: DoPutUpdateResult =
-
Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
+ let result = result.message().await?.unwrap();
+ let result: DoPutUpdateResult =
Message::decode(&*result.app_metadata)?;
Ok(result.record_count)
}
/// Retrieve the parameter schema from the query.
- pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> {
+ pub fn parameter_schema(&self) -> Result<&Schema> {
Ok(&self.parameter_schema)
}
/// Retrieve the ResultSet schema from the query.
- pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> {
+ pub fn dataset_schema(&self) -> Result<&Schema> {
Ok(&self.dataset_schema)
}
/// Set a RecordBatch that contains the parameters that will be bind.
- pub fn set_parameters(&mut self, parameter_binding: RecordBatch) ->
Result<(), ArrowError> {
+ pub fn set_parameters(&mut self, parameter_binding: RecordBatch) ->
Result<()> {
self.parameter_binding = Some(parameter_binding);
Ok(())
}
/// Submit parameters to the server, if any have been set on this prepared
statement instance
/// Updates our stored prepared statement handle with the handle given by
the server response.
- async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
+ async fn write_bind_params(&mut self) -> Result<()> {
if let Some(ref params_batch) = self.parameter_binding {
let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
@@ -656,8 +576,7 @@ where
self.parameter_binding.clone().map(Ok),
))
.try_collect::<Vec<_>>()
- .await
- .map_err(flight_error_to_arrow_error)?;
+ .await?;
// Attempt to update the stored handle with any updated handle in
the DoPut result.
// Older servers do not respond with a result for DoPut, so skip
this step when
@@ -667,8 +586,7 @@ where
.do_put(stream::iter(flight_data))
.await?
.message()
- .await
- .map_err(status_to_arrow_error)?
+ .await?
{
if let Some(handle) =
self.unpack_prepared_statement_handle(&result)? {
self.handle = handle;
@@ -681,18 +599,14 @@ where
/// Decodes the app_metadata stored in a [`PutResult`] as a
/// [`DoPutPreparedStatementResult`] and then returns
/// the inner prepared statement handle as [`Bytes`]
- fn unpack_prepared_statement_handle(
- &self,
- put_result: &PutResult,
- ) -> Result<Option<Bytes>, ArrowError> {
- let result: DoPutPreparedStatementResult =
-
Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
+ fn unpack_prepared_statement_handle(&self, put_result: &PutResult) ->
Result<Option<Bytes>> {
+ let result: DoPutPreparedStatementResult =
Message::decode(&*put_result.app_metadata)?;
Ok(result.prepared_statement_handle)
}
/// Close the prepared statement, so that this PreparedStatement can not
used
/// anymore and server can free up any resources.
- pub async fn close(mut self) -> Result<(), ArrowError> {
+ pub async fn close(mut self) -> Result<()> {
let cmd = ActionClosePreparedStatementRequest {
prepared_statement_handle: self.handle.clone(),
};
@@ -705,21 +619,6 @@ where
}
}
-fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
- ArrowError::IpcError(err.to_string())
-}
-
-fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
- ArrowError::IpcError(format!("{status:?}"))
-}
-
-fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
- match err {
- FlightError::Arrow(e) => e,
- e => ArrowError::ExternalError(Box::new(e)),
- }
-}
-
/// A polymorphic structure to natively represent different types of data
contained in `FlightData`
pub enum ArrowFlightData {
/// A record batch
@@ -732,7 +631,7 @@ pub enum ArrowFlightData {
pub fn arrow_data_from_flight_data(
flight_data: FlightData,
arrow_schema_ref: &SchemaRef,
-) -> Result<ArrowFlightData, ArrowError> {
+) -> std::result::Result<ArrowFlightData, ArrowError> {
let ipc_message = root_as_message(&flight_data.data_header[..])
.map_err(|err| ArrowError::ParseError(format!("Unable to get root as
message: {err:?}")))?;