This is an automated email from the ASF dual-hosted git repository.

chunshao pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/incubator-horaedb-client-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 08ec55c  feat: add basic auth (#50)
08ec55c is described below

commit 08ec55ccf9f43d3c59e3b55459dc55edca3a71ce
Author: Jiacai Liu <[email protected]>
AuthorDate: Thu May 16 10:13:03 2024 +0800

    feat: add basic auth (#50)
---
 .github/workflows/ci.yml          |  2 +-
 Cargo.lock                        | 18 ++++++++++++-----
 Cargo.toml                        |  2 ++
 examples/read_write.rs            | 10 ++++++++--
 src/config.rs                     |  6 ++++++
 src/db_client/builder.rs          | 15 ++++++++++++--
 src/errors.rs                     |  6 ++++++
 src/lib.rs                        |  2 +-
 src/rpc_client/rpc_client_impl.rs | 42 +++++++++++++++++++++++++++++++++------
 9 files changed, 86 insertions(+), 17 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 3e34001..08b760f 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -45,7 +45,7 @@ jobs:
       - name: Setup Build Environment
         run: sudo apt update && sudo apt install -y protobuf-compiler
       - name: Install cargo binaries
-        run: cargo install cargo-sort
+        run: cargo install cargo-sort --locked
       - name: Run Style Check
         run: make fmt clippy check-toml
 
diff --git a/Cargo.lock b/Cargo.lock
index bd873f8..c3c030e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -50,9 +50,9 @@ dependencies = [
 
 [[package]]
 name = "anyhow"
-version = "1.0.71"
+version = "1.0.83"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8"
+checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3"
 
 [[package]]
 name = "arrow"
@@ -367,6 +367,12 @@ version = "0.13.1"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
 
+[[package]]
+name = "base64"
+version = "0.22.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
+
 [[package]]
 name = "bitflags"
 version = "1.3.2"
@@ -381,9 +387,9 @@ checksum = 
"a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
 
 [[package]]
 name = "bytes"
-version = "1.4.0"
+version = "1.6.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
+checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
 
 [[package]]
 name = "cc"
@@ -713,8 +719,10 @@ checksum = 
"fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
 name = "horaedb-client"
 version = "1.0.2"
 dependencies = [
+ "anyhow",
  "arrow",
  "async-trait",
+ "base64 0.22.1",
  "chrono",
  "dashmap",
  "futures",
@@ -1730,7 +1738,7 @@ dependencies = [
  "async-stream",
  "async-trait",
  "axum",
- "base64",
+ "base64 0.13.1",
  "bytes",
  "futures-core",
  "futures-util",
diff --git a/Cargo.toml b/Cargo.toml
index d026f66..8e65906 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -23,8 +23,10 @@ description = "Rust implementation of HoraeDB client."
 readme = "README.md"
 
 [dependencies]
+anyhow = "1.0.83"
 arrow = "38.0.0"
 async-trait = "0.1.72"
+base64 = "0.22.1"
 dashmap = "5.3.4"
 futures = "0.3"
 horaedbproto = "1.0.23"
diff --git a/examples/read_write.rs b/examples/read_write.rs
index ab73304..7c002c2 100644
--- a/examples/read_write.rs
+++ b/examples/read_write.rs
@@ -22,7 +22,7 @@ use horaedb_client::{
         value::Value,
         write::{point::PointBuilder, Request as WriteRequest},
     },
-    RpcContext,
+    Authorization, RpcContext,
 };
 
 async fn create_table(client: &Arc<dyn DbClient>, rpc_ctx: &RpcContext) {
@@ -112,7 +112,13 @@ async fn sql_query(client: &Arc<dyn DbClient>, rpc_ctx: 
&RpcContext) {
 #[tokio::main]
 async fn main() {
     // you should ensure horaedb is running, and grpc port is set to 8831
-    let client = Builder::new("127.0.0.1:8831".to_string(), 
Mode::Direct).build();
+    let client = Builder::new("127.0.0.1:8831".to_string(), Mode::Direct)
+        // Set authorization if needed
+        .authorization(Authorization {
+            username: "user".to_string(),
+            password: "pass".to_string(),
+        })
+        .build();
     let rpc_ctx = RpcContext::default().database("public".to_string());
 
     
println!("------------------------------------------------------------------");
diff --git a/src/config.rs b/src/config.rs
index cdb7ee7..dcd6523 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -56,6 +56,12 @@ pub struct RpcConfig {
     pub connect_timeout: Duration,
 }
 
+#[derive(Debug, Clone)]
+pub struct Authorization {
+    pub username: String,
+    pub password: String,
+}
+
 impl Default for RpcConfig {
     fn default() -> Self {
         Self {
diff --git a/src/db_client/builder.rs b/src/db_client/builder.rs
index 1749bc2..43f6c8a 100644
--- a/src/db_client/builder.rs
+++ b/src/db_client/builder.rs
@@ -17,7 +17,7 @@ use std::sync::Arc;
 use crate::{
     db_client::{raw::RawImpl, route_based::RouteBasedImpl, DbClient},
     rpc_client::RpcClientImplFactory,
-    RpcConfig,
+    Authorization, RpcConfig,
 };
 
 /// Access mode to HoraeDB server(s).
@@ -40,6 +40,7 @@ pub struct Builder {
     endpoint: String,
     default_database: Option<String>,
     rpc_config: RpcConfig,
+    authorization: Option<Authorization>,
 }
 
 impl Builder {
@@ -50,6 +51,7 @@ impl Builder {
             endpoint,
             rpc_config: RpcConfig::default(),
             default_database: None,
+            authorization: None,
         }
     }
 
@@ -65,8 +67,17 @@ impl Builder {
         self
     }
 
+    #[inline]
+    pub fn authorization(mut self, authorization: Authorization) -> Self {
+        self.authorization = Some(authorization);
+        self
+    }
+
     pub fn build(self) -> Arc<dyn DbClient> {
-        let rpc_client_factory = 
Arc::new(RpcClientImplFactory::new(self.rpc_config));
+        let rpc_client_factory = Arc::new(RpcClientImplFactory::new(
+            self.rpc_config,
+            self.authorization,
+        ));
 
         match self.mode {
             Mode::Direct => Arc::new(RouteBasedImpl::new(
diff --git a/src/errors.rs b/src/errors.rs
index 2b03f65..8dfe8cd 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -66,6 +66,12 @@ pub enum Error {
 
     #[error("failed to find a database")]
     NoDatabase,
+
+    #[error(transparent)]
+    Other {
+        #[from]
+        source: anyhow::Error,
+    },
 }
 
 #[derive(Debug)]
diff --git a/src/lib.rs b/src/lib.rs
index 3c94a54..2c7dc9f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -85,7 +85,7 @@ mod util;
 
 #[doc(inline)]
 pub use crate::{
-    config::RpcConfig,
+    config::{Authorization, RpcConfig},
     db_client::{Builder, DbClient, Mode},
     errors::{Error, Result},
     model::{
diff --git a/src/rpc_client/rpc_client_impl.rs 
b/src/rpc_client/rpc_client_impl.rs
index 0cf54eb..16d5c7b 100644
--- a/src/rpc_client/rpc_client_impl.rs
+++ b/src/rpc_client/rpc_client_impl.rs
@@ -14,7 +14,9 @@
 
 use std::{sync::Arc, time::Duration};
 
+use anyhow::Context;
 use async_trait::async_trait;
+use base64::{prelude::BASE64_STANDARD, Engine};
 use horaedbproto::{
     common::ResponseHeader,
     storage::{
@@ -24,6 +26,7 @@ use horaedbproto::{
     },
 };
 use tonic::{
+    metadata::{Ascii, MetadataValue},
     transport::{Channel, Endpoint},
     Request,
 };
@@ -33,12 +36,14 @@ use crate::{
     errors::{Error, Result, ServerError},
     rpc_client::{RpcClient, RpcClientFactory, RpcContext},
     util::is_ok,
+    Authorization,
 };
 
 struct RpcClientImpl {
     channel: Channel,
     default_read_timeout: Duration,
     default_write_timeout: Duration,
+    metadata: Option<MetadataValue<Ascii>>,
 }
 
 impl RpcClientImpl {
@@ -46,11 +51,13 @@ impl RpcClientImpl {
         channel: Channel,
         default_read_timeout: Duration,
         default_write_timeout: Duration,
+        metadata: Option<MetadataValue<Ascii>>,
     ) -> Self {
         Self {
             channel,
             default_read_timeout,
             default_write_timeout,
+            metadata,
         }
     }
 
@@ -65,19 +72,22 @@ impl RpcClientImpl {
         Ok(())
     }
 
-    fn make_request<T>(ctx: &RpcContext, req: T, default_timeout: Duration) -> 
Request<T> {
+    fn make_request<T>(&self, ctx: &RpcContext, req: T, default_timeout: 
Duration) -> Request<T> {
         let timeout = ctx.timeout.unwrap_or(default_timeout);
         let mut req = Request::new(req);
         req.set_timeout(timeout);
+        if let Some(md) = &self.metadata {
+            req.metadata_mut().insert("authorization", md.clone());
+        }
         req
     }
 
     fn make_query_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
-        Self::make_request(ctx, req, self.default_read_timeout)
+        self.make_request(ctx, req, self.default_read_timeout)
     }
 
     fn make_write_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
-        Self::make_request(ctx, req, self.default_write_timeout)
+        self.make_request(ctx, req, self.default_write_timeout)
     }
 }
 
@@ -119,7 +129,7 @@ impl RpcClient for RpcClientImpl {
         let mut client = 
StorageServiceClient::<Channel>::new(self.channel.clone());
 
         // use the write timeout for the route request.
-        let route_req = Self::make_request(ctx, req, 
self.default_write_timeout);
+        let route_req = self.make_request(ctx, req, 
self.default_write_timeout);
         let resp = client.route(route_req).await.map_err(Error::Rpc)?;
         let mut resp = resp.into_inner();
 
@@ -133,11 +143,15 @@ impl RpcClient for RpcClientImpl {
 
 pub struct RpcClientImplFactory {
     rpc_config: RpcConfig,
+    authorization: Option<Authorization>,
 }
 
 impl RpcClientImplFactory {
-    pub fn new(rpc_config: RpcConfig) -> Self {
-        Self { rpc_config }
+    pub fn new(rpc_config: RpcConfig, authorization: Option<Authorization>) -> 
Self {
+        Self {
+            rpc_config,
+            authorization,
+        }
     }
 
     #[inline]
@@ -174,10 +188,26 @@ impl RpcClientFactory for RpcClientImplFactory {
                 addr: endpoint,
                 source: Box::new(e),
             })?;
+
+        let metadata = if let Some(auth) = &self.authorization {
+            let mut buf = Vec::with_capacity(auth.username.len() + 
auth.password.len() + 1);
+            buf.extend_from_slice(auth.username.as_bytes());
+            buf.push(b':');
+            buf.extend_from_slice(auth.password.as_bytes());
+            let auth = BASE64_STANDARD.encode(&buf);
+            let metadata: MetadataValue<Ascii> = format!("Basic {}", auth)
+                .parse()
+                .context("invalid grpc metadata")?;
+
+            Some(metadata)
+        } else {
+            None
+        };
         Ok(Arc::new(RpcClientImpl::new(
             channel,
             self.rpc_config.default_sql_query_timeout,
             self.rpc_config.default_write_timeout,
+            metadata,
         )))
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to