This is an automated email from the ASF dual-hosted git repository.
chunshao pushed a commit to branch main
in repository
https://gitbox.apache.org/repos/asf/incubator-horaedb-client-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 08ec55c feat: add basic auth (#50)
08ec55c is described below
commit 08ec55ccf9f43d3c59e3b55459dc55edca3a71ce
Author: Jiacai Liu <[email protected]>
AuthorDate: Thu May 16 10:13:03 2024 +0800
feat: add basic auth (#50)
---
.github/workflows/ci.yml | 2 +-
Cargo.lock | 18 ++++++++++++-----
Cargo.toml | 2 ++
examples/read_write.rs | 10 ++++++++--
src/config.rs | 6 ++++++
src/db_client/builder.rs | 15 ++++++++++++--
src/errors.rs | 6 ++++++
src/lib.rs | 2 +-
src/rpc_client/rpc_client_impl.rs | 42 +++++++++++++++++++++++++++++++++------
9 files changed, 86 insertions(+), 17 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 3e34001..08b760f 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -45,7 +45,7 @@ jobs:
- name: Setup Build Environment
run: sudo apt update && sudo apt install -y protobuf-compiler
- name: Install cargo binaries
- run: cargo install cargo-sort
+ run: cargo install cargo-sort --locked
- name: Run Style Check
run: make fmt clippy check-toml
diff --git a/Cargo.lock b/Cargo.lock
index bd873f8..c3c030e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -50,9 +50,9 @@ dependencies = [
[[package]]
name = "anyhow"
-version = "1.0.71"
+version = "1.0.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8"
+checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3"
[[package]]
name = "arrow"
@@ -367,6 +367,12 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
+[[package]]
+name = "base64"
+version = "0.22.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
+
[[package]]
name = "bitflags"
version = "1.3.2"
@@ -381,9 +387,9 @@ checksum =
"a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
[[package]]
name = "bytes"
-version = "1.4.0"
+version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
+checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
[[package]]
name = "cc"
@@ -713,8 +719,10 @@ checksum =
"fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
name = "horaedb-client"
version = "1.0.2"
dependencies = [
+ "anyhow",
"arrow",
"async-trait",
+ "base64 0.22.1",
"chrono",
"dashmap",
"futures",
@@ -1730,7 +1738,7 @@ dependencies = [
"async-stream",
"async-trait",
"axum",
- "base64",
+ "base64 0.13.1",
"bytes",
"futures-core",
"futures-util",
diff --git a/Cargo.toml b/Cargo.toml
index d026f66..8e65906 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -23,8 +23,10 @@ description = "Rust implementation of HoraeDB client."
readme = "README.md"
[dependencies]
+anyhow = "1.0.83"
arrow = "38.0.0"
async-trait = "0.1.72"
+base64 = "0.22.1"
dashmap = "5.3.4"
futures = "0.3"
horaedbproto = "1.0.23"
diff --git a/examples/read_write.rs b/examples/read_write.rs
index ab73304..7c002c2 100644
--- a/examples/read_write.rs
+++ b/examples/read_write.rs
@@ -22,7 +22,7 @@ use horaedb_client::{
value::Value,
write::{point::PointBuilder, Request as WriteRequest},
},
- RpcContext,
+ Authorization, RpcContext,
};
async fn create_table(client: &Arc<dyn DbClient>, rpc_ctx: &RpcContext) {
@@ -112,7 +112,13 @@ async fn sql_query(client: &Arc<dyn DbClient>, rpc_ctx:
&RpcContext) {
#[tokio::main]
async fn main() {
// you should ensure horaedb is running, and grpc port is set to 8831
- let client = Builder::new("127.0.0.1:8831".to_string(),
Mode::Direct).build();
+ let client = Builder::new("127.0.0.1:8831".to_string(), Mode::Direct)
+ // Set authorization if needed
+ .authorization(Authorization {
+ username: "user".to_string(),
+ password: "pass".to_string(),
+ })
+ .build();
let rpc_ctx = RpcContext::default().database("public".to_string());
println!("------------------------------------------------------------------");
diff --git a/src/config.rs b/src/config.rs
index cdb7ee7..dcd6523 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -56,6 +56,12 @@ pub struct RpcConfig {
pub connect_timeout: Duration,
}
+#[derive(Debug, Clone)]
+pub struct Authorization {
+ pub username: String,
+ pub password: String,
+}
+
impl Default for RpcConfig {
fn default() -> Self {
Self {
diff --git a/src/db_client/builder.rs b/src/db_client/builder.rs
index 1749bc2..43f6c8a 100644
--- a/src/db_client/builder.rs
+++ b/src/db_client/builder.rs
@@ -17,7 +17,7 @@ use std::sync::Arc;
use crate::{
db_client::{raw::RawImpl, route_based::RouteBasedImpl, DbClient},
rpc_client::RpcClientImplFactory,
- RpcConfig,
+ Authorization, RpcConfig,
};
/// Access mode to HoraeDB server(s).
@@ -40,6 +40,7 @@ pub struct Builder {
endpoint: String,
default_database: Option<String>,
rpc_config: RpcConfig,
+ authorization: Option<Authorization>,
}
impl Builder {
@@ -50,6 +51,7 @@ impl Builder {
endpoint,
rpc_config: RpcConfig::default(),
default_database: None,
+ authorization: None,
}
}
@@ -65,8 +67,17 @@ impl Builder {
self
}
+ #[inline]
+ pub fn authorization(mut self, authorization: Authorization) -> Self {
+ self.authorization = Some(authorization);
+ self
+ }
+
pub fn build(self) -> Arc<dyn DbClient> {
- let rpc_client_factory =
Arc::new(RpcClientImplFactory::new(self.rpc_config));
+ let rpc_client_factory = Arc::new(RpcClientImplFactory::new(
+ self.rpc_config,
+ self.authorization,
+ ));
match self.mode {
Mode::Direct => Arc::new(RouteBasedImpl::new(
diff --git a/src/errors.rs b/src/errors.rs
index 2b03f65..8dfe8cd 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -66,6 +66,12 @@ pub enum Error {
#[error("failed to find a database")]
NoDatabase,
+
+ #[error(transparent)]
+ Other {
+ #[from]
+ source: anyhow::Error,
+ },
}
#[derive(Debug)]
diff --git a/src/lib.rs b/src/lib.rs
index 3c94a54..2c7dc9f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -85,7 +85,7 @@ mod util;
#[doc(inline)]
pub use crate::{
- config::RpcConfig,
+ config::{Authorization, RpcConfig},
db_client::{Builder, DbClient, Mode},
errors::{Error, Result},
model::{
diff --git a/src/rpc_client/rpc_client_impl.rs
b/src/rpc_client/rpc_client_impl.rs
index 0cf54eb..16d5c7b 100644
--- a/src/rpc_client/rpc_client_impl.rs
+++ b/src/rpc_client/rpc_client_impl.rs
@@ -14,7 +14,9 @@
use std::{sync::Arc, time::Duration};
+use anyhow::Context;
use async_trait::async_trait;
+use base64::{prelude::BASE64_STANDARD, Engine};
use horaedbproto::{
common::ResponseHeader,
storage::{
@@ -24,6 +26,7 @@ use horaedbproto::{
},
};
use tonic::{
+ metadata::{Ascii, MetadataValue},
transport::{Channel, Endpoint},
Request,
};
@@ -33,12 +36,14 @@ use crate::{
errors::{Error, Result, ServerError},
rpc_client::{RpcClient, RpcClientFactory, RpcContext},
util::is_ok,
+ Authorization,
};
struct RpcClientImpl {
channel: Channel,
default_read_timeout: Duration,
default_write_timeout: Duration,
+ metadata: Option<MetadataValue<Ascii>>,
}
impl RpcClientImpl {
@@ -46,11 +51,13 @@ impl RpcClientImpl {
channel: Channel,
default_read_timeout: Duration,
default_write_timeout: Duration,
+ metadata: Option<MetadataValue<Ascii>>,
) -> Self {
Self {
channel,
default_read_timeout,
default_write_timeout,
+ metadata,
}
}
@@ -65,19 +72,22 @@ impl RpcClientImpl {
Ok(())
}
- fn make_request<T>(ctx: &RpcContext, req: T, default_timeout: Duration) ->
Request<T> {
+ fn make_request<T>(&self, ctx: &RpcContext, req: T, default_timeout:
Duration) -> Request<T> {
let timeout = ctx.timeout.unwrap_or(default_timeout);
let mut req = Request::new(req);
req.set_timeout(timeout);
+ if let Some(md) = &self.metadata {
+ req.metadata_mut().insert("authorization", md.clone());
+ }
req
}
fn make_query_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
- Self::make_request(ctx, req, self.default_read_timeout)
+ self.make_request(ctx, req, self.default_read_timeout)
}
fn make_write_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
- Self::make_request(ctx, req, self.default_write_timeout)
+ self.make_request(ctx, req, self.default_write_timeout)
}
}
@@ -119,7 +129,7 @@ impl RpcClient for RpcClientImpl {
let mut client =
StorageServiceClient::<Channel>::new(self.channel.clone());
// use the write timeout for the route request.
- let route_req = Self::make_request(ctx, req,
self.default_write_timeout);
+ let route_req = self.make_request(ctx, req,
self.default_write_timeout);
let resp = client.route(route_req).await.map_err(Error::Rpc)?;
let mut resp = resp.into_inner();
@@ -133,11 +143,15 @@ impl RpcClient for RpcClientImpl {
pub struct RpcClientImplFactory {
rpc_config: RpcConfig,
+ authorization: Option<Authorization>,
}
impl RpcClientImplFactory {
- pub fn new(rpc_config: RpcConfig) -> Self {
- Self { rpc_config }
+ pub fn new(rpc_config: RpcConfig, authorization: Option<Authorization>) ->
Self {
+ Self {
+ rpc_config,
+ authorization,
+ }
}
#[inline]
@@ -174,10 +188,26 @@ impl RpcClientFactory for RpcClientImplFactory {
addr: endpoint,
source: Box::new(e),
})?;
+
+ let metadata = if let Some(auth) = &self.authorization {
+ let mut buf = Vec::with_capacity(auth.username.len() +
auth.password.len() + 1);
+ buf.extend_from_slice(auth.username.as_bytes());
+ buf.push(b':');
+ buf.extend_from_slice(auth.password.as_bytes());
+ let auth = BASE64_STANDARD.encode(&buf);
+ let metadata: MetadataValue<Ascii> = format!("Basic {}", auth)
+ .parse()
+ .context("invalid grpc metadata")?;
+
+ Some(metadata)
+ } else {
+ None
+ };
Ok(Arc::new(RpcClientImpl::new(
channel,
self.rpc_config.default_sql_query_timeout,
self.rpc_config.default_write_timeout,
+ metadata,
)))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]