This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 40e2874e1 refactor: assorted `FlightSqlServiceClient` improvements 
(#3788)
40e2874e1 is described below

commit 40e2874e1d83dd8dc64981b7f4a19f894befe615
Author: Marco Neumann <[email protected]>
AuthorDate: Fri Mar 3 12:38:34 2023 +0100

    refactor: assorted `FlightSqlServiceClient` improvements (#3788)
    
    * refactor: assorted `FlightSqlServiceClient` improvements
    
    - **TLS config:** Do NOT alter existing method signatures if the TLS
      feature is enabled. Features should be purely additive in Rust.
      Instead use a new method to pass TLS configs. The config is now passed
      as `ClientTlsConfig` to allow more flexibility, e.g. just to use TLS
      w/o any client certs.
    - **token handlng:** Allow the token to be passed in from an external
      source. The [auth spec] is super flexibility ("application-defined")
      and we cannot derive a way to determine the token in all cases. The
      current handshake-based mechanism is OK though. Also make sure the
      token is used in all relevant methods.
    - **headers:** Allow users to pass in additional headers. This is
      helpful for certain applications.
    
    [auth spec]: https://arrow.apache.org/docs/format/Flight.html#authentication
    
    * refactor: simplify flight SQL client construction
    
    Just accept a channel and let the caller set it up to their liking.
    Simplify example as well so that we no longer do totally different
    things under different features (since features shall be additive).
    Instead use a single example.
---
 .github/workflows/arrow_flight.yml         |   3 -
 arrow-flight/Cargo.toml                    |   2 +-
 arrow-flight/examples/flight_sql_server.rs |  85 +++++++++---------
 arrow-flight/src/sql/client.rs             | 134 ++++++++++++-----------------
 4 files changed, 99 insertions(+), 125 deletions(-)

diff --git a/.github/workflows/arrow_flight.yml 
b/.github/workflows/arrow_flight.yml
index 02c149aaa..7facf1719 100644
--- a/.github/workflows/arrow_flight.yml
+++ b/.github/workflows/arrow_flight.yml
@@ -60,9 +60,6 @@ jobs:
         run: |
           cargo test -p arrow-flight --all-features
       - name: Test --examples
-        run: |
-          cargo test -p arrow-flight  --features=flight-sql-experimental 
--examples
-      - name: Test --examples with TLS
         run: |
           cargo test -p arrow-flight  --features=flight-sql-experimental,tls 
--examples
       - name: Verify workspace clean
diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index 61959143e..fd77a814a 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -64,4 +64,4 @@ tonic-build = { version = "=0.8.4", default-features = false, 
features = ["trans
 
 [[example]]
 name = "flight_sql_server"
-required-features = ["flight-sql-experimental"]
+required-features = ["flight-sql-experimental", "tls"]
diff --git a/arrow-flight/examples/flight_sql_server.rs 
b/arrow-flight/examples/flight_sql_server.rs
index 28aef4e92..425ceab42 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -31,7 +31,6 @@ use prost::Message;
 use std::pin::Pin;
 use std::sync::Arc;
 use tonic::transport::Server;
-#[cfg(feature = "tls")]
 use tonic::transport::{Certificate, Identity, ServerTlsConfig};
 use tonic::{Request, Response, Status, Streaming};
 
@@ -451,7 +450,6 @@ impl FlightSqlService for FlightSqlServiceImpl {
 
 /// This example shows how to run a FlightSql server
 #[tokio::main]
-#[cfg(not(feature = "tls"))]
 async fn main() -> Result<(), Box<dyn std::error::Error>> {
     let addr = "0.0.0.0:50051".parse()?;
 
@@ -459,34 +457,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> 
{
 
     println!("Listening on {:?}", addr);
 
-    Server::builder().add_service(svc).serve(addr).await?;
+    if std::env::var("USE_TLS").ok().is_some() {
+        let cert = 
std::fs::read_to_string("arrow-flight/examples/data/server.pem")?;
+        let key = 
std::fs::read_to_string("arrow-flight/examples/data/server.key")?;
+        let client_ca =
+            
std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?;
 
-    Ok(())
-}
-
-/// This example shows how to run a HTTPs FlightSql server
-#[tokio::main]
-#[cfg(feature = "tls")]
-async fn main() -> Result<(), Box<dyn std::error::Error>> {
-    let addr = "0.0.0.0:50051".parse()?;
-
-    let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
-
-    println!("Listening on {addr:?}");
-
-    let cert = 
std::fs::read_to_string("arrow-flight/examples/data/server.pem")?;
-    let key = 
std::fs::read_to_string("arrow-flight/examples/data/server.key")?;
-    let client_ca = 
std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?;
-
-    let tls_config = ServerTlsConfig::new()
-        .identity(Identity::from_pem(&cert, &key))
-        .client_ca_root(Certificate::from_pem(&client_ca));
+        let tls_config = ServerTlsConfig::new()
+            .identity(Identity::from_pem(&cert, &key))
+            .client_ca_root(Certificate::from_pem(&client_ca));
 
-    Server::builder()
-        .tls_config(tls_config)?
-        .add_service(svc)
-        .serve(addr)
-        .await?;
+        Server::builder()
+            .tls_config(tls_config)?
+            .add_service(svc)
+            .serve(addr)
+            .await?;
+    } else {
+        Server::builder().add_service(svc).serve(addr).await?;
+    }
 
     Ok(())
 }
@@ -523,8 +511,6 @@ mod tests {
     use tokio_stream::wrappers::UnixListenerStream;
     use tonic::body::BoxBody;
     use tonic::codegen::{http, Body, Service};
-
-    #[cfg(feature = "tls")]
     use tonic::transport::ClientTlsConfig;
 
     use arrow::util::pretty::pretty_format_batches;
@@ -533,10 +519,9 @@ mod tests {
     use tonic::transport::{Certificate, Channel, Endpoint};
     use tower::{service_fn, ServiceExt};
 
-    #[cfg(not(feature = "tls"))]
     async fn client_with_uds(path: String) -> FlightSqlServiceClient {
         let connector = service_fn(move |_| UnixStream::connect(path.clone()));
-        let channel = Endpoint::try_from("https://example.com";)
+        let channel = Endpoint::try_from("http://example.com";)
             .unwrap()
             .connect_with_connector(connector)
             .await
@@ -544,7 +529,6 @@ mod tests {
         FlightSqlServiceClient::new(channel)
     }
 
-    #[cfg(feature = "tls")]
     async fn create_https_server() -> Result<(), tonic::transport::Error> {
         let cert = 
std::fs::read_to_string("examples/data/server.pem").unwrap();
         let key = std::fs::read_to_string("examples/data/server.key").unwrap();
@@ -567,7 +551,6 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(feature = "tls")]
     async fn test_select_https() {
         tokio::spawn(async {
             create_https_server().await.unwrap();
@@ -580,15 +563,16 @@ mod tests {
             let key = 
std::fs::read_to_string("examples/data/client1.key").unwrap();
             let server_ca = 
std::fs::read_to_string("examples/data/ca.pem").unwrap();
 
-            let mut client = FlightSqlServiceClient::new_with_endpoint(
-                Identity::from_pem(cert, key),
-                Certificate::from_pem(&server_ca),
-                "localhost",
-                "127.0.0.1",
-                50051,
-            )
-            .await
-            .unwrap();
+            let tls_config = ClientTlsConfig::new()
+                .domain_name("localhost")
+                .ca_certificate(Certificate::from_pem(&server_ca))
+                .identity(Identity::from_pem(cert, key));
+            let endpoint = endpoint(String::from("https://127.0.0.1:50051";))
+                .unwrap()
+                .tls_config(tls_config)
+                .unwrap();
+            let channel = endpoint.connect().await.unwrap();
+            let mut client = FlightSqlServiceClient::new(channel);
             let token = client.handshake("admin", "password").await.unwrap();
             println!("Auth succeeded with token: {:?}", token);
             let mut stmt = client.prepare("select 
1;".to_string()).await.unwrap();
@@ -615,7 +599,6 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(not(feature = "tls"))]
     async fn test_select_1() {
         let file = NamedTempFile::new().unwrap();
         let path = file.into_temp_path().to_str().unwrap().to_string();
@@ -657,4 +640,18 @@ mod tests {
             _ = request_future => println!("Client finished!"),
         }
     }
+
+    fn endpoint(addr: String) -> Result<Endpoint, ArrowError> {
+        let endpoint = Endpoint::new(addr)
+            .map_err(|_| ArrowError::IoError("Cannot create 
endpoint".to_string()))?
+            .connect_timeout(Duration::from_secs(20))
+            .timeout(Duration::from_secs(20))
+            .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't 
want packets to wait
+            .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
+            .http2_keep_alive_interval(Duration::from_secs(300))
+            .keep_alive_timeout(Duration::from_secs(20))
+            .keep_alive_while_idle(true);
+
+        Ok(endpoint)
+    }
 }
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index 31ba1e274..a61f06d32 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -19,7 +19,8 @@ use base64::prelude::BASE64_STANDARD;
 use base64::Engine;
 use bytes::Bytes;
 use std::collections::HashMap;
-use std::time::Duration;
+use std::str::FromStr;
+use tonic::metadata::AsciiMetadataKey;
 
 use crate::flight_service_client::FlightServiceClient;
 use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
@@ -44,16 +45,15 @@ use arrow_ipc::{root_as_message, MessageHeader};
 use arrow_schema::{ArrowError, Schema, SchemaRef};
 use futures::{stream, TryStreamExt};
 use prost::Message;
-#[cfg(feature = "tls")]
-use tonic::transport::{Certificate, ClientTlsConfig, Identity};
-use tonic::transport::{Channel, Endpoint};
-use tonic::Streaming;
+use tonic::transport::Channel;
+use tonic::{IntoRequest, Streaming};
 
 /// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow 
data
 /// by FlightSQL protocol.
 #[derive(Debug, Clone)]
 pub struct FlightSqlServiceClient {
     token: Option<String>,
+    headers: HashMap<String, String>,
     flight_client: FlightServiceClient<Channel>,
 }
 
@@ -61,68 +61,13 @@ pub struct FlightSqlServiceClient {
 /// This client is in the "experimental" stage. It is not guaranteed to follow 
the spec in all instances.
 /// Github issues are welcomed.
 impl FlightSqlServiceClient {
-    /// Creates a new FlightSql Client that connects via TCP to a server
-    #[cfg(not(feature = "tls"))]
-    pub async fn new_with_endpoint(host: &str, port: u16) -> Result<Self, 
ArrowError> {
-        let addr = format!("http://{}:{}";, host, port);
-        let endpoint = Endpoint::new(addr)
-            .map_err(|_| ArrowError::IoError("Cannot create 
endpoint".to_string()))?
-            .connect_timeout(Duration::from_secs(20))
-            .timeout(Duration::from_secs(20))
-            .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't 
want packets to wait
-            .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
-            .http2_keep_alive_interval(Duration::from_secs(300))
-            .keep_alive_timeout(Duration::from_secs(20))
-            .keep_alive_while_idle(true);
-
-        let channel = endpoint.connect().await.map_err(|e| {
-            ArrowError::IoError(format!("Cannot connect to endpoint: {}", e))
-        })?;
-        Ok(Self::new(channel))
-    }
-
-    /// Creates a new HTTPs FlightSql Client that connects via TCP to a server
-    #[cfg(feature = "tls")]
-    pub async fn new_with_endpoint(
-        client_ident: Identity,
-        server_ca: Certificate,
-        domain: &str,
-        host: &str,
-        port: u16,
-    ) -> Result<Self, ArrowError> {
-        let addr = format!("https://{host}:{port}";);
-
-        let endpoint = Endpoint::new(addr)
-            .map_err(|_| ArrowError::IoError("Cannot create 
endpoint".to_string()))?
-            .connect_timeout(Duration::from_secs(20))
-            .timeout(Duration::from_secs(20))
-            .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't 
want packets to wait
-            .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
-            .http2_keep_alive_interval(Duration::from_secs(300))
-            .keep_alive_timeout(Duration::from_secs(20))
-            .keep_alive_while_idle(true);
-
-        let tls_config = ClientTlsConfig::new()
-            .domain_name(domain)
-            .ca_certificate(server_ca)
-            .identity(client_ident);
-
-        let endpoint = endpoint
-            .tls_config(tls_config)
-            .map_err(|_| ArrowError::IoError("Cannot create 
endpoint".to_string()))?;
-
-        let channel = endpoint.connect().await.map_err(|e| {
-            ArrowError::IoError(format!("Cannot connect to endpoint: {e}"))
-        })?;
-        Ok(Self::new(channel))
-    }
-
     /// Creates a new FlightSql client that connects to a server over an 
arbitrary tonic `Channel`
     pub fn new(channel: Channel) -> Self {
         let flight_client = FlightServiceClient::new(channel);
         FlightSqlServiceClient {
             token: None,
             flight_client,
+            headers: HashMap::default(),
         }
     }
 
@@ -141,14 +86,27 @@ impl FlightSqlServiceClient {
         self.flight_client
     }
 
+    /// Set auth token to the given value.
+    pub fn set_token(&mut self, token: String) {
+        self.token = Some(token);
+    }
+
+    /// Set header value.
+    pub fn set_header(&mut self, key: impl Into<String>, value: impl 
Into<String>) {
+        let key: String = key.into();
+        let value: String = value.into();
+        self.headers.insert(key, value);
+    }
+
     async fn get_flight_info_for_command<M: ProstMessageExt>(
         &mut self,
         cmd: M,
     ) -> Result<FlightInfo, ArrowError> {
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
+        let req = self.set_request_headers(descriptor.into_request())?;
         let fi = self
             .flight_client
-            .get_flight_info(descriptor)
+            .get_flight_info(req)
             .await
             .map_err(status_to_arrow_error)?
             .into_inner();
@@ -178,6 +136,7 @@ impl FlightSqlServiceClient {
             .parse()
             .map_err(|_| ArrowError::ParseError("Cannot parse 
header".to_string()))?;
         req.metadata_mut().insert("authorization", val);
+        let req = self.set_request_headers(req)?;
         let resp = self
             .flight_client
             .handshake(req)
@@ -199,25 +158,29 @@ impl FlightSqlServiceClient {
                 ArrowError::ParseError("Can't collect responses".to_string())
             })?;
         let resp = match responses.as_slice() {
-            [resp] => resp,
-            [] => Err(ArrowError::ParseError("No handshake 
response".to_string()))?,
+            [resp] => resp.payload.clone(),
+            [] => Bytes::new(),
             _ => Err(ArrowError::ParseError(
                 "Multiple handshake responses".to_string(),
             ))?,
         };
-        Ok(resp.payload.clone())
+        Ok(resp)
     }
 
     /// Execute a update query on the server, and return the number of records 
affected
     pub async fn execute_update(&mut self, query: String) -> Result<i64, 
ArrowError> {
         let cmd = CommandStatementUpdate { query };
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
-        let mut result = self
-            .flight_client
-            .do_put(stream::iter(vec![FlightData {
+        let req = self.set_request_headers(
+            stream::iter(vec![FlightData {
                 flight_descriptor: Some(descriptor),
                 ..Default::default()
-            }]))
+            }])
+            .into_request(),
+        )?;
+        let mut result = self
+            .flight_client
+            .do_put(req)
             .await
             .map_err(status_to_arrow_error)?
             .into_inner();
@@ -251,9 +214,10 @@ impl FlightSqlServiceClient {
         &mut self,
         ticket: Ticket,
     ) -> Result<Streaming<FlightData>, ArrowError> {
+        let req = self.set_request_headers(ticket.into_request())?;
         Ok(self
             .flight_client
-            .do_get(ticket)
+            .do_get(req)
             .await
             .map_err(status_to_arrow_error)?
             .into_inner())
@@ -329,13 +293,7 @@ impl FlightSqlServiceClient {
             r#type: CREATE_PREPARED_STATEMENT.to_string(),
             body: cmd.as_any().encode_to_vec().into(),
         };
-        let mut req = tonic::Request::new(action);
-        if let Some(token) = &self.token {
-            let val = format!("Bearer {token}").parse().map_err(|_| {
-                ArrowError::IoError("Statement already closed.".to_string())
-            })?;
-            req.metadata_mut().insert("authorization", val);
-        }
+        let req = self.set_request_headers(action.into_request())?;
         let mut result = self
             .flight_client
             .do_action(req)
@@ -369,6 +327,28 @@ impl FlightSqlServiceClient {
     pub async fn close(&mut self) -> Result<(), ArrowError> {
         Ok(())
     }
+
+    fn set_request_headers<T>(
+        &self,
+        mut req: tonic::Request<T>,
+    ) -> Result<tonic::Request<T>, ArrowError> {
+        for (k, v) in &self.headers {
+            let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
+                ArrowError::IoError(format!("Cannot convert header key 
\"{k}\": {e}"))
+            })?;
+            let v = v.parse().map_err(|e| {
+                ArrowError::IoError(format!("Cannot convert header value 
\"{v}\": {e}"))
+            })?;
+            req.metadata_mut().insert(k, v);
+        }
+        if let Some(token) = &self.token {
+            let val = format!("Bearer {token}").parse().map_err(|e| {
+                ArrowError::IoError(format!("Cannot convert token to header 
value: {e}"))
+            })?;
+            req.metadata_mut().insert("authorization", val);
+        }
+        Ok(req)
+    }
 }
 
 /// A PreparedStatement

Reply via email to