This is an automated email from the ASF dual-hosted git repository.

lizhanhui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-clients.git


The following commit(s) were added to refs/heads/master by this push:
     new 6fe0f61f test(rust): mock client and session manager (#482)
6fe0f61f is described below

commit 6fe0f61ffe4f90b62372185d27dd21703d3ae38d
Author: SSpirits <[email protected]>
AuthorDate: Tue Apr 18 17:21:09 2023 +0800

    test(rust): mock client and session manager (#482)
---
 rust/Cargo.toml             |   1 +
 rust/src/client.rs          | 303 +++++++++++++++++++++++++-------------------
 rust/src/producer.rs        |  19 ++-
 rust/src/session.rs         |  17 +++
 rust/src/simple_consumer.rs |   2 +-
 5 files changed, 205 insertions(+), 137 deletions(-)

diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 35a229c6..dae0b60e 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -53,6 +53,7 @@ tokio-stream="0.1.12"
 minitrace = "0.4.1"
 
 mockall = "0.11.4"
+mockall_double= "0.3.0"
 
 siphasher = "0.3.10"
 
diff --git a/rust/src/client.rs b/rust/src/client.rs
index 7e50b3b6..7c0fbdf3 100644
--- a/rust/src/client.rs
+++ b/rust/src/client.rs
@@ -18,6 +18,8 @@ use std::clone::Clone;
 use std::string::ToString;
 use std::{collections::HashMap, sync::atomic::AtomicUsize, sync::Arc};
 
+use mockall::automock;
+use mockall_double::double;
 use parking_lot::Mutex;
 use prost_types::Duration;
 use slog::{debug, error, info, o, warn, Logger};
@@ -31,11 +33,13 @@ use crate::model::message::AckMessageEntry;
 use crate::pb;
 use crate::pb::receive_message_response::Content;
 use crate::pb::{
-    AckMessageRequest, AckMessageResultEntry, Code, FilterExpression, 
HeartbeatRequest, Message,
-    MessageQueue, QueryRouteRequest, ReceiveMessageRequest, Resource, 
SendMessageRequest,
-    SendResultEntry, Status, TelemetryCommand,
+    AckMessageRequest, AckMessageResultEntry, Code, FilterExpression, 
HeartbeatRequest,
+    HeartbeatResponse, Message, MessageQueue, QueryRouteRequest, 
ReceiveMessageRequest, Resource,
+    SendMessageRequest, SendResultEntry, Status, TelemetryCommand,
 };
-use crate::session::{RPCClient, Session, SessionManager};
+#[double]
+use crate::session::SessionManager;
+use crate::session::{RPCClient, Session};
 
 #[derive(Debug)]
 pub(crate) struct Client {
@@ -52,6 +56,7 @@ lazy_static::lazy_static! {
     static ref CLIENT_ID_SEQUENCE: AtomicUsize = AtomicUsize::new(0);
 }
 
+#[automock]
 impl Client {
     const OPERATION_CLIENT_NEW: &'static str = "client.new";
     const OPERATION_QUERY_ROUTE: &'static str = "client.query_route";
@@ -80,57 +85,6 @@ impl Client {
         })
     }
 
-    async fn heart_beat(
-        logger: &Logger,
-        session_manager: Arc<SessionManager>,
-        group: &str,
-        namespace: &str,
-        client_type: &ClientType,
-    ) {
-        let sessions = session_manager.get_all_sessions().await;
-        if sessions.is_err() {
-            error!(
-                logger,
-                "send heartbeat failed: failed to get sessions: {}",
-                sessions.unwrap_err()
-            );
-            return;
-        }
-        for mut session in sessions.unwrap() {
-            let request = HeartbeatRequest {
-                group: Some(Resource {
-                    name: group.to_string(),
-                    resource_namespace: namespace.to_string(),
-                }),
-                client_type: client_type.clone() as i32,
-            };
-            let response = session.heartbeat(request).await;
-            if response.is_err() {
-                error!(
-                    logger,
-                    "send heartbeat failed: failed to send heartbeat rpc: {}",
-                    response.unwrap_err()
-                );
-                return;
-            }
-            let result =
-                Self::handle_response_status(response.unwrap().status, 
Self::OPERATION_HEARTBEAT);
-            if result.is_err() {
-                error!(
-                    logger,
-                    "send heartbeat failed: server return error: {}",
-                    result.unwrap_err()
-                );
-                return;
-            }
-            debug!(
-                logger,
-                "send heartbeat to server success, peer={}",
-                session.peer()
-            );
-        }
-    }
-
     pub(crate) fn start(&self) {
         let logger = self.logger.clone();
         let session_manager = self.session_manager.clone();
@@ -144,7 +98,39 @@ impl Client {
             loop {
                 select! {
                     _ = interval.tick() => {
-                        Self::heart_beat(&logger, session_manager.clone(), 
&group, &namespace, &client_type).await;
+                        let sessions = 
session_manager.get_all_sessions().await;
+                        if sessions.is_err() {
+                            error!(
+                                logger,
+                                "send heartbeat failed: failed to get 
sessions: {}",
+                                sessions.unwrap_err()
+                            );
+                            continue;
+                        }
+
+                        for session in sessions.unwrap() {
+                            let peer = session.peer().to_string();
+                            let response = Self::heart_beat_inner(session, 
&group, &namespace, &client_type).await;
+                            if response.is_err() {
+                                error!(
+                                    logger,
+                                    "send heartbeat failed: failed to send 
heartbeat rpc: {}",
+                                    response.unwrap_err()
+                                );
+                                continue;
+                            }
+                            let result =
+                                
Self::handle_response_status(response.unwrap().status, 
Self::OPERATION_HEARTBEAT);
+                            if result.is_err() {
+                                error!(
+                                    logger,
+                                    "send heartbeat failed: server return 
error: {}",
+                                    result.unwrap_err()
+                                );
+                                continue;
+                            }
+                            debug!(logger,"send heartbeat to server success, 
peer={}",peer);
+                        }
                     },
                 }
             }
@@ -242,9 +228,9 @@ impl Client {
             .await
     }
 
-    async fn query_topic_route(
+    async fn query_topic_route<T: RPCClient + 'static>(
         &self,
-        mut rpc_client: impl RPCClient,
+        mut rpc_client: T,
         topic: &str,
     ) -> Result<Route, ClientError> {
         let request = QueryRouteRequest {
@@ -265,9 +251,9 @@ impl Client {
         Ok(route)
     }
 
-    async fn topic_route_inner(
+    async fn topic_route_inner<T: RPCClient + 'static>(
         &self,
-        rpc_client: impl RPCClient,
+        rpc_client: T,
         topic: &str,
     ) -> Result<Arc<Route>, ClientError> {
         debug!(self.logger, "query route for topic={}", topic);
@@ -349,6 +335,23 @@ impl Client {
         }
     }
 
+    async fn heart_beat_inner<T: RPCClient + 'static>(
+        mut rpc_client: T,
+        group: &str,
+        namespace: &str,
+        client_type: &ClientType,
+    ) -> Result<HeartbeatResponse, ClientError> {
+        let request = HeartbeatRequest {
+            group: Some(Resource {
+                name: group.to_string(),
+                resource_namespace: namespace.to_string(),
+            }),
+            client_type: client_type.clone() as i32,
+        };
+        let response = rpc_client.heartbeat(request).await?;
+        Ok(response)
+    }
+
     pub(crate) async fn send_message(
         &self,
         endpoints: &Endpoints,
@@ -361,9 +364,9 @@ impl Client {
         .await
     }
 
-    pub(crate) async fn send_message_inner(
+    pub(crate) async fn send_message_inner<T: RPCClient + 'static>(
         &self,
-        mut rpc_client: impl RPCClient,
+        mut rpc_client: T,
         messages: Vec<Message>,
     ) -> Result<Vec<SendResultEntry>, ClientError> {
         let message_count = messages.len();
@@ -396,9 +399,9 @@ impl Client {
         .await
     }
 
-    pub(crate) async fn receive_message_inner(
+    pub(crate) async fn receive_message_inner<T: RPCClient + 'static>(
         &self,
-        mut rpc_client: impl RPCClient,
+        mut rpc_client: T,
         message_queue: MessageQueue,
         expression: FilterExpression,
         batch_size: i32,
@@ -435,9 +438,9 @@ impl Client {
         Ok(messages)
     }
 
-    pub(crate) async fn ack_message(
+    pub(crate) async fn ack_message<T: AckMessageEntry + 'static>(
         &self,
-        ack_entry: impl AckMessageEntry,
+        ack_entry: T,
     ) -> Result<AckMessageResultEntry, ClientError> {
         let result = self
             .ack_message_inner(
@@ -454,9 +457,9 @@ impl Client {
         Ok(result[0].clone())
     }
 
-    pub(crate) async fn ack_message_inner(
+    pub(crate) async fn ack_message_inner<T: RPCClient + 'static>(
         &self,
-        mut rpc_client: impl RPCClient,
+        mut rpc_client: T,
         topic: String,
         entries: Vec<pb::AckMessageEntry>,
     ) -> Result<Vec<AckMessageResultEntry>, ClientError> {
@@ -488,16 +491,40 @@ mod tests {
     use crate::conf::ClientOption;
     use crate::error::{ClientError, ErrorKind};
     use crate::log::terminal_logger;
-    use crate::model::common::Route;
+    use crate::model::common::{ClientType, Route};
     use crate::pb::receive_message_response::Content;
     use crate::pb::{
-        AckMessageEntry, AckMessageResponse, Code, FilterExpression, Message, 
MessageQueue,
-        QueryRouteResponse, ReceiveMessageResponse, Resource, 
SendMessageResponse, Status,
-        TelemetryCommand,
+        AckMessageEntry, AckMessageResponse, Code, FilterExpression, 
HeartbeatResponse, Message,
+        MessageQueue, QueryRouteResponse, ReceiveMessageResponse, Resource, 
SendMessageResponse,
+        Status, TelemetryCommand,
     };
     use crate::session;
 
-    use super::CLIENT_ID_SEQUENCE;
+    use super::*;
+
+    fn new_client_for_test() -> Client {
+        Client {
+            logger: terminal_logger(),
+            option: ClientOption::default(),
+            session_manager: Arc::new(SessionManager::default()),
+            route_table: Mutex::new(HashMap::new()),
+            id: Client::generate_client_id(),
+            access_endpoints: 
Endpoints::from_url("http://localhost:8081";).unwrap(),
+            settings: TelemetryCommand::default(),
+        }
+    }
+
+    fn new_client_with_session_manager(session_manager: SessionManager) -> 
Client {
+        Client {
+            logger: terminal_logger(),
+            option: ClientOption::default(),
+            session_manager: Arc::new(session_manager),
+            route_table: Mutex::new(HashMap::new()),
+            id: Client::generate_client_id(),
+            access_endpoints: 
Endpoints::from_url("http://localhost:8081";).unwrap(),
+            settings: TelemetryCommand::default(),
+        }
+    }
 
     #[test]
     fn client_id_sequence() {
@@ -506,6 +533,51 @@ mod tests {
         assert!(v2 > v1, "Client ID sequence should be increasing");
     }
 
+    #[test]
+    fn client_new() -> Result<(), ClientError> {
+        let ctx = crate::session::MockSessionManager::new_context();
+        ctx.expect()
+            .return_once(|_, _, _| SessionManager::default());
+        Client::new(
+            &terminal_logger(),
+            ClientOption::default(),
+            TelemetryCommand::default(),
+        )?;
+        Ok(())
+    }
+
+    #[tokio::test(flavor = "multi_thread")]
+    async fn client_start() -> Result<(), ClientError> {
+        let mut session_manager = SessionManager::default();
+        session_manager
+            .expect_get_all_sessions()
+            .returning(|| Ok(vec![]));
+
+        let client = new_client_with_session_manager(session_manager);
+        client.start();
+
+        // TODO use countdown latch instead sleeping
+        // wait for run
+        tokio::time::sleep(Duration::from_secs(1)).await;
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn client_get_session() {
+        let mut session_manager = SessionManager::default();
+        session_manager
+            .expect_get_or_create_session()
+            .returning(|_, _| Ok(Session::mock()));
+
+        let client = new_client_with_session_manager(session_manager);
+        let result = client.get_session().await;
+        assert!(result.is_ok());
+        let result = client
+            
.get_session_with_endpoints(&Endpoints::from_url("localhost:8081").unwrap())
+            .await;
+        assert!(result.is_ok());
+    }
+
     #[test]
     fn handle_response_status() {
         let result = Client::handle_response_status(None, "test");
@@ -572,16 +644,10 @@ mod tests {
 
     #[tokio::test]
     async fn client_query_route_from_cache() {
-        let logger = terminal_logger();
-        let client = Client::new(
-            &logger,
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         client.route_table.lock().insert(
             "DefaultCluster".to_string(),
-            super::RouteStatus::Found(Arc::new(Route {
+            RouteStatus::Found(Arc::new(Route {
                 index: AtomicUsize::new(0),
                 queue: vec![],
             })),
@@ -592,13 +658,7 @@ mod tests {
 
     #[tokio::test]
     async fn client_query_route() {
-        let logger = terminal_logger();
-        let client = Client::new(
-            &logger,
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
 
         let mut mock = session::MockRPCClient::new();
         mock.expect_query_route()
@@ -620,13 +680,7 @@ mod tests {
 
     #[tokio::test(flavor = "multi_thread")]
     async fn client_query_route_with_inflight_request() {
-        let logger = terminal_logger();
-        let client = Client::new(
-            &logger,
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         let client = Arc::new(client);
 
         let client_clone = client.clone();
@@ -653,13 +707,7 @@ mod tests {
 
     #[tokio::test(flavor = "multi_thread")]
     async fn client_query_route_with_failed_request() {
-        let logger = terminal_logger();
-        let client = Client::new(
-            &logger,
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         let client = Arc::new(client);
 
         let client_clone = client.clone();
@@ -688,6 +736,22 @@ mod tests {
         awaitility::at_most(Duration::from_secs(1)).until(|| 
handle.is_finished());
     }
 
+    #[tokio::test]
+    async fn client_heartbeat() {
+        let response = Ok(HeartbeatResponse {
+            status: Some(Status {
+                code: Code::Ok as i32,
+                message: "Success".to_string(),
+            }),
+        });
+        let mut mock = session::MockRPCClient::new();
+        mock.expect_heartbeat()
+            .return_once(|_| Box::pin(futures::future::ready(response)));
+
+        let send_result = Client::heart_beat_inner(mock, "", "", 
&ClientType::Producer).await;
+        assert!(send_result.is_ok());
+    }
+
     #[tokio::test]
     async fn client_send_message() {
         let response = Ok(SendMessageResponse {
@@ -701,12 +765,7 @@ mod tests {
         mock.expect_send_message()
             .return_once(|_| Box::pin(futures::future::ready(response)));
 
-        let client = Client::new(
-            &terminal_logger(),
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         let send_result = client.send_message_inner(mock, vec![]).await;
         assert!(send_result.is_ok());
 
@@ -723,12 +782,7 @@ mod tests {
         mock.expect_receive_message()
             .return_once(|_| Box::pin(futures::future::ready(response)));
 
-        let client = Client::new(
-            &terminal_logger(),
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         let receive_result = client
             .receive_message_inner(
                 mock,
@@ -756,12 +810,7 @@ mod tests {
         mock.expect_receive_message()
             .return_once(|_| Box::pin(futures::future::ready(response)));
 
-        let client = Client::new(
-            &terminal_logger(),
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         let receive_result = client
             .receive_message_inner(
                 mock,
@@ -792,12 +841,7 @@ mod tests {
         mock.expect_ack_message()
             .return_once(|_| Box::pin(futures::future::ready(response)));
 
-        let client = Client::new(
-            &terminal_logger(),
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         let ack_entries: Vec<AckMessageEntry> = vec![];
         let ack_result = client
             .ack_message_inner(mock, "test_topic".to_string(), ack_entries)
@@ -819,12 +863,7 @@ mod tests {
         mock.expect_ack_message()
             .return_once(|_| Box::pin(futures::future::ready(response)));
 
-        let client = Client::new(
-            &terminal_logger(),
-            ClientOption::default(),
-            TelemetryCommand::default(),
-        )
-        .unwrap();
+        let client = new_client_for_test();
         let ack_entries: Vec<AckMessageEntry> = vec![];
         let ack_result = client
             .ack_message_inner(mock, "test_topic".to_string(), ack_entries)
diff --git a/rust/src/producer.rs b/rust/src/producer.rs
index dda1e69f..983d40d0 100644
--- a/rust/src/producer.rs
+++ b/rust/src/producer.rs
@@ -22,9 +22,11 @@
 
 use std::time::{SystemTime, UNIX_EPOCH};
 
+use mockall_double::double;
 use prost_types::Timestamp;
 use slog::{info, Logger};
 
+#[double]
 use crate::client::Client;
 use crate::conf::{ClientOption, ProducerOption};
 use crate::error::{ClientError, ErrorKind};
@@ -205,10 +207,19 @@ impl Producer {
 
 #[cfg(test)]
 mod tests {
-    use crate::conf::{ClientOption, ProducerOption};
     use crate::error::ErrorKind;
+    use crate::log::terminal_logger;
     use crate::model::message::MessageImpl;
-    use crate::producer::Producer;
+
+    use super::*;
+
+    fn new_producer_for_test() -> Producer {
+        Producer {
+            option: Default::default(),
+            logger: terminal_logger(),
+            client: Client::default(),
+        }
+    }
 
     // #[tokio::test]
     // async fn producer_start() {
@@ -222,7 +233,7 @@ mod tests {
 
     #[tokio::test]
     async fn producer_transform_messages_to_protobuf() {
-        let producer = Producer::new(ProducerOption::default(), 
ClientOption::default()).unwrap();
+        let producer = new_producer_for_test();
         let messages = vec![MessageImpl::builder()
             .set_topic("DefaultCluster")
             .set_body("hello world".as_bytes().to_vec())
@@ -253,7 +264,7 @@ mod tests {
 
     #[tokio::test]
     async fn producer_transform_messages_to_protobuf_failed() {
-        let producer = Producer::new(ProducerOption::default(), 
ClientOption::default()).unwrap();
+        let producer = new_producer_for_test();
 
         let messages: Vec<MessageImpl> = vec![];
         let result = producer.transform_messages_to_protobuf(messages);
diff --git a/rust/src/session.rs b/rust/src/session.rs
index fcb212ad..268e5fb4 100644
--- a/rust/src/session.rs
+++ b/rust/src/session.rs
@@ -27,6 +27,7 @@ use tonic::transport::{Channel, Endpoint};
 
 use crate::conf::ClientOption;
 use crate::error::ErrorKind;
+use crate::log::terminal_logger;
 use crate::model::common::Endpoints;
 use crate::pb::{
     AckMessageRequest, AckMessageResponse, HeartbeatRequest, 
HeartbeatResponse, QueryRouteRequest,
@@ -87,6 +88,20 @@ impl Session {
     const HTTP_SCHEMA: &'static str = "http";
     const HTTPS_SCHEMA: &'static str = "https";
 
+    #[cfg(test)]
+    pub(crate) fn mock() -> Self {
+        Session {
+            logger: terminal_logger(),
+            client_id: "fake_id".to_string(),
+            option: ClientOption::default(),
+            endpoints: Endpoints::from_url("http://localhost:8081";).unwrap(),
+            stub: MessagingServiceClient::new(
+                Channel::from_static("http://localhost:8081";).connect_lazy(),
+            ),
+            telemetry_tx: Box::new(None),
+        }
+    }
+
     async fn new(
         logger: &Logger,
         endpoints: &Endpoints,
@@ -277,6 +292,7 @@ impl Session {
     }
 }
 
+#[automock]
 #[async_trait]
 impl RPCClient for Session {
     async fn query_route(
@@ -388,6 +404,7 @@ pub(crate) struct SessionManager {
     session_map: Mutex<HashMap<String, Session>>,
 }
 
+#[automock]
 impl SessionManager {
     pub(crate) fn new(logger: &Logger, client_id: String, option: 
&ClientOption) -> Self {
         let logger = logger.new(o!("component" => "session"));
diff --git a/rust/src/simple_consumer.rs b/rust/src/simple_consumer.rs
index e44f7cd9..bccada6b 100644
--- a/rust/src/simple_consumer.rs
+++ b/rust/src/simple_consumer.rs
@@ -120,7 +120,7 @@ impl SimpleConsumer {
             .collect())
     }
 
-    pub async fn ack(&self, ack_entry: impl AckMessageEntry) -> Result<(), 
ClientError> {
+    pub async fn ack(&self, ack_entry: impl AckMessageEntry + 'static) -> 
Result<(), ClientError> {
         self.client.ack_message(ack_entry).await?;
         Ok(())
     }

Reply via email to