This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new abf5367d4 Fix(flight_sql): PreparedStatement has no token for auth. 
(#3948)
abf5367d4 is described below

commit abf5367d4828b71e152bc159ce7c70c86181eebc
Author: Yang Xiufeng <[email protected]>
AuthorDate: Thu Mar 30 18:55:47 2023 +0800

    Fix(flight_sql): PreparedStatement has no token for auth. (#3948)
    
    * Fix(flight_sql): PreparedStatement need FlightSqlServiceClient to set 
headers.
    
    In particular, the token is required for auth in each request.
    
    * refactor: make FlightSqlServiceClient generic.
    
    * test: example FlightSqlServiceImpl check token for each request .
    
    * remove FlightSqlServiceClient::get_flight_info.
    
    * keep consistent of do_get/do_action/do_put.
    
    * code reuse in tests of example FlightSqlServiceImpl.
    
    * add cases for auth failure.
---
 arrow-flight/examples/flight_sql_server.rs | 180 +++++++++++++++++++----------
 arrow-flight/src/bin/flight_sql_client.rs  |   6 +-
 arrow-flight/src/sql/client.rs             |  67 +++++++----
 3 files changed, 168 insertions(+), 85 deletions(-)

diff --git a/arrow-flight/examples/flight_sql_server.rs 
b/arrow-flight/examples/flight_sql_server.rs
index bc9d24656..08744b65f 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -59,10 +59,35 @@ macro_rules! status {
     };
 }
 
+const FAKE_TOKEN: &str = "uuid_token";
+const FAKE_HANDLE: &str = "uuid_handle";
+const FAKE_UPDATE_RESULT: i64 = 1;
+
 #[derive(Clone)]
 pub struct FlightSqlServiceImpl {}
 
 impl FlightSqlServiceImpl {
+    fn check_token<T>(&self, req: &Request<T>) -> Result<(), Status> {
+        let metadata = req.metadata();
+        let auth = metadata.get("authorization").ok_or_else(|| {
+            Status::internal(format!("No authorization header! metadata = 
{metadata:?}"))
+        })?;
+        let str = auth
+            .to_str()
+            .map_err(|e| Status::internal(format!("Error parsing header: 
{e}")))?;
+        let authorization = str.to_string();
+        let bearer = "Bearer ";
+        if !authorization.starts_with(bearer) {
+            Err(Status::internal("Invalid auth header!"))?;
+        }
+        let token = authorization[bearer.len()..].to_string();
+        if token == FAKE_TOKEN {
+            Ok(())
+        } else {
+            Err(Status::unauthenticated("invalid token "))
+        }
+    }
+
     fn fake_result() -> Result<RecordBatch, ArrowError> {
         let schema = Schema::new(vec![Field::new("salutation", DataType::Utf8, 
false)]);
         let mut builder = StringBuilder::new();
@@ -70,10 +95,6 @@ impl FlightSqlServiceImpl {
         let cols = vec![Arc::new(builder.finish()) as ArrayRef];
         RecordBatch::try_new(Arc::new(schema), cols)
     }
-
-    fn fake_update_result() -> i64 {
-        1
-    }
 }
 
 #[tonic::async_trait]
@@ -118,7 +139,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
 
         let result = HandshakeResponse {
             protocol_version: 0,
-            payload: "random_uuid_token".into(),
+            payload: FAKE_TOKEN.into(),
         };
         let result = Ok(result);
         let output = futures::stream::iter(vec![result]);
@@ -127,9 +148,10 @@ impl FlightSqlService for FlightSqlServiceImpl {
 
     async fn do_get_fallback(
         &self,
-        _request: Request<Ticket>,
+        request: Request<Ticket>,
         _message: Any,
     ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
+        self.check_token(&request)?;
         let batch =
             Self::fake_result().map_err(|e| status!("Could not fake a result", 
e))?;
         let schema = (*batch.schema()).clone();
@@ -158,8 +180,9 @@ impl FlightSqlService for FlightSqlServiceImpl {
     async fn get_flight_info_prepared_statement(
         &self,
         cmd: CommandPreparedStatementQuery,
-        _request: Request<FlightDescriptor>,
+        request: Request<FlightDescriptor>,
     ) -> Result<Response<FlightInfo>, Status> {
+        self.check_token(&request)?;
         let handle = std::str::from_utf8(&cmd.prepared_statement_handle)
             .map_err(|e| status!("Unable to parse handle", e))?;
         let batch =
@@ -395,7 +418,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
         _ticket: CommandStatementUpdate,
         _request: Request<Streaming<FlightData>>,
     ) -> Result<i64, Status> {
-        Ok(FlightSqlServiceImpl::fake_update_result())
+        Ok(FAKE_UPDATE_RESULT)
     }
 
     async fn do_put_prepared_statement_query(
@@ -421,9 +444,9 @@ impl FlightSqlService for FlightSqlServiceImpl {
     async fn do_action_create_prepared_statement(
         &self,
         _query: ActionCreatePreparedStatementRequest,
-        _request: Request<Action>,
+        request: Request<Action>,
     ) -> Result<ActionCreatePreparedStatementResult, Status> {
-        let handle = "some_uuid";
+        self.check_token(&request)?;
         let schema = Self::fake_result()
             .map_err(|e| status!("Error getting result schema", e))?
             .schema();
@@ -432,7 +455,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
             .map_err(|e| status!("Unable to serialize schema", e))?;
         let IpcMessage(schema_bytes) = message;
         let res = ActionCreatePreparedStatementResult {
-            prepared_statement_handle: handle.into(),
+            prepared_statement_handle: FAKE_HANDLE.into(),
             dataset_schema: schema_bytes,
             parameter_schema: Default::default(), // TODO: parameters
         };
@@ -505,12 +528,13 @@ mod tests {
     use super::*;
     use futures::TryStreamExt;
     use std::fs;
+    use std::future::Future;
     use std::time::Duration;
     use tempfile::NamedTempFile;
     use tokio::net::{UnixListener, UnixStream};
     use tokio::time::sleep;
     use tokio_stream::wrappers::UnixListenerStream;
-    use tonic::transport::ClientTlsConfig;
+    use tonic::transport::{Channel, ClientTlsConfig};
 
     use arrow_cast::pretty::pretty_format_batches;
     use arrow_flight::sql::client::FlightSqlServiceClient;
@@ -518,7 +542,7 @@ mod tests {
     use tonic::transport::{Certificate, Endpoint};
     use tower::service_fn;
 
-    async fn client_with_uds(path: String) -> FlightSqlServiceClient {
+    async fn client_with_uds(path: String) -> FlightSqlServiceClient<Channel> {
         let connector = service_fn(move |_| UnixStream::connect(path.clone()));
         let channel = Endpoint::try_from("http://example.com";)
             .unwrap()
@@ -549,6 +573,20 @@ mod tests {
             .await
     }
 
+    fn endpoint(addr: String) -> Result<Endpoint, ArrowError> {
+        let endpoint = Endpoint::new(addr)
+            .map_err(|_| ArrowError::IoError("Cannot create 
endpoint".to_string()))?
+            .connect_timeout(Duration::from_secs(20))
+            .timeout(Duration::from_secs(20))
+            .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't 
want packets to wait
+            .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
+            .http2_keep_alive_interval(Duration::from_secs(300))
+            .keep_alive_timeout(Duration::from_secs(20))
+            .keep_alive_while_idle(true);
+
+        Ok(endpoint)
+    }
+
     #[tokio::test]
     async fn test_select_https() {
         tokio::spawn(async {
@@ -573,6 +611,7 @@ mod tests {
             let channel = endpoint.connect().await.unwrap();
             let mut client = FlightSqlServiceClient::new(channel);
             let token = client.handshake("admin", "password").await.unwrap();
+            client.set_token(String::from_utf8(token.to_vec()).unwrap());
             println!("Auth succeeded with token: {:?}", token);
             let mut stmt = client.prepare("select 
1;".to_string()).await.unwrap();
             let flight_info = stmt.execute().await.unwrap();
@@ -597,8 +636,16 @@ mod tests {
         }
     }
 
-    #[tokio::test]
-    async fn test_select_1() {
+    async fn auth_client(client: &mut FlightSqlServiceClient<Channel>) {
+        let token = client.handshake("admin", "password").await.unwrap();
+        client.set_token(String::from_utf8(token.to_vec()).unwrap());
+    }
+
+    async fn test_client<F, C>(f: F)
+    where
+        F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
+        C: Future<Output = ()>,
+    {
         let file = NamedTempFile::new().unwrap();
         let path = file.into_temp_path().to_str().unwrap().to_string();
         let _ = fs::remove_file(path.clone());
@@ -613,9 +660,20 @@ mod tests {
             .serve_with_incoming(stream);
 
         let request_future = async {
-            let mut client = client_with_uds(path).await;
-            let token = client.handshake("admin", "password").await.unwrap();
-            println!("Auth succeeded with token: {:?}", token);
+            let client = client_with_uds(path).await;
+            f(client).await
+        };
+
+        tokio::select! {
+            _ = serve_future => panic!("server returned first"),
+            _ = request_future => println!("Client finished!"),
+        }
+    }
+
+    #[tokio::test]
+    async fn test_select_1() {
+        test_client(|mut client| async move {
+            auth_client(&mut client).await;
             let mut stmt = client.prepare("select 
1;".to_string()).await.unwrap();
             let flight_info = stmt.execute().await.unwrap();
             let ticket = 
flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
@@ -632,57 +690,61 @@ mod tests {
                 .trim()
                 .to_string();
             assert_eq!(res.to_string(), expected);
-        };
-
-        tokio::select! {
-            _ = serve_future => panic!("server returned first"),
-            _ = request_future => println!("Client finished!"),
-        }
+        })
+        .await
     }
 
     #[tokio::test]
     async fn test_execute_update() {
-        let file = NamedTempFile::new().unwrap();
-        let path = file.into_temp_path().to_str().unwrap().to_string();
-        let _ = fs::remove_file(path.clone());
-
-        let uds = UnixListener::bind(path.clone()).unwrap();
-        let stream = UnixListenerStream::new(uds);
-
-        // We would just listen on TCP, but it seems impossible to know when 
tonic is ready to serve
-        let service = FlightSqlServiceImpl {};
-        let serve_future = Server::builder()
-            .add_service(FlightServiceServer::new(service))
-            .serve_with_incoming(stream);
-
-        let request_future = async {
-            let mut client = client_with_uds(path).await;
-            let token = client.handshake("admin", "password").await.unwrap();
-            println!("Auth succeeded with token: {:?}", token);
+        test_client(|mut client| async move {
+            auth_client(&mut client).await;
             let res = client
                 .execute_update("creat table test(a int);".to_string())
                 .await
                 .unwrap();
-            assert_eq!(res, FlightSqlServiceImpl::fake_update_result());
-        };
-
-        tokio::select! {
-            _ = serve_future => panic!("server returned first"),
-            _ = request_future => println!("Client finished!"),
-        }
+            assert_eq!(res, FAKE_UPDATE_RESULT);
+        })
+        .await
     }
 
-    fn endpoint(addr: String) -> Result<Endpoint, ArrowError> {
-        let endpoint = Endpoint::new(addr)
-            .map_err(|_| ArrowError::IoError("Cannot create 
endpoint".to_string()))?
-            .connect_timeout(Duration::from_secs(20))
-            .timeout(Duration::from_secs(20))
-            .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't 
want packets to wait
-            .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
-            .http2_keep_alive_interval(Duration::from_secs(300))
-            .keep_alive_timeout(Duration::from_secs(20))
-            .keep_alive_while_idle(true);
+    #[tokio::test]
+    async fn test_auth() {
+        test_client(|mut client| async move {
+            // no handshake
+            assert!(client
+                .prepare("select 1;".to_string())
+                .await
+                .unwrap_err()
+                .to_string()
+                .contains("No authorization header"));
 
-        Ok(endpoint)
+            // Invalid credentials
+            assert!(client
+                .handshake("admin", "password2")
+                .await
+                .unwrap_err()
+                .to_string()
+                .contains("Invalid credentials"));
+
+            // forget to set_token
+            client.handshake("admin", "password").await.unwrap();
+            assert!(client
+                .prepare("select 1;".to_string())
+                .await
+                .unwrap_err()
+                .to_string()
+                .contains("No authorization header"));
+
+            // Invalid Tokens
+            client.handshake("admin", "password").await.unwrap();
+            client.set_token("wrong token".to_string());
+            assert!(client
+                .prepare("select 1;".to_string())
+                .await
+                .unwrap_err()
+                .to_string()
+                .contains("invalid token"));
+        })
+        .await
     }
 }
diff --git a/arrow-flight/src/bin/flight_sql_client.rs 
b/arrow-flight/src/bin/flight_sql_client.rs
index c6a46a387..1891a331b 100644
--- a/arrow-flight/src/bin/flight_sql_client.rs
+++ b/arrow-flight/src/bin/flight_sql_client.rs
@@ -25,7 +25,7 @@ use arrow_flight::{
 use arrow_schema::{ArrowError, Schema};
 use clap::Parser;
 use futures::TryStreamExt;
-use tonic::transport::{ClientTlsConfig, Endpoint};
+use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
 use tracing_log::log::info;
 
 /// A ':' separated key value pair
@@ -140,7 +140,9 @@ fn setup_logging() {
     tracing_subscriber::fmt::init();
 }
 
-async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient, 
ArrowError> {
+async fn setup_client(
+    args: ClientArgs,
+) -> Result<FlightSqlServiceClient<Channel>, ArrowError> {
     let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
 
     let protocol = if args.tls { "https" } else { "http" };
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index a61f06d32..a8868fba1 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -35,7 +35,7 @@ use crate::sql::{
 };
 use crate::{
     Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
-    HandshakeResponse, IpcMessage, Ticket,
+    HandshakeResponse, IpcMessage, PutResult, Ticket,
 };
 use arrow_array::RecordBatch;
 use arrow_buffer::Buffer;
@@ -51,16 +51,16 @@ use tonic::{IntoRequest, Streaming};
 /// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow 
data
 /// by FlightSQL protocol.
 #[derive(Debug, Clone)]
-pub struct FlightSqlServiceClient {
+pub struct FlightSqlServiceClient<T> {
     token: Option<String>,
     headers: HashMap<String, String>,
-    flight_client: FlightServiceClient<Channel>,
+    flight_client: FlightServiceClient<T>,
 }
 
 /// A FlightSql protocol client that can run queries against FlightSql servers
 /// This client is in the "experimental" stage. It is not guaranteed to follow 
the spec in all instances.
 /// Github issues are welcomed.
-impl FlightSqlServiceClient {
+impl FlightSqlServiceClient<Channel> {
     /// Creates a new FlightSql client that connects to a server over an 
arbitrary tonic `Channel`
     pub fn new(channel: Channel) -> Self {
         let flight_client = FlightServiceClient::new(channel);
@@ -212,7 +212,7 @@ impl FlightSqlServiceClient {
     /// Given a flight ticket, request to be sent the stream. Returns record 
batch stream reader
     pub async fn do_get(
         &mut self,
-        ticket: Ticket,
+        ticket: impl IntoRequest<Ticket>,
     ) -> Result<Streaming<FlightData>, ArrowError> {
         let req = self.set_request_headers(ticket.into_request())?;
         Ok(self
@@ -223,6 +223,34 @@ impl FlightSqlServiceClient {
             .into_inner())
     }
 
+    /// Push a stream to the flight service associated with a particular 
flight stream.
+    pub async fn do_put(
+        &mut self,
+        request: impl tonic::IntoStreamingRequest<Message = FlightData>,
+    ) -> Result<Streaming<PutResult>, ArrowError> {
+        let req = self.set_request_headers(request.into_streaming_request())?;
+        Ok(self
+            .flight_client
+            .do_put(req)
+            .await
+            .map_err(status_to_arrow_error)?
+            .into_inner())
+    }
+
+    /// DoAction allows a flight client to do a specific action against a 
flight service
+    pub async fn do_action(
+        &mut self,
+        request: impl IntoRequest<Action>,
+    ) -> Result<Streaming<crate::Result>, ArrowError> {
+        let req = self.set_request_headers(request.into_request())?;
+        Ok(self
+            .flight_client
+            .do_action(req)
+            .await
+            .map_err(status_to_arrow_error)?
+            .into_inner())
+    }
+
     /// Request a list of tables.
     pub async fn get_tables(
         &mut self,
@@ -316,7 +344,7 @@ impl FlightSqlServiceClient {
             _ => 
Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
         };
         Ok(PreparedStatement::new(
-            self.flight_client.clone(),
+            self.clone(),
             prepared_result.prepared_statement_handle,
             dataset_schema,
             parameter_schema,
@@ -354,7 +382,7 @@ impl FlightSqlServiceClient {
 /// A PreparedStatement
 #[derive(Debug, Clone)]
 pub struct PreparedStatement<T> {
-    flight_client: FlightServiceClient<T>,
+    flight_sql_client: FlightSqlServiceClient<T>,
     parameter_binding: Option<RecordBatch>,
     handle: Bytes,
     dataset_schema: Schema,
@@ -363,13 +391,13 @@ pub struct PreparedStatement<T> {
 
 impl PreparedStatement<Channel> {
     pub(crate) fn new(
-        flight_client: FlightServiceClient<Channel>,
+        flight_client: FlightSqlServiceClient<Channel>,
         handle: impl Into<Bytes>,
         dataset_schema: Schema,
         parameter_schema: Schema,
     ) -> Self {
         PreparedStatement {
-            flight_client,
+            flight_sql_client: flight_client,
             parameter_binding: None,
             handle: handle.into(),
             dataset_schema,
@@ -382,13 +410,10 @@ impl PreparedStatement<Channel> {
         let cmd = CommandPreparedStatementQuery {
             prepared_statement_handle: self.handle.clone(),
         };
-        let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
         let result = self
-            .flight_client
-            .get_flight_info(descriptor)
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
+            .flight_sql_client
+            .get_flight_info_for_command(cmd)
+            .await?;
         Ok(result)
     }
 
@@ -399,14 +424,12 @@ impl PreparedStatement<Channel> {
         };
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
         let mut result = self
-            .flight_client
+            .flight_sql_client
             .do_put(stream::iter(vec![FlightData {
                 flight_descriptor: Some(descriptor),
                 ..Default::default()
             }]))
-            .await
-            .map_err(status_to_arrow_error)?
-            .into_inner();
+            .await?;
         let result = result
             .message()
             .await
@@ -447,11 +470,7 @@ impl PreparedStatement<Channel> {
             r#type: CLOSE_PREPARED_STATEMENT.to_string(),
             body: cmd.as_any().encode_to_vec().into(),
         };
-        let _ = self
-            .flight_client
-            .do_action(action)
-            .await
-            .map_err(status_to_arrow_error)?;
+        let _ = self.flight_sql_client.do_action(action).await?;
         Ok(())
     }
 }

Reply via email to