This is an automated email from the ASF dual-hosted git repository. lizhanhui pushed a commit to branch develop in repository https://gitbox.apache.org/repos/asf/rocketmq-client-rust.git
commit 4d2bd2e480ba2a928ef130cf2303b79d586490d3 Author: Li Zhanhui <[email protected]> AuthorDate: Thu Apr 7 03:47:28 2022 +0000 WIP: --- Cargo.lock | 32 +++++++++++++++ Cargo.toml | 8 +++- src/bin/task.rs | 25 ++++++++++++ src/client.rs | 118 +++++++++++++++++++++++++++++++++++++++++++++++------ src/credentials.rs | 1 - src/error.rs | 7 ++++ src/lib.rs | 1 + 7 files changed, 178 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f880360..af8e4df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -265,6 +265,16 @@ dependencies = [ "pin-utils", ] +[[package]] +name = "gethostname" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1ebd34e35c46e00bb73e81363248d627782724609fe1b6396f553f68fe3862e" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "getrandom" version = "0.2.5" @@ -795,9 +805,11 @@ name = "rocketmq" version = "0.1.0" dependencies = [ "futures", + "gethostname", "prost", "prost-types", "rustls", + "thiserror", "tokio", "tokio-rustls", "tonic", @@ -969,6 +981,26 @@ dependencies = [ "winapi", ] +[[package]] +name = "thiserror" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index 52cdba9..1a31fac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,8 @@ prost-types = "0.10" tokio = { version = "1", features = ["full"]} tokio-rustls = "0.23" rustls = {version = "0.20", features = ["default", "dangerous_configuration"]} +gethostname = "0.2" +thiserror = "1.0" [build-dependencies] tonic-build = {version = "0.7", features = ["default", "compression"]} @@ -20,4 +22,8 @@ tonic-build = {version = "0.7", features = ["default", "compression"]} [[bin]] name = "server" -path = "src/bin/server.rs" \ No newline at end of file +path = "src/bin/server.rs" + +[[bin]] +name = "task" +path = "src/bin/task.rs" \ No newline at end of file diff --git a/src/bin/task.rs b/src/bin/task.rs new file mode 100644 index 0000000..bf23764 --- /dev/null +++ b/src/bin/task.rs @@ -0,0 +1,25 @@ +use std::sync::{ + atomic::{self, Ordering}, + Arc, +}; +use tokio::time; + +#[tokio::main] +async fn main() { + let stopped = Arc::new(atomic::AtomicBool::new(false)); + + let stop_flag = Arc::clone(&stopped); + let handle = tokio::spawn(async move { + while (!stop_flag.load(Ordering::Relaxed)) { + tokio::time::sleep(time::Duration::from_secs(1)).await; + println!("Ping"); + } + }); + + let terminate_task = tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + stopped.store(true, Ordering::Relaxed); + }); + + terminate_task.await; +} diff --git a/src/client.rs b/src/client.rs index 3b1cdde..0b0bd85 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,6 +2,7 @@ use crate::pb::{ messaging_service_client::MessagingServiceClient, QueryRouteRequest, QueryRouteResponse, SendMessageRequest, SendMessageResponse, }; +use rustls::client; use tonic::{ metadata::MetadataMap, transport::{Channel, ClientTlsConfig}, @@ -9,38 +10,93 @@ use tonic::{ }; use crate::credentials::CredentialProvider; +use std::collections::HashMap; +use std::rc::Rc; +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, RwLock, + }, + thread, +}; -#[derive(Default)] -struct ClientConfig { +static CLIENT_SEQUENCE: AtomicUsize = AtomicUsize::new(0); + +pub struct ClientConfig { region: String, service_name: String, - resource_namespace: String, + resource_namespace: Option<String>, credential_provider: Option<Box<dyn CredentialProvider>>, - tenant_id: String, + tenant_id: Option<String>, + connect_timeout: std::time::Duration, io_timeout: std::time::Duration, long_polling_timeout: std::time::Duration, - group: String, + group: Option<String>, client_id: String, tracing: bool, } +fn build_client_id() -> String { + let mut client_id = String::new(); + match gethostname::gethostname().into_string() { + Ok(hostname) => { + client_id.push_str(&hostname); + } + Err(_) => { + client_id.push_str("localhost"); + } + }; + client_id.push('@'); + let pid = std::process::id(); + client_id.push_str(&pid.to_string()); + client_id.push('#'); + let sequence = CLIENT_SEQUENCE.fetch_add(1usize, Ordering::Relaxed); + client_id.push_str(&sequence.to_string()); + client_id +} + +impl Default for ClientConfig { + fn default() -> Self { + let client_id = build_client_id(); + Self { + region: String::from("cn-hangzhou"), + service_name: String::from("RocketMQ"), + resource_namespace: None, + credential_provider: None, + tenant_id: None, + connect_timeout: std::time::Duration::from_secs(3), + io_timeout: std::time::Duration::from_secs(3), + long_polling_timeout: std::time::Duration::from_secs(3), + group: None, + client_id, + tracing: false, + } + } +} + pub struct RpcClient { + client_config: Arc<RwLock<ClientConfig>>, stub: MessagingServiceClient<Channel>, peer_address: String, // client_config: std::rc::Rc<ClientConfig>, } impl RpcClient { - pub async fn new(target: String) -> Result<RpcClient, Box<dyn std::error::Error>> { + pub async fn new( + target: String, + client_config: Arc<RwLock<ClientConfig>>, + ) -> Result<RpcClient, Box<dyn std::error::Error>> { + let config = Arc::clone(&client_config); let mut channel = Channel::from_shared(target.clone())? .tcp_nodelay(true) - .connect_timeout(std::time::Duration::from_secs(3)); + .connect_timeout(config.read().unwrap().connect_timeout); if target.starts_with("https://") { channel = channel.tls_config(ClientTlsConfig::new())?; } let channel = channel.connect().await?; let stub = MessagingServiceClient::new(channel); Ok(RpcClient { + client_config, stub, peer_address: target, }) @@ -61,13 +117,16 @@ impl RpcClient { &mut self, request: SendMessageRequest, ) -> Result<Response<SendMessageResponse>, Box<dyn std::error::Error>> { - let mut req = Request::new(request); + let req = Request::new(request); Ok(self.stub.send_message(req).await?) } } #[derive(Default)] -pub struct ClientManager {} +pub struct ClientManager { + client_config: Arc<RwLock<ClientConfig>>, + clients: Mutex<HashMap<String, Rc<RpcClient>>>, +} impl ClientManager { pub async fn start(&self) { @@ -80,25 +139,48 @@ impl ClientManager { }); let _result = handle.await; } + + pub async fn get_rpc_client( + &'static mut self, + endpoint: &str, + ) -> Result<Rc<RpcClient>, Box<dyn std::error::Error>> { + let mut rpc_clients = self.clients.lock()?; + let key = endpoint.to_owned(); + match rpc_clients.get(&key) { + Some(value) => { + return Ok(Rc::clone(value)); + } + None => { + let rpc_client = + RpcClient::new(key.clone(), Arc::clone(&self.client_config)).await?; + let client = Rc::new(rpc_client); + rpc_clients.insert(key, Rc::clone(&client)); + Ok(client) + } + } + } } #[cfg(test)] mod test { use super::*; + use std::collections::HashSet; use crate::pb::{Code, Resource}; #[tokio::test] async fn test_connect() { let target = "http://127.0.0.1:5001"; - let _rpc_client = RpcClient::new(target.to_owned()) + let client_config = Arc::new(RwLock::new(ClientConfig::default())); + let _rpc_client = RpcClient::new(target.to_owned(), client_config) .await .expect("Should be able to connect"); } #[tokio::test] async fn test_connect_staging() { + let client_config = Arc::new(RwLock::new(ClientConfig::default())); let target = "https://mq-inst-1080056302921134-bxuibml7.mq.cn-hangzhou.aliyuncs.com:80"; - let _rpc_client = RpcClient::new(target.to_owned()) + let _rpc_client = RpcClient::new(target.to_owned(), client_config) .await .expect("Failed to connect to staging proxy server"); } @@ -106,7 +188,8 @@ mod test { #[tokio::test] async fn test_query_route() { let target = "http://127.0.0.1:5001"; - let mut rpc_client = RpcClient::new(target.to_owned()) + let client_config = Arc::new(RwLock::new(ClientConfig::default())); + let mut rpc_client = RpcClient::new(target.to_owned(), client_config) .await .expect("Should be able to connect"); let topic = Resource { @@ -126,6 +209,17 @@ mod test { assert_eq!(route_response.status.unwrap().code, Code::Ok as i32); } + #[test] + fn test_build_client_id() { + let mut set = HashSet::new(); + let cnt = 1000; + for _ in 0..cnt { + let client_id = build_client_id(); + set.insert(client_id); + } + assert_eq!(cnt, set.len()); + } + #[tokio::test] async fn test_periodic_task() { let client_manager = ClientManager::default(); diff --git a/src/credentials.rs b/src/credentials.rs index 43d44a9..27812f8 100644 --- a/src/credentials.rs +++ b/src/credentials.rs @@ -82,6 +82,5 @@ mod test { fn test_environment_variable_credential_provider() { let env_credentials_provider = EnvironmentVariableCredentialProvider::new(); assert_eq!(true, env_credentials_provider.is_err()); - } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..6d4fb44 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,7 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ClientError { + #[error("Host name inconvertible to UTF-8")] + InvalidHostName, +} diff --git a/src/lib.rs b/src/lib.rs index 7ffd5e9..a7427ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod client; pub mod credentials; +pub mod error; pub mod pb; pub mod server;
