This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 284556e55a feat: improve flight CLI error handling (#4873)
284556e55a is described below

commit 284556e55ae073b88cd24cbc0749e19f394d21fd
Author: Marco Neumann <[email protected]>
AuthorDate: Thu Sep 28 13:51:15 2023 +0200

    feat: improve flight CLI error handling (#4873)
    
    **Before:**
    
    ```text
    thread 'main' panicked at 'collect data stream: Status { code: Internal, 
message: "h2 protocol error: error reading a body from connection: stream error 
received: unexpected internal error encountered", source: 
Some(hyper::Error(Body, Error { kind: Reset(S
    treamId(3), INTERNAL_ERROR, Remote) })) }', 
arrow-flight/src/bin/flight_sql_client.rs:130:14
    note: run with `RUST_BACKTRACE=1` environment variable to display a 
backtrace
    ```
    
    **After:**
    
    ```text
    Error: read flight data
    
    Caused by:
        0: collect data stream
        1: status: Internal, message: "h2 protocol error: error reading a body 
from connection: stream error received: unexpected internal error encountered", 
details: [], metadata: MetadataMap { headers: {} }
        2: error reading a body from connection: stream error received: 
unexpected internal error encountered
        3: stream error received: unexpected internal error encountered
    ```
---
 arrow-flight/Cargo.toml                   |  3 +-
 arrow-flight/src/bin/flight_sql_client.rs | 67 ++++++++++++++++---------------
 2 files changed, 36 insertions(+), 34 deletions(-)

diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index 29a8109d88..54c5cdf5e2 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -49,6 +49,7 @@ tokio = { version = "1.0", default-features = false, features 
= ["macros", "rt",
 tonic = { version = "0.10.0", default-features = false, features = 
["transport", "codegen", "prost"] }
 
 # CLI-related dependencies
+anyhow = { version = "1.0", optional = true }
 clap = { version = "4.1", default-features = false, features = ["std", 
"derive", "env", "help", "error-context", "usage"], optional = true }
 tracing-log = { version = "0.1", optional = true }
 tracing-subscriber = { version = "0.3.1", default-features = false, features = 
["ansi", "fmt"], optional = true }
@@ -62,7 +63,7 @@ flight-sql-experimental = ["arrow-arith", "arrow-data", 
"arrow-ord", "arrow-row"
 tls = ["tonic/tls"]
 
 # Enable CLI tools
-cli = ["arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", 
"tonic/tls-webpki-roots"]
+cli = ["anyhow", "arrow-cast/prettyprint", "clap", "tracing-log", 
"tracing-subscriber", "tonic/tls-webpki-roots"]
 
 [dev-dependencies]
 arrow-cast = { workspace = true, features = ["prettyprint"] }
diff --git a/arrow-flight/src/bin/flight_sql_client.rs 
b/arrow-flight/src/bin/flight_sql_client.rs
index d7b02414c5..c6aaccf376 100644
--- a/arrow-flight/src/bin/flight_sql_client.rs
+++ b/arrow-flight/src/bin/flight_sql_client.rs
@@ -17,13 +17,14 @@
 
 use std::{error::Error, sync::Arc, time::Duration};
 
+use anyhow::{Context, Result};
 use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
 use arrow_cast::{cast_with_options, pretty::pretty_format_batches, 
CastOptions};
 use arrow_flight::{
     sql::client::FlightSqlServiceClient, utils::flight_data_to_batches, 
FlightData,
     FlightInfo,
 };
-use arrow_schema::{ArrowError, Schema};
+use arrow_schema::Schema;
 use clap::{Parser, Subcommand};
 use futures::TryStreamExt;
 use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
@@ -116,21 +117,23 @@ enum Command {
 }
 
 #[tokio::main]
-async fn main() {
+async fn main() -> Result<()> {
     let args = Args::parse();
-    setup_logging();
-    let mut client = setup_client(args.client_args).await.expect("setup 
client");
+    setup_logging()?;
+    let mut client = setup_client(args.client_args)
+        .await
+        .context("setup client")?;
 
     let flight_info = match args.cmd {
         Command::StatementQuery { query } => client
             .execute(query, None)
             .await
-            .expect("execute statement"),
+            .context("execute statement")?,
         Command::PreparedStatementQuery { query, params } => {
             let mut prepared_stmt = client
                 .prepare(query, None)
                 .await
-                .expect("prepare statement");
+                .context("prepare statement")?;
 
             if !params.is_empty() {
                 prepared_stmt
@@ -139,33 +142,35 @@ async fn main() {
                             &params,
                             prepared_stmt
                                 .parameter_schema()
-                                .expect("get parameter schema"),
+                                .context("get parameter schema")?,
                         )
-                        .expect("construct parameters"),
+                        .context("construct parameters")?,
                     )
-                    .expect("bind parameters")
+                    .context("bind parameters")?;
             }
 
             prepared_stmt
                 .execute()
                 .await
-                .expect("execute prepared statement")
+                .context("execute prepared statement")?
         }
     };
 
     let batches = execute_flight(&mut client, flight_info)
         .await
-        .expect("read flight data");
+        .context("read flight data")?;
 
-    let res = pretty_format_batches(batches.as_slice()).expect("format 
results");
+    let res = pretty_format_batches(batches.as_slice()).context("format 
results")?;
     println!("{res}");
+
+    Ok(())
 }
 
 async fn execute_flight(
     client: &mut FlightSqlServiceClient<Channel>,
     info: FlightInfo,
-) -> Result<Vec<RecordBatch>, ArrowError> {
-    let schema = Arc::new(Schema::try_from(info.clone()).expect("valid 
schema"));
+) -> Result<Vec<RecordBatch>> {
+    let schema = Arc::new(Schema::try_from(info.clone()).context("valid 
schema")?);
     let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
     batches.push(RecordBatch::new_empty(schema));
     info!("decoded schema");
@@ -174,13 +179,13 @@ async fn execute_flight(
         let Some(ticket) = &endpoint.ticket else {
             panic!("did not get ticket");
         };
-        let flight_data = client.do_get(ticket.clone()).await.expect("do get");
+        let flight_data = client.do_get(ticket.clone()).await.context("do 
get")?;
         let flight_data: Vec<FlightData> = flight_data
             .try_collect()
             .await
-            .expect("collect data stream");
+            .context("collect data stream")?;
         let mut endpoint_batches = flight_data_to_batches(&flight_data)
-            .expect("convert flight data to record batches");
+            .context("convert flight data to record batches")?;
         batches.append(&mut endpoint_batches);
     }
     info!("received data");
@@ -191,7 +196,7 @@ async fn execute_flight(
 fn construct_record_batch_from_params(
     params: &[(String, String)],
     parameter_schema: &Schema,
-) -> Result<RecordBatch, ArrowError> {
+) -> Result<RecordBatch> {
     let mut items = Vec::<(&String, ArrayRef)>::new();
 
     for (name, value) in params {
@@ -205,23 +210,22 @@ fn construct_record_batch_from_params(
         items.push((name, casted))
     }
 
-    RecordBatch::try_from_iter(items)
+    Ok(RecordBatch::try_from_iter(items)?)
 }
 
-fn setup_logging() {
-    tracing_log::LogTracer::init().expect("tracing log init");
+fn setup_logging() -> Result<()> {
+    tracing_log::LogTracer::init().context("tracing log init")?;
     tracing_subscriber::fmt::init();
+    Ok(())
 }
 
-async fn setup_client(
-    args: ClientArgs,
-) -> Result<FlightSqlServiceClient<Channel>, ArrowError> {
+async fn setup_client(args: ClientArgs) -> 
Result<FlightSqlServiceClient<Channel>> {
     let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
 
     let protocol = if args.tls { "https" } else { "http" };
 
     let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, 
args.host, port))
-        .map_err(|_| ArrowError::IpcError("Cannot create 
endpoint".to_string()))?
+        .context("create endpoint")?
         .connect_timeout(Duration::from_secs(20))
         .timeout(Duration::from_secs(20))
         .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want 
packets to wait
@@ -232,15 +236,12 @@ async fn setup_client(
 
     if args.tls {
         let tls_config = ClientTlsConfig::new();
-        endpoint = endpoint.tls_config(tls_config).map_err(|_| {
-            ArrowError::IpcError("Cannot create TLS endpoint".to_string())
-        })?;
+        endpoint = endpoint
+            .tls_config(tls_config)
+            .context("create TLS endpoint")?;
     }
 
-    let channel = endpoint
-        .connect()
-        .await
-        .map_err(|e| ArrowError::IpcError(format!("Cannot connect to endpoint: 
{e}")))?;
+    let channel = endpoint.connect().await.context("connect to endpoint")?;
 
     let mut client = FlightSqlServiceClient::new(channel);
     info!("connected");
@@ -260,7 +261,7 @@ async fn setup_client(
             client
                 .handshake(&username, &password)
                 .await
-                .expect("handshake");
+                .context("handshake")?;
             info!("performed handshake");
         }
         (Some(_), None) => {

Reply via email to