This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new abf5367d4 Fix(flight_sql): PreparedStatement has no token for auth.
(#3948)
abf5367d4 is described below
commit abf5367d4828b71e152bc159ce7c70c86181eebc
Author: Yang Xiufeng <[email protected]>
AuthorDate: Thu Mar 30 18:55:47 2023 +0800
Fix(flight_sql): PreparedStatement has no token for auth. (#3948)
* Fix(flight_sql): PreparedStatement need FlightSqlServiceClient to set
headers.
In particular, the token is required for auth in each request.
* refactor: make FlightSqlServiceClient generic.
* test: example FlightSqlServiceImpl check token for each request .
* remove FlightSqlServiceClient::get_flight_info.
* keep consistent of do_get/do_action/do_put.
* code reuse in tests of example FlightSqlServiceImpl.
* add cases for auth failure.
---
arrow-flight/examples/flight_sql_server.rs | 180 +++++++++++++++++++----------
arrow-flight/src/bin/flight_sql_client.rs | 6 +-
arrow-flight/src/sql/client.rs | 67 +++++++----
3 files changed, 168 insertions(+), 85 deletions(-)
diff --git a/arrow-flight/examples/flight_sql_server.rs
b/arrow-flight/examples/flight_sql_server.rs
index bc9d24656..08744b65f 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -59,10 +59,35 @@ macro_rules! status {
};
}
+const FAKE_TOKEN: &str = "uuid_token";
+const FAKE_HANDLE: &str = "uuid_handle";
+const FAKE_UPDATE_RESULT: i64 = 1;
+
#[derive(Clone)]
pub struct FlightSqlServiceImpl {}
impl FlightSqlServiceImpl {
+ fn check_token<T>(&self, req: &Request<T>) -> Result<(), Status> {
+ let metadata = req.metadata();
+ let auth = metadata.get("authorization").ok_or_else(|| {
+ Status::internal(format!("No authorization header! metadata =
{metadata:?}"))
+ })?;
+ let str = auth
+ .to_str()
+ .map_err(|e| Status::internal(format!("Error parsing header:
{e}")))?;
+ let authorization = str.to_string();
+ let bearer = "Bearer ";
+ if !authorization.starts_with(bearer) {
+ Err(Status::internal("Invalid auth header!"))?;
+ }
+ let token = authorization[bearer.len()..].to_string();
+ if token == FAKE_TOKEN {
+ Ok(())
+ } else {
+ Err(Status::unauthenticated("invalid token "))
+ }
+ }
+
fn fake_result() -> Result<RecordBatch, ArrowError> {
let schema = Schema::new(vec![Field::new("salutation", DataType::Utf8,
false)]);
let mut builder = StringBuilder::new();
@@ -70,10 +95,6 @@ impl FlightSqlServiceImpl {
let cols = vec![Arc::new(builder.finish()) as ArrayRef];
RecordBatch::try_new(Arc::new(schema), cols)
}
-
- fn fake_update_result() -> i64 {
- 1
- }
}
#[tonic::async_trait]
@@ -118,7 +139,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
let result = HandshakeResponse {
protocol_version: 0,
- payload: "random_uuid_token".into(),
+ payload: FAKE_TOKEN.into(),
};
let result = Ok(result);
let output = futures::stream::iter(vec![result]);
@@ -127,9 +148,10 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_get_fallback(
&self,
- _request: Request<Ticket>,
+ request: Request<Ticket>,
_message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
+ self.check_token(&request)?;
let batch =
Self::fake_result().map_err(|e| status!("Could not fake a result",
e))?;
let schema = (*batch.schema()).clone();
@@ -158,8 +180,9 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn get_flight_info_prepared_statement(
&self,
cmd: CommandPreparedStatementQuery,
- _request: Request<FlightDescriptor>,
+ request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
+ self.check_token(&request)?;
let handle = std::str::from_utf8(&cmd.prepared_statement_handle)
.map_err(|e| status!("Unable to parse handle", e))?;
let batch =
@@ -395,7 +418,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
_ticket: CommandStatementUpdate,
_request: Request<Streaming<FlightData>>,
) -> Result<i64, Status> {
- Ok(FlightSqlServiceImpl::fake_update_result())
+ Ok(FAKE_UPDATE_RESULT)
}
async fn do_put_prepared_statement_query(
@@ -421,9 +444,9 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_action_create_prepared_statement(
&self,
_query: ActionCreatePreparedStatementRequest,
- _request: Request<Action>,
+ request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status> {
- let handle = "some_uuid";
+ self.check_token(&request)?;
let schema = Self::fake_result()
.map_err(|e| status!("Error getting result schema", e))?
.schema();
@@ -432,7 +455,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
.map_err(|e| status!("Unable to serialize schema", e))?;
let IpcMessage(schema_bytes) = message;
let res = ActionCreatePreparedStatementResult {
- prepared_statement_handle: handle.into(),
+ prepared_statement_handle: FAKE_HANDLE.into(),
dataset_schema: schema_bytes,
parameter_schema: Default::default(), // TODO: parameters
};
@@ -505,12 +528,13 @@ mod tests {
use super::*;
use futures::TryStreamExt;
use std::fs;
+ use std::future::Future;
use std::time::Duration;
use tempfile::NamedTempFile;
use tokio::net::{UnixListener, UnixStream};
use tokio::time::sleep;
use tokio_stream::wrappers::UnixListenerStream;
- use tonic::transport::ClientTlsConfig;
+ use tonic::transport::{Channel, ClientTlsConfig};
use arrow_cast::pretty::pretty_format_batches;
use arrow_flight::sql::client::FlightSqlServiceClient;
@@ -518,7 +542,7 @@ mod tests {
use tonic::transport::{Certificate, Endpoint};
use tower::service_fn;
- async fn client_with_uds(path: String) -> FlightSqlServiceClient {
+ async fn client_with_uds(path: String) -> FlightSqlServiceClient<Channel> {
let connector = service_fn(move |_| UnixStream::connect(path.clone()));
let channel = Endpoint::try_from("http://example.com")
.unwrap()
@@ -549,6 +573,20 @@ mod tests {
.await
}
+ fn endpoint(addr: String) -> Result<Endpoint, ArrowError> {
+ let endpoint = Endpoint::new(addr)
+ .map_err(|_| ArrowError::IoError("Cannot create
endpoint".to_string()))?
+ .connect_timeout(Duration::from_secs(20))
+ .timeout(Duration::from_secs(20))
+ .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't
want packets to wait
+ .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
+ .http2_keep_alive_interval(Duration::from_secs(300))
+ .keep_alive_timeout(Duration::from_secs(20))
+ .keep_alive_while_idle(true);
+
+ Ok(endpoint)
+ }
+
#[tokio::test]
async fn test_select_https() {
tokio::spawn(async {
@@ -573,6 +611,7 @@ mod tests {
let channel = endpoint.connect().await.unwrap();
let mut client = FlightSqlServiceClient::new(channel);
let token = client.handshake("admin", "password").await.unwrap();
+ client.set_token(String::from_utf8(token.to_vec()).unwrap());
println!("Auth succeeded with token: {:?}", token);
let mut stmt = client.prepare("select
1;".to_string()).await.unwrap();
let flight_info = stmt.execute().await.unwrap();
@@ -597,8 +636,16 @@ mod tests {
}
}
- #[tokio::test]
- async fn test_select_1() {
+ async fn auth_client(client: &mut FlightSqlServiceClient<Channel>) {
+ let token = client.handshake("admin", "password").await.unwrap();
+ client.set_token(String::from_utf8(token.to_vec()).unwrap());
+ }
+
+ async fn test_client<F, C>(f: F)
+ where
+ F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
+ C: Future<Output = ()>,
+ {
let file = NamedTempFile::new().unwrap();
let path = file.into_temp_path().to_str().unwrap().to_string();
let _ = fs::remove_file(path.clone());
@@ -613,9 +660,20 @@ mod tests {
.serve_with_incoming(stream);
let request_future = async {
- let mut client = client_with_uds(path).await;
- let token = client.handshake("admin", "password").await.unwrap();
- println!("Auth succeeded with token: {:?}", token);
+ let client = client_with_uds(path).await;
+ f(client).await
+ };
+
+ tokio::select! {
+ _ = serve_future => panic!("server returned first"),
+ _ = request_future => println!("Client finished!"),
+ }
+ }
+
+ #[tokio::test]
+ async fn test_select_1() {
+ test_client(|mut client| async move {
+ auth_client(&mut client).await;
let mut stmt = client.prepare("select
1;".to_string()).await.unwrap();
let flight_info = stmt.execute().await.unwrap();
let ticket =
flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
@@ -632,57 +690,61 @@ mod tests {
.trim()
.to_string();
assert_eq!(res.to_string(), expected);
- };
-
- tokio::select! {
- _ = serve_future => panic!("server returned first"),
- _ = request_future => println!("Client finished!"),
- }
+ })
+ .await
}
#[tokio::test]
async fn test_execute_update() {
- let file = NamedTempFile::new().unwrap();
- let path = file.into_temp_path().to_str().unwrap().to_string();
- let _ = fs::remove_file(path.clone());
-
- let uds = UnixListener::bind(path.clone()).unwrap();
- let stream = UnixListenerStream::new(uds);
-
- // We would just listen on TCP, but it seems impossible to know when
tonic is ready to serve
- let service = FlightSqlServiceImpl {};
- let serve_future = Server::builder()
- .add_service(FlightServiceServer::new(service))
- .serve_with_incoming(stream);
-
- let request_future = async {
- let mut client = client_with_uds(path).await;
- let token = client.handshake("admin", "password").await.unwrap();
- println!("Auth succeeded with token: {:?}", token);
+ test_client(|mut client| async move {
+ auth_client(&mut client).await;
let res = client
.execute_update("creat table test(a int);".to_string())
.await
.unwrap();
- assert_eq!(res, FlightSqlServiceImpl::fake_update_result());
- };
-
- tokio::select! {
- _ = serve_future => panic!("server returned first"),
- _ = request_future => println!("Client finished!"),
- }
+ assert_eq!(res, FAKE_UPDATE_RESULT);
+ })
+ .await
}
- fn endpoint(addr: String) -> Result<Endpoint, ArrowError> {
- let endpoint = Endpoint::new(addr)
- .map_err(|_| ArrowError::IoError("Cannot create
endpoint".to_string()))?
- .connect_timeout(Duration::from_secs(20))
- .timeout(Duration::from_secs(20))
- .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't
want packets to wait
- .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
- .http2_keep_alive_interval(Duration::from_secs(300))
- .keep_alive_timeout(Duration::from_secs(20))
- .keep_alive_while_idle(true);
+ #[tokio::test]
+ async fn test_auth() {
+ test_client(|mut client| async move {
+ // no handshake
+ assert!(client
+ .prepare("select 1;".to_string())
+ .await
+ .unwrap_err()
+ .to_string()
+ .contains("No authorization header"));
- Ok(endpoint)
+ // Invalid credentials
+ assert!(client
+ .handshake("admin", "password2")
+ .await
+ .unwrap_err()
+ .to_string()
+ .contains("Invalid credentials"));
+
+ // forget to set_token
+ client.handshake("admin", "password").await.unwrap();
+ assert!(client
+ .prepare("select 1;".to_string())
+ .await
+ .unwrap_err()
+ .to_string()
+ .contains("No authorization header"));
+
+ // Invalid Tokens
+ client.handshake("admin", "password").await.unwrap();
+ client.set_token("wrong token".to_string());
+ assert!(client
+ .prepare("select 1;".to_string())
+ .await
+ .unwrap_err()
+ .to_string()
+ .contains("invalid token"));
+ })
+ .await
}
}
diff --git a/arrow-flight/src/bin/flight_sql_client.rs
b/arrow-flight/src/bin/flight_sql_client.rs
index c6a46a387..1891a331b 100644
--- a/arrow-flight/src/bin/flight_sql_client.rs
+++ b/arrow-flight/src/bin/flight_sql_client.rs
@@ -25,7 +25,7 @@ use arrow_flight::{
use arrow_schema::{ArrowError, Schema};
use clap::Parser;
use futures::TryStreamExt;
-use tonic::transport::{ClientTlsConfig, Endpoint};
+use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tracing_log::log::info;
/// A ':' separated key value pair
@@ -140,7 +140,9 @@ fn setup_logging() {
tracing_subscriber::fmt::init();
}
-async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient,
ArrowError> {
+async fn setup_client(
+ args: ClientArgs,
+) -> Result<FlightSqlServiceClient<Channel>, ArrowError> {
let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
let protocol = if args.tls { "https" } else { "http" };
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index a61f06d32..a8868fba1 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -35,7 +35,7 @@ use crate::sql::{
};
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
- HandshakeResponse, IpcMessage, Ticket,
+ HandshakeResponse, IpcMessage, PutResult, Ticket,
};
use arrow_array::RecordBatch;
use arrow_buffer::Buffer;
@@ -51,16 +51,16 @@ use tonic::{IntoRequest, Streaming};
/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow
data
/// by FlightSQL protocol.
#[derive(Debug, Clone)]
-pub struct FlightSqlServiceClient {
+pub struct FlightSqlServiceClient<T> {
token: Option<String>,
headers: HashMap<String, String>,
- flight_client: FlightServiceClient<Channel>,
+ flight_client: FlightServiceClient<T>,
}
/// A FlightSql protocol client that can run queries against FlightSql servers
/// This client is in the "experimental" stage. It is not guaranteed to follow
the spec in all instances.
/// Github issues are welcomed.
-impl FlightSqlServiceClient {
+impl FlightSqlServiceClient<Channel> {
/// Creates a new FlightSql client that connects to a server over an
arbitrary tonic `Channel`
pub fn new(channel: Channel) -> Self {
let flight_client = FlightServiceClient::new(channel);
@@ -212,7 +212,7 @@ impl FlightSqlServiceClient {
/// Given a flight ticket, request to be sent the stream. Returns record
batch stream reader
pub async fn do_get(
&mut self,
- ticket: Ticket,
+ ticket: impl IntoRequest<Ticket>,
) -> Result<Streaming<FlightData>, ArrowError> {
let req = self.set_request_headers(ticket.into_request())?;
Ok(self
@@ -223,6 +223,34 @@ impl FlightSqlServiceClient {
.into_inner())
}
+ /// Push a stream to the flight service associated with a particular
flight stream.
+ pub async fn do_put(
+ &mut self,
+ request: impl tonic::IntoStreamingRequest<Message = FlightData>,
+ ) -> Result<Streaming<PutResult>, ArrowError> {
+ let req = self.set_request_headers(request.into_streaming_request())?;
+ Ok(self
+ .flight_client
+ .do_put(req)
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner())
+ }
+
+ /// DoAction allows a flight client to do a specific action against a
flight service
+ pub async fn do_action(
+ &mut self,
+ request: impl IntoRequest<Action>,
+ ) -> Result<Streaming<crate::Result>, ArrowError> {
+ let req = self.set_request_headers(request.into_request())?;
+ Ok(self
+ .flight_client
+ .do_action(req)
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner())
+ }
+
/// Request a list of tables.
pub async fn get_tables(
&mut self,
@@ -316,7 +344,7 @@ impl FlightSqlServiceClient {
_ =>
Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
};
Ok(PreparedStatement::new(
- self.flight_client.clone(),
+ self.clone(),
prepared_result.prepared_statement_handle,
dataset_schema,
parameter_schema,
@@ -354,7 +382,7 @@ impl FlightSqlServiceClient {
/// A PreparedStatement
#[derive(Debug, Clone)]
pub struct PreparedStatement<T> {
- flight_client: FlightServiceClient<T>,
+ flight_sql_client: FlightSqlServiceClient<T>,
parameter_binding: Option<RecordBatch>,
handle: Bytes,
dataset_schema: Schema,
@@ -363,13 +391,13 @@ pub struct PreparedStatement<T> {
impl PreparedStatement<Channel> {
pub(crate) fn new(
- flight_client: FlightServiceClient<Channel>,
+ flight_client: FlightSqlServiceClient<Channel>,
handle: impl Into<Bytes>,
dataset_schema: Schema,
parameter_schema: Schema,
) -> Self {
PreparedStatement {
- flight_client,
+ flight_sql_client: flight_client,
parameter_binding: None,
handle: handle.into(),
dataset_schema,
@@ -382,13 +410,10 @@ impl PreparedStatement<Channel> {
let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};
- let descriptor =
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let result = self
- .flight_client
- .get_flight_info(descriptor)
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
+ .flight_sql_client
+ .get_flight_info_for_command(cmd)
+ .await?;
Ok(result)
}
@@ -399,14 +424,12 @@ impl PreparedStatement<Channel> {
};
let descriptor =
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let mut result = self
- .flight_client
+ .flight_sql_client
.do_put(stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
}]))
- .await
- .map_err(status_to_arrow_error)?
- .into_inner();
+ .await?;
let result = result
.message()
.await
@@ -447,11 +470,7 @@ impl PreparedStatement<Channel> {
r#type: CLOSE_PREPARED_STATEMENT.to_string(),
body: cmd.as_any().encode_to_vec().into(),
};
- let _ = self
- .flight_client
- .do_action(action)
- .await
- .map_err(status_to_arrow_error)?;
+ let _ = self.flight_sql_client.do_action(action).await?;
Ok(())
}
}