This is an automated email from the ASF dual-hosted git repository.

wangxin pushed a commit to branch 2.x
in repository https://gitbox.apache.org/repos/asf/incubator-seata.git


The following commit(s) were added to refs/heads/2.x by this push:
     new dbeba6120b optimize: expand unit test coverage for the [rocketmq] 
module. (#6927)
dbeba6120b is described below

commit dbeba6120ba36851f884271132df6b6363a843d0
Author: psxjoy <[email protected]>
AuthorDate: Wed Oct 23 18:24:10 2024 +0800

    optimize: expand unit test coverage for the [rocketmq] module. (#6927)
    
    feature: add unit-test for rocketmq module
---
 changes/en-us/2.x.md                               |   2 +-
 changes/zh-cn/2.x.md                               |   2 +-
 .../integration/rocketmq/SeataMQProducer.java      |  34 ++-
 .../integration/rocketmq/SeataMQProducerTest.java  | 340 ++++++++++++++++++++-
 .../integration/rocketmq/TCCRocketMQImplTest.java  | 267 ++++++++++++++++
 5 files changed, 629 insertions(+), 16 deletions(-)

diff --git a/changes/en-us/2.x.md b/changes/en-us/2.x.md
index a9afb386fa..9156a0b1b3 100644
--- a/changes/en-us/2.x.md
+++ b/changes/en-us/2.x.md
@@ -38,7 +38,7 @@ Add changes here for all PR submitted to the 2.x branch.
 ### security:
 
 ### test:
-
+- [[#6927](https://github.com/apache/incubator-seata/pull/6927)] Add unit 
tests for the `seata-rocketmq` module
 
 Thanks to these contributors for their code commits. Please report an 
unintended omission.
 
diff --git a/changes/zh-cn/2.x.md b/changes/zh-cn/2.x.md
index f7934c454a..eafcc9b80e 100644
--- a/changes/zh-cn/2.x.md
+++ b/changes/zh-cn/2.x.md
@@ -41,7 +41,7 @@
 ### security:
 
 ### test:
-
+- [[#6927](https://github.com/apache/incubator-seata/pull/6927)] 
增加`seata-rocketmq`模块的测试用例
 
 非常感谢以下 contributors 的代码贡献。若有无意遗漏,请报告。
 
diff --git 
a/rocketmq/src/main/java/org/apache/seata/integration/rocketmq/SeataMQProducer.java
 
b/rocketmq/src/main/java/org/apache/seata/integration/rocketmq/SeataMQProducer.java
index 2846d00073..decd1c9060 100644
--- 
a/rocketmq/src/main/java/org/apache/seata/integration/rocketmq/SeataMQProducer.java
+++ 
b/rocketmq/src/main/java/org/apache/seata/integration/rocketmq/SeataMQProducer.java
@@ -16,16 +16,12 @@
  */
 package org.apache.seata.integration.rocketmq;
 
-import org.apache.rocketmq.client.producer.SendStatus;
-import org.apache.seata.common.util.StringUtils;
-import org.apache.seata.core.context.RootContext;
-import org.apache.seata.core.model.GlobalStatus;
-import org.apache.seata.rm.DefaultResourceManager;
 import org.apache.rocketmq.client.Validators;
 import org.apache.rocketmq.client.exception.MQBrokerException;
 import org.apache.rocketmq.client.exception.MQClientException;
 import org.apache.rocketmq.client.producer.LocalTransactionState;
 import org.apache.rocketmq.client.producer.SendResult;
+import org.apache.rocketmq.client.producer.SendStatus;
 import org.apache.rocketmq.client.producer.TransactionListener;
 import org.apache.rocketmq.client.producer.TransactionMQProducer;
 import org.apache.rocketmq.common.message.Message;
@@ -34,6 +30,10 @@ import org.apache.rocketmq.common.message.MessageConst;
 import org.apache.rocketmq.common.message.MessageExt;
 import org.apache.rocketmq.remoting.RPCHook;
 import org.apache.rocketmq.remoting.exception.RemotingException;
+import org.apache.seata.common.util.StringUtils;
+import org.apache.seata.core.context.RootContext;
+import org.apache.seata.core.model.GlobalStatus;
+import org.apache.seata.rm.DefaultResourceManager;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -47,8 +47,10 @@ public class SeataMQProducer extends TransactionMQProducer {
 
     private static final Logger LOGGER = 
LoggerFactory.getLogger(SeataMQProducer.class);
 
-    private static final List<GlobalStatus> COMMIT_STATUSES = 
Arrays.asList(GlobalStatus.Committed, GlobalStatus.Committing, 
GlobalStatus.CommitRetrying);
-    private static final List<GlobalStatus> ROLLBACK_STATUSES = 
Arrays.asList(GlobalStatus.Rollbacked, GlobalStatus.Rollbacking, 
GlobalStatus.RollbackRetrying);
+    private static final List<GlobalStatus> COMMIT_STATUSES =
+        Arrays.asList(GlobalStatus.Committed, GlobalStatus.Committing, 
GlobalStatus.CommitRetrying);
+    private static final List<GlobalStatus> ROLLBACK_STATUSES =
+        Arrays.asList(GlobalStatus.Rollbacked, GlobalStatus.Rollbacking, 
GlobalStatus.RollbackRetrying);
 
     public static String PROPERTY_SEATA_XID = RootContext.KEY_XID;
     public static String PROPERTY_SEATA_BRANCHID = RootContext.KEY_BRANCHID;
@@ -75,7 +77,8 @@ public class SeataMQProducer extends TransactionMQProducer {
                     LOGGER.error("msg has no xid, msgTransactionId: {}, msg 
will be rollback", msg.getTransactionId());
                     return LocalTransactionState.ROLLBACK_MESSAGE;
                 }
-                GlobalStatus globalStatus = 
DefaultResourceManager.get().getGlobalStatus(SeataMQProducerFactory.ROCKET_BRANCH_TYPE,
 xid);
+                GlobalStatus globalStatus =
+                    
DefaultResourceManager.get().getGlobalStatus(SeataMQProducerFactory.ROCKET_BRANCH_TYPE,
 xid);
                 if (COMMIT_STATUSES.contains(globalStatus)) {
                     return LocalTransactionState.COMMIT_MESSAGE;
                 } else if (ROLLBACK_STATUSES.contains(globalStatus) || 
GlobalStatus.isOnePhaseTimeout(globalStatus)) {
@@ -90,12 +93,14 @@ public class SeataMQProducer extends TransactionMQProducer {
     }
 
     @Override
-    public SendResult send(Message msg) throws MQClientException, 
MQBrokerException, RemotingException, InterruptedException {
+    public SendResult send(Message msg)
+        throws MQClientException, MQBrokerException, RemotingException, 
InterruptedException {
         return send(msg, this.getSendMsgTimeout());
     }
 
     @Override
-    public SendResult send(Message msg, long timeout) throws 
MQClientException, MQBrokerException, RemotingException, InterruptedException {
+    public SendResult send(Message msg, long timeout)
+        throws MQClientException, MQBrokerException, RemotingException, 
InterruptedException {
         if (RootContext.inGlobalTransaction()) {
             if (tccRocketMQ == null) {
                 throw new RuntimeException("TCCRocketMQ is not initialized");
@@ -106,7 +111,8 @@ public class SeataMQProducer extends TransactionMQProducer {
         }
     }
 
-    public SendResult doSendMessageInTransaction(final Message msg, long 
timeout, String xid, long branchId) throws MQClientException {
+    public SendResult doSendMessageInTransaction(final Message msg, long 
timeout, String xid, long branchId)
+        throws MQClientException {
         msg.setTopic(withNamespace(msg.getTopic()));
         if (msg.getDelayTimeLevel() != 0) {
             MessageAccessor.clearProperty(msg, 
MessageConst.PROPERTY_DELAY_TIME_LEVEL);
@@ -119,7 +125,7 @@ public class SeataMQProducer extends TransactionMQProducer {
         MessageAccessor.putProperty(msg, PROPERTY_SEATA_XID, xid);
         MessageAccessor.putProperty(msg, PROPERTY_SEATA_BRANCHID, 
String.valueOf(branchId));
         try {
-            sendResult = super.send(msg, timeout);
+            sendResult = superSend(msg, timeout);
         } catch (Exception e) {
             throw new MQClientException("send message Exception", e);
         }
@@ -137,6 +143,10 @@ public class SeataMQProducer extends TransactionMQProducer 
{
         return sendResult;
     }
 
+    public SendResult superSend(Message msg, long timeout)
+        throws MQClientException, MQBrokerException, RemotingException, 
InterruptedException {
+        return super.send(msg, timeout);
+    }
 
     @Override
     public TransactionListener getTransactionListener() {
diff --git 
a/rocketmq/src/test/java/org/apache/seata/integration/rocketmq/SeataMQProducerTest.java
 
b/rocketmq/src/test/java/org/apache/seata/integration/rocketmq/SeataMQProducerTest.java
index 7b8ab979d5..a43bb0fb20 100644
--- 
a/rocketmq/src/test/java/org/apache/seata/integration/rocketmq/SeataMQProducerTest.java
+++ 
b/rocketmq/src/test/java/org/apache/seata/integration/rocketmq/SeataMQProducerTest.java
@@ -16,16 +16,352 @@
  */
 package org.apache.seata.integration.rocketmq;
 
+import org.apache.rocketmq.client.exception.MQBrokerException;
+import org.apache.rocketmq.client.exception.MQClientException;
+import org.apache.rocketmq.client.producer.LocalTransactionState;
+import org.apache.rocketmq.client.producer.SendResult;
+import org.apache.rocketmq.client.producer.SendStatus;
+import org.apache.rocketmq.client.producer.TransactionListener;
+import org.apache.rocketmq.client.producer.TransactionMQProducer;
+import org.apache.rocketmq.common.message.Message;
+import org.apache.rocketmq.common.message.MessageAccessor;
+import org.apache.rocketmq.common.message.MessageConst;
+import org.apache.rocketmq.common.message.MessageExt;
+import org.apache.rocketmq.remoting.exception.RemotingException;
+import org.apache.seata.core.context.RootContext;
+import org.apache.seata.core.model.GlobalStatus;
+import org.apache.seata.rm.DefaultResourceManager;
+import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.doCallRealMethod;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
 
 /**
  * seata mq producer test
  **/
 public class SeataMQProducerTest {
 
+    @Mock
+    private TransactionMQProducer transactionMQProducer;
+
+    private TCCRocketMQ tccRocketMQ;
+    @InjectMocks
+    private SeataMQProducer producer;
+    private SeataMQProducer producerTwo;
+    private SeataMQProducer seataMQProducer;
+    private TransactionListener transactionListener;
+
+    @BeforeEach
+    void setUp() {
+        producer = Mockito.spy(new SeataMQProducer("testGroup"));
+        seataMQProducer = spy(new SeataMQProducer("testGroup"));
+        tccRocketMQ = mock(TCCRocketMQImpl.class);
+        producer.setTccRocketMQ(tccRocketMQ);
+        producerTwo = new SeataMQProducer("namespace", "producerGroup", null);
+        transactionListener = producerTwo.getTransactionListener();
+    }
+
     @Test
-    public void testCreate(){
+    public void testCreate() {
         new SeataMQProducer("testProducerGroup");
-        new SeataMQProducer("testNamespace", "testProducerGroup",null);
+        new SeataMQProducer("testNamespace", "testProducerGroup", null);
+    }
+
+    @Test
+    void testExecuteLocalTransaction() {
+        Message msg = new Message();
+        assertEquals(LocalTransactionState.UNKNOW, 
transactionListener.executeLocalTransaction(msg, null));
+    }
+
+    @Test
+    void testCheckLocalTransactionWithNoXid() {
+        MessageExt msg = new MessageExt();
+        msg.setTransactionId("testTransactionId");
+        assertEquals(LocalTransactionState.ROLLBACK_MESSAGE, 
transactionListener.checkLocalTransaction(msg));
+    }
+
+    @Test
+    void testCheckLocalTransactionWithCommitStatus() {
+        MessageExt msg = new MessageExt();
+        msg.putUserProperty(SeataMQProducer.PROPERTY_SEATA_XID, "testXid");
+
+        try (MockedStatic<DefaultResourceManager> mockedStatic = 
mockStatic(DefaultResourceManager.class)) {
+            DefaultResourceManager mockResourceManager = 
mock(DefaultResourceManager.class);
+            
mockedStatic.when(DefaultResourceManager::get).thenReturn(mockResourceManager);
+            
when(mockResourceManager.getGlobalStatus(SeataMQProducerFactory.ROCKET_BRANCH_TYPE,
 "testXid")).thenReturn(
+                GlobalStatus.Committed);
+
+            assertEquals(LocalTransactionState.COMMIT_MESSAGE, 
transactionListener.checkLocalTransaction(msg));
+        }
+    }
+
+    @Test
+    void testCheckLocalTransactionWithRollbackStatus() {
+        MessageExt msg = new MessageExt();
+        msg.putUserProperty(SeataMQProducer.PROPERTY_SEATA_XID, "testXid");
+
+        try (MockedStatic<DefaultResourceManager> mockedStatic = 
mockStatic(DefaultResourceManager.class)) {
+            DefaultResourceManager mockResourceManager = 
mock(DefaultResourceManager.class);
+            
mockedStatic.when(DefaultResourceManager::get).thenReturn(mockResourceManager);
+            
when(mockResourceManager.getGlobalStatus(SeataMQProducerFactory.ROCKET_BRANCH_TYPE,
 "testXid")).thenReturn(
+                GlobalStatus.Rollbacked);
+
+            assertEquals(LocalTransactionState.ROLLBACK_MESSAGE, 
transactionListener.checkLocalTransaction(msg));
+        }
+    }
+
+    @Test
+    void testCheckLocalTransactionWithFinishedStatus() {
+        MessageExt msg = new MessageExt();
+        msg.putUserProperty(SeataMQProducer.PROPERTY_SEATA_XID, "testXid");
+
+        try (MockedStatic<DefaultResourceManager> mockedStatic = 
mockStatic(DefaultResourceManager.class)) {
+            DefaultResourceManager mockResourceManager = 
mock(DefaultResourceManager.class);
+            
mockedStatic.when(DefaultResourceManager::get).thenReturn(mockResourceManager);
+            
when(mockResourceManager.getGlobalStatus(SeataMQProducerFactory.ROCKET_BRANCH_TYPE,
 "testXid")).thenReturn(
+                GlobalStatus.Finished);
+
+            assertEquals(LocalTransactionState.ROLLBACK_MESSAGE, 
transactionListener.checkLocalTransaction(msg));
+        }
+    }
+
+    @Test
+    void testCheckLocalTransactionWithUnknownStatus() {
+        MessageExt msg = new MessageExt();
+        msg.putUserProperty(SeataMQProducer.PROPERTY_SEATA_XID, "testXid");
+
+        try (MockedStatic<DefaultResourceManager> mockedStatic = 
mockStatic(DefaultResourceManager.class)) {
+            DefaultResourceManager mockResourceManager = 
mock(DefaultResourceManager.class);
+            
mockedStatic.when(DefaultResourceManager::get).thenReturn(mockResourceManager);
+            
when(mockResourceManager.getGlobalStatus(SeataMQProducerFactory.ROCKET_BRANCH_TYPE,
 "testXid")).thenReturn(
+                GlobalStatus.Begin);
+
+            assertEquals(LocalTransactionState.UNKNOW, 
transactionListener.checkLocalTransaction(msg));
+        }
+    }
+
+    @Test
+    void testSendWithoutGlobalTransaction()
+        throws MQClientException, RemotingException, MQBrokerException, 
InterruptedException {
+        Message msg = new Message("testTopic", "testBody".getBytes());
+        long timeout = 3000L;
+        SendResult expectedResult = mock(SendResult.class);
+
+        doReturn(expectedResult).when(producer).send(msg, timeout);
+
+        SendResult result = producer.send(msg, timeout);
+
+        assertSame(expectedResult, result);
+        verify(producer).send(msg, timeout);
+        verifyNoInteractions(tccRocketMQ);
+    }
+
+    @Test
+    void testSendWithGlobalTransaction()
+        throws MQClientException, RemotingException, MQBrokerException, 
InterruptedException {
+        Message msg = new Message("testTopic", "testBody".getBytes());
+        long timeout = 3000L;
+        SendResult expectedResult = mock(SendResult.class);
+
+        RootContext.bind("DummyXID");
+        try {
+            when(tccRocketMQ.prepare(msg, timeout)).thenReturn(expectedResult);
+
+            SendResult result = producer.send(msg, timeout);
+
+            assertSame(expectedResult, result);
+            verify(tccRocketMQ).prepare(msg, timeout);
+        } finally {
+            RootContext.unbind();
+        }
+    }
+
+    @Test
+    void testSendWithGlobalTransactionAndNullTCCRocketMQ() {
+        Message msg = new Message("testTopic", "testBody".getBytes());
+        long timeout = 3000L;
+
+        producer.setTccRocketMQ(null);
+        RootContext.bind("DummyXID");
+        try {
+            assertThrows(RuntimeException.class, () -> producer.send(msg, 
timeout));
+        } finally {
+            RootContext.unbind();
+        }
+    }
+
+    @Test
+    void testSend() throws MQClientException, RemotingException, 
MQBrokerException, InterruptedException {
+
+        Message msg = new Message("testTopic", "testBody".getBytes());
+        SendResult expectedResult = mock(SendResult.class);
+        int expectedTimeout = 3000;
+
+        doReturn(expectedTimeout).when(producer).getSendMsgTimeout();
+        doReturn(expectedResult).when(producer).send(any(Message.class), 
anyLong());
+
+        SendResult result = producer.send(msg);
+
+        assertSame(expectedResult, result);
+        verify(producer).send(msg, expectedTimeout);
     }
+
+    @Test
+    void testSendWithException() throws MQClientException, RemotingException, 
MQBrokerException, InterruptedException {
+
+        Message msg = new Message("testTopic", "testBody".getBytes());
+        int expectedTimeout = 3000;
+
+        doReturn(expectedTimeout).when(producer).getSendMsgTimeout();
+        doThrow(new MQClientException("Test exception", 
null)).when(producer).send(any(Message.class), anyInt());
+
+        assertThrows(MQClientException.class, () -> producer.send(msg));
+        verify(producer).send(msg, expectedTimeout);
+    }
+
+    @Test
+    void testDoSendMessageInTransactionWithNonOkStatus() throws Exception {
+
+        Message msg = new Message("testTopic", "testBody".getBytes());
+        long timeout = 3000L;
+        String xid = "testXid";
+        long branchId = 123L;
+
+        SendResult mockSendResult = mock(SendResult.class);
+        
when(mockSendResult.getSendStatus()).thenReturn(SendStatus.FLUSH_DISK_TIMEOUT);
+
+        doReturn(mockSendResult).when(producer).send(any(Message.class), 
anyLong());
+
+        assertThrows(MQClientException.class, () -> 
producer.doSendMessageInTransaction(msg, timeout, xid, branchId));
+    }
+
+    @Test
+    void testDoSendMessageInTransactionWithException() throws Exception {
+
+        Message msg = new Message("testTopic", "testBody".getBytes());
+        long timeout = 3000L;
+        String xid = "testXid";
+        long branchId = 123L;
+
+        doThrow(new RuntimeException("Test 
exception")).when(producer).send(any(Message.class), anyLong());
+        doCallRealMethod().when(producer)
+            .doSendMessageInTransaction(any(Message.class), anyLong(), 
anyString(), anyLong());
+
+        assertThrows(MQClientException.class, () -> 
producer.doSendMessageInTransaction(msg, timeout, xid, branchId));
+    }
+
+    @Test
+    void testDoSendMessageInTransactionSuccess() throws Exception {
+
+        Message msg = new Message("testTopic", "testTag", "testKey", 
"testBody".getBytes());
+        long timeout = 3000L;
+        String xid = "testXid";
+        long branchId = 123L;
+
+        SendResult mockSendResult = new SendResult();
+        mockSendResult.setSendStatus(SendStatus.SEND_OK);
+        mockSendResult.setTransactionId("testTransactionId");
+
+        
doReturn(mockSendResult).when(seataMQProducer).superSend(any(Message.class), 
anyLong());
+
+        SendResult result = seataMQProducer.doSendMessageInTransaction(msg, 
timeout, xid, branchId);
+
+        assertNotNull(result);
+        assertEquals(SendStatus.SEND_OK, result.getSendStatus());
+        assertEquals("testTransactionId", 
msg.getUserProperty("__transactionId__"));
+        assertEquals("true", 
msg.getProperty(MessageConst.PROPERTY_TRANSACTION_PREPARED));
+        assertEquals(seataMQProducer.getProducerGroup(), 
msg.getProperty(MessageConst.PROPERTY_PRODUCER_GROUP));
+        assertEquals(xid, msg.getProperty(SeataMQProducer.PROPERTY_SEATA_XID));
+        assertEquals(String.valueOf(branchId), 
msg.getProperty(SeataMQProducer.PROPERTY_SEATA_BRANCHID));
+
+        verify(seataMQProducer).superSend(msg, timeout);
+    }
+
+    @Test
+    void testDoSendMessageInTransactionSendException() throws Exception {
+
+        Message msg = new Message("testTopic", "testTag", "testKey", 
"testBody".getBytes());
+        long timeout = 3000L;
+        String xid = "testXid";
+        long branchId = 123L;
+
+        doThrow(new RuntimeException("Send 
failed")).when(seataMQProducer).superSend(any(Message.class), anyLong());
+
+        assertThrows(MQClientException.class,
+            () -> seataMQProducer.doSendMessageInTransaction(msg, timeout, 
xid, branchId));
+
+        verify(seataMQProducer).superSend(msg, timeout);
+    }
+
+    @Test
+    void testDoSendMessageInTransactionSendStatusNotOk() throws Exception {
+
+        Message msg = new Message("testTopic", "testTag", "testKey", 
"testBody".getBytes());
+        long timeout = 3000L;
+        String xid = "testXid";
+        long branchId = 123L;
+
+        SendResult mockSendResult = new SendResult();
+        mockSendResult.setSendStatus(SendStatus.FLUSH_DISK_TIMEOUT);
+
+        
doReturn(mockSendResult).when(seataMQProducer).superSend(any(Message.class), 
anyLong());
+
+        assertThrows(RuntimeException.class,
+            () -> seataMQProducer.doSendMessageInTransaction(msg, timeout, 
xid, branchId));
+
+        verify(seataMQProducer).superSend(msg, timeout);
+    }
+
+    @Test
+    void testDoSendMessageInTransactionWithTransactionId() throws Exception {
+
+        Message msg = new Message("testTopic", "testTag", "testKey", 
"testBody".getBytes());
+        long timeout = 3000L;
+        String xid = "testXid";
+        long branchId = 123L;
+
+        SendResult mockSendResult = new SendResult();
+        mockSendResult.setSendStatus(SendStatus.SEND_OK);
+        mockSendResult.setTransactionId("testTransactionId");
+
+        MessageAccessor.putProperty(msg, 
MessageConst.PROPERTY_UNIQ_CLIENT_MESSAGE_ID_KEYIDX, "clientTransactionId");
+
+        
doReturn(mockSendResult).when(seataMQProducer).superSend(any(Message.class), 
anyLong());
+
+        SendResult result = seataMQProducer.doSendMessageInTransaction(msg, 
timeout, xid, branchId);
+
+        assertNotNull(result);
+        assertEquals(SendStatus.SEND_OK, result.getSendStatus());
+        assertEquals("testTransactionId", 
msg.getUserProperty("__transactionId__"));
+        assertEquals("clientTransactionId", msg.getTransactionId());
+
+        verify(seataMQProducer).superSend(msg, timeout);
+    }
+
+    @Test
+    void getTransactionListenerShouldReturnNonNullTransactionListener() {
+        TransactionListener transactionListener = 
producer.getTransactionListener();
+        assertNotNull(transactionListener, "TransactionListener should not be 
null");
+    }
+
 }
diff --git 
a/rocketmq/src/test/java/org/apache/seata/integration/rocketmq/TCCRocketMQImplTest.java
 
b/rocketmq/src/test/java/org/apache/seata/integration/rocketmq/TCCRocketMQImplTest.java
new file mode 100644
index 0000000000..af9ef218f2
--- /dev/null
+++ 
b/rocketmq/src/test/java/org/apache/seata/integration/rocketmq/TCCRocketMQImplTest.java
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.seata.integration.rocketmq;
+
+import java.lang.reflect.Field;
+import java.net.UnknownHostException;
+import java.util.concurrent.TimeoutException;
+import org.apache.rocketmq.client.exception.MQBrokerException;
+import org.apache.rocketmq.client.exception.MQClientException;
+import org.apache.rocketmq.client.impl.producer.DefaultMQProducerImpl;
+import org.apache.rocketmq.client.producer.LocalTransactionState;
+import org.apache.rocketmq.client.producer.SendResult;
+import org.apache.rocketmq.client.producer.SendStatus;
+import org.apache.rocketmq.common.message.Message;
+import org.apache.rocketmq.remoting.exception.RemotingException;
+import org.apache.seata.core.exception.TransactionException;
+import org.apache.seata.rm.tcc.api.BusinessActionContext;
+import org.apache.seata.rm.tcc.api.BusinessActionContextUtil;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockedStatic;
+import org.mockito.MockitoAnnotations;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.isNull;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * the type TCCRocketMQImpl
+ */
+public class TCCRocketMQImplTest {
+    @Mock
+    private SeataMQProducer producer;
+
+    @Mock
+    private DefaultMQProducerImpl producerImpl;
+    @Mock
+    private BusinessActionContext businessActionContext;
+
+    private TCCRocketMQImpl tccRocketMQ;
+    private TCCRocketMQImpl prepareTccRocketMQ;
+
+    @BeforeEach
+    void setUp() throws Exception {
+        MockitoAnnotations.openMocks(this);
+        tccRocketMQ = new TCCRocketMQImpl();
+        prepareTccRocketMQ = new TCCRocketMQImpl();
+
+        Field producerImplField = 
TCCRocketMQImpl.class.getDeclaredField("producerImpl");
+        producerImplField.setAccessible(true);
+        producerImplField.set(tccRocketMQ, producerImpl);
+        prepareTccRocketMQ.setProducer(producer);
+    }
+
+    @Test
+    void testPrepare() throws MQClientException {
+        MockedStatic<BusinessActionContextUtil> mockedStatic = 
mockStatic(BusinessActionContextUtil.class);
+        try {
+
+            Message message = new Message("testTopic", "testBody".getBytes());
+            long timeout = 3000L;
+            String xid = "testXid";
+            long branchId = 123L;
+
+            
mockedStatic.when(BusinessActionContextUtil::getContext).thenReturn(businessActionContext);
+            when(businessActionContext.getXid()).thenReturn(xid);
+            when(businessActionContext.getBranchId()).thenReturn(branchId);
+
+            SendResult mockSendResult = mock(SendResult.class);
+            
when(mockSendResult.getSendStatus()).thenReturn(SendStatus.SEND_OK);
+            when(producer.doSendMessageInTransaction(message, timeout, xid, 
branchId)).thenReturn(mockSendResult);
+
+            SendResult result = prepareTccRocketMQ.prepare(message, timeout);
+
+            assertNotNull(result);
+            assertEquals(SendStatus.SEND_OK, result.getSendStatus());
+            assertEquals(0, message.getDelayTimeLevel());
+
+            verify(producer).doSendMessageInTransaction(message, timeout, xid, 
branchId);
+            mockedStatic.verify(BusinessActionContextUtil::getContext, 
times(1));
+        } finally {
+            mockedStatic.close();
+        }
+    }
+
+    @Test
+    void testPrepareWithException() throws MQClientException {
+        MockedStatic<BusinessActionContextUtil> mockedStatic = 
mockStatic(BusinessActionContextUtil.class);
+        try {
+
+            Message message = new Message("testTopic", "testBody".getBytes());
+            long timeout = 3000L;
+            String xid = "testXid";
+            long branchId = 123L;
+
+            
mockedStatic.when(BusinessActionContextUtil::getContext).thenReturn(businessActionContext);
+            when(businessActionContext.getXid()).thenReturn(xid);
+            when(businessActionContext.getBranchId()).thenReturn(branchId);
+
+            when(producer.doSendMessageInTransaction(message, timeout, xid, 
branchId)).thenThrow(
+                new MQClientException("Test exception", null));
+
+            assertThrows(MQClientException.class, () -> 
prepareTccRocketMQ.prepare(message, timeout));
+
+            verify(producer).doSendMessageInTransaction(message, timeout, xid, 
branchId);
+            mockedStatic.verify(BusinessActionContextUtil::getContext, 
times(1));
+            mockedStatic.verify(() -> 
BusinessActionContextUtil.addContext(any()), never());
+        } finally {
+            mockedStatic.close();
+        }
+    }
+
+    @Test
+    void testCommitSuccess()
+        throws UnknownHostException, MQBrokerException, RemotingException, 
InterruptedException, TimeoutException,
+        TransactionException {
+
+        Message message = new Message("testTopic", "testBody".getBytes());
+        SendResult sendResult = mock(SendResult.class);
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(message);
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(sendResult);
+        when(businessActionContext.getXid()).thenReturn("testXid");
+        when(businessActionContext.getBranchId()).thenReturn(123L);
+
+        boolean result = tccRocketMQ.commit(businessActionContext);
+
+        assertTrue(result);
+        verify(producerImpl).endTransaction(eq(message), eq(sendResult), 
eq(LocalTransactionState.COMMIT_MESSAGE),
+            isNull());
+    }
+
+    @Test
+    void testCommitWithNullMessage() {
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(null);
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(
+            mock(SendResult.class));
+
+        assertThrows(TransactionException.class, () -> 
tccRocketMQ.commit(businessActionContext));
+    }
+
+    @Test
+    void testCommitWithNullSendResult() {
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(new Message());
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(null);
+
+        assertThrows(TransactionException.class, () -> 
tccRocketMQ.commit(businessActionContext));
+    }
+
+    @Test
+    void testCommitWithException()
+        throws UnknownHostException, MQBrokerException, RemotingException, 
InterruptedException, TimeoutException {
+
+        Message message = new Message("testTopic", "testBody".getBytes());
+        SendResult sendResult = mock(SendResult.class);
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(message);
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(sendResult);
+
+        doThrow(new MQBrokerException(1, "Test exception")).when(producerImpl)
+            .endTransaction(any(), any(), any(), any());
+
+        assertThrows(MQBrokerException.class, () -> 
tccRocketMQ.commit(businessActionContext));
+    }
+
+    @Test
+    void testRollbackSuccess()
+        throws UnknownHostException, MQBrokerException, RemotingException, 
InterruptedException, TransactionException {
+
+        Message message = new Message("testTopic", "testBody".getBytes());
+        SendResult sendResult = mock(SendResult.class);
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(message);
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(sendResult);
+        when(businessActionContext.getXid()).thenReturn("testXid");
+        when(businessActionContext.getBranchId()).thenReturn(123L);
+
+        boolean result = tccRocketMQ.rollback(businessActionContext);
+
+        assertTrue(result);
+        verify(producerImpl).endTransaction(eq(message), eq(sendResult), 
eq(LocalTransactionState.ROLLBACK_MESSAGE),
+            isNull());
+    }
+
+    @Test
+    void testRollbackWithNullMessage()
+        throws UnknownHostException, MQBrokerException, RemotingException, 
InterruptedException, TransactionException {
+
+        SendResult sendResult = mock(SendResult.class);
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(null);
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(sendResult);
+        when(businessActionContext.getXid()).thenReturn("testXid");
+        when(businessActionContext.getBranchId()).thenReturn(123L);
+
+        boolean result = tccRocketMQ.rollback(businessActionContext);
+
+        assertTrue(result);
+        verify(producerImpl).endTransaction(isNull(), eq(sendResult), 
eq(LocalTransactionState.ROLLBACK_MESSAGE),
+            isNull());
+    }
+
+    @Test
+    void testRollbackWithNullSendResult()
+        throws UnknownHostException, MQBrokerException, RemotingException, 
InterruptedException, TransactionException {
+
+        Message message = new Message("testTopic", "testBody".getBytes());
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(message);
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(null);
+        when(businessActionContext.getXid()).thenReturn("testXid");
+        when(businessActionContext.getBranchId()).thenReturn(123L);
+
+        boolean result = tccRocketMQ.rollback(businessActionContext);
+
+        assertTrue(result);
+        verify(producerImpl).endTransaction(eq(message), isNull(), 
eq(LocalTransactionState.ROLLBACK_MESSAGE),
+            isNull());
+    }
+
+    @Test
+    void testRollbackWithException()
+        throws UnknownHostException, MQBrokerException, RemotingException, 
InterruptedException {
+
+        Message message = new Message("testTopic", "testBody".getBytes());
+        SendResult sendResult = mock(SendResult.class);
+
+        when(businessActionContext.getActionContext("ROCKET_MSG", 
Message.class)).thenReturn(message);
+        when(businessActionContext.getActionContext("ROCKET_SEND_RESULT", 
SendResult.class)).thenReturn(sendResult);
+        when(businessActionContext.getXid()).thenReturn("testXid");
+        when(businessActionContext.getBranchId()).thenReturn(123L);
+
+        doThrow(new MQBrokerException(1, "Test exception")).when(producerImpl)
+            .endTransaction(any(), any(), any(), any());
+
+        assertThrows(MQBrokerException.class, () -> 
tccRocketMQ.rollback(businessActionContext));
+    }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to