This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 7f460aff0 feat: add simple flight SQL CLI client (#3789)
7f460aff0 is described below
commit 7f460aff0a6438f2ff90087fb9ecd6aa0e1e891b
Author: Marco Neumann <[email protected]>
AuthorDate: Mon Mar 6 18:40:50 2023 +0100
feat: add simple flight SQL CLI client (#3789)
---
arrow-flight/Cargo.toml | 13 ++
arrow-flight/src/bin/flight_sql_client.rs | 199 ++++++++++++++++++++++++++++++
2 files changed, 212 insertions(+)
diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index fd77a814a..f1cd7d4fb 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -41,6 +41,12 @@ prost-derive = { version = "0.11", default-features = false }
tokio = { version = "1.0", default-features = false, features = ["macros",
"rt", "rt-multi-thread"] }
futures = { version = "0.3", default-features = false, features = ["alloc"] }
+# CLI-related dependencies
+arrow = { version = "34.0.0", path = "../arrow", optional = true }
+clap = { version = "4.1", default-features = false, features = ["std",
"derive", "env", "help", "error-context", "usage"], optional = true }
+tracing-log = { version = "0.1", optional = true }
+tracing-subscriber = { version = "0.3.1", default-features = false, features =
["ansi", "fmt"], optional = true }
+
[package.metadata.docs.rs]
all-features = true
@@ -49,6 +55,9 @@ default = []
flight-sql-experimental = []
tls = ["tonic/tls"]
+# Enable CLI tools
+cli = ["arrow/prettyprint", "clap", "tracing-log", "tracing-subscriber",
"tonic/tls-webpki-roots"]
+
[dev-dependencies]
arrow = { version = "34.0.0", path = "../arrow", features = ["prettyprint"] }
tempfile = "3.3"
@@ -65,3 +74,7 @@ tonic-build = { version = "=0.8.4", default-features = false,
features = ["trans
[[example]]
name = "flight_sql_server"
required-features = ["flight-sql-experimental", "tls"]
+
+[[bin]]
+name = "flight_sql_client"
+required-features = ["cli", "flight-sql-experimental", "tls"]
diff --git a/arrow-flight/src/bin/flight_sql_client.rs
b/arrow-flight/src/bin/flight_sql_client.rs
new file mode 100644
index 000000000..9f211eaf6
--- /dev/null
+++ b/arrow-flight/src/bin/flight_sql_client.rs
@@ -0,0 +1,199 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{sync::Arc, time::Duration};
+
+use arrow::error::Result;
+use arrow::util::pretty::pretty_format_batches;
+use arrow_array::RecordBatch;
+use arrow_flight::{
+ sql::client::FlightSqlServiceClient, utils::flight_data_to_batches,
FlightData,
+};
+use arrow_schema::{ArrowError, Schema};
+use clap::Parser;
+use futures::TryStreamExt;
+use tonic::transport::{ClientTlsConfig, Endpoint};
+use tracing_log::log::info;
+
+/// A ':' separated key value pair
+#[derive(Debug, Clone)]
+struct KeyValue<K, V> {
+ pub key: K,
+ pub value: V,
+}
+
+impl<K, V> std::str::FromStr for KeyValue<K, V>
+where
+ K: std::str::FromStr,
+ V: std::str::FromStr,
+ K::Err: std::fmt::Display,
+ V::Err: std::fmt::Display,
+{
+ type Err = String;
+
+ fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
+ let parts = s.splitn(2, ':').collect::<Vec<_>>();
+ match parts.as_slice() {
+ [key, value] => {
+ let key = K::from_str(key).map_err(|e| e.to_string())?;
+ let value = V::from_str(value).map_err(|e| e.to_string())?;
+ Ok(Self { key, value })
+ }
+ _ => Err(format!(
+ "Invalid key value pair - expected 'KEY:VALUE' got '{s}'"
+ )),
+ }
+ }
+}
+
+#[derive(Debug, Parser)]
+struct ClientArgs {
+ /// Additional headers.
+ ///
+ /// Values should be key value pairs separated by ':'
+ #[clap(long, value_delimiter = ',')]
+ headers: Vec<KeyValue<String, String>>,
+
+ /// Username
+ #[clap(long)]
+ username: Option<String>,
+
+ /// Password
+ #[clap(long)]
+ password: Option<String>,
+
+ /// Auth token.
+ #[clap(long)]
+ token: Option<String>,
+
+ /// Use TLS.
+ #[clap(long)]
+ tls: bool,
+
+ /// Server host.
+ #[clap(long)]
+ host: String,
+
+ /// Server port.
+ #[clap(long)]
+ port: Option<u16>,
+}
+
+#[derive(Debug, Parser)]
+struct Args {
+ /// Client args.
+ #[clap(flatten)]
+ client_args: ClientArgs,
+
+ /// SQL query.
+ query: String,
+}
+
+#[tokio::main]
+async fn main() {
+ let args = Args::parse();
+ setup_logging();
+ let mut client = setup_client(args.client_args).await.expect("setup
client");
+
+ let info = client.execute(args.query).await.expect("prepare statement");
+ info!("got flight info");
+
+ let schema = Arc::new(Schema::try_from(info.clone()).expect("valid
schema"));
+ let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
+ batches.push(RecordBatch::new_empty(schema));
+ info!("decoded schema");
+
+ for endpoint in info.endpoint {
+ let Some(ticket) = &endpoint.ticket else {
+ panic!("did not get ticket");
+ };
+ let flight_data = client.do_get(ticket.clone()).await.expect("do get");
+ let flight_data: Vec<FlightData> = flight_data
+ .try_collect()
+ .await
+ .expect("collect data stream");
+ let mut endpoint_batches = flight_data_to_batches(&flight_data)
+ .expect("convert flight data to record batches");
+ batches.append(&mut endpoint_batches);
+ }
+ info!("received data");
+
+ let res = pretty_format_batches(batches.as_slice()).expect("format
results");
+ println!("{res}");
+}
+
+fn setup_logging() {
+ tracing_log::LogTracer::init().expect("tracing log init");
+ tracing_subscriber::fmt::init();
+}
+
+async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient> {
+ let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
+
+ let mut endpoint = Endpoint::new(format!("https://{}:{}", args.host, port))
+ .map_err(|_| ArrowError::IoError("Cannot create
endpoint".to_string()))?
+ .connect_timeout(Duration::from_secs(20))
+ .timeout(Duration::from_secs(20))
+ .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want
packets to wait
+ .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
+ .http2_keep_alive_interval(Duration::from_secs(300))
+ .keep_alive_timeout(Duration::from_secs(20))
+ .keep_alive_while_idle(true);
+
+ if args.tls {
+ let tls_config = ClientTlsConfig::new();
+ endpoint = endpoint
+ .tls_config(tls_config)
+ .map_err(|_| ArrowError::IoError("Cannot create TLS
endpoint".to_string()))?;
+ }
+
+ let channel = endpoint
+ .connect()
+ .await
+ .map_err(|e| ArrowError::IoError(format!("Cannot connect to endpoint:
{e}")))?;
+
+ let mut client = FlightSqlServiceClient::new(channel);
+ info!("connected");
+
+ for kv in args.headers {
+ client.set_header(kv.key, kv.value);
+ }
+
+ if let Some(token) = args.token {
+ client.set_token(token);
+ info!("token set");
+ }
+
+ match (args.username, args.password) {
+ (None, None) => {}
+ (Some(username), Some(password)) => {
+ client
+ .handshake(&username, &password)
+ .await
+ .expect("handshake");
+ info!("performed handshake");
+ }
+ (Some(_), None) => {
+ panic!("when username is set, you also need to set a password")
+ }
+ (None, Some(_)) => {
+ panic!("when password is set, you also need to set a username")
+ }
+ }
+
+ Ok(client)
+}