This is an automated email from the ASF dual-hosted git repository.

dinglei pushed a commit to branch native
in repository https://gitbox.apache.org/repos/asf/rocketmq-client-go.git


The following commit(s) were added to refs/heads/native by this push:
     new 68358ec  Add transaction support. resolve #137 (#138)
68358ec is described below

commit 68358ecfc6ec1ee48b1c045f895a676e028c2bfa
Author: xujianhai666 <[email protected]>
AuthorDate: Wed Aug 14 14:41:22 2019 +0800

    Add transaction support. resolve #137 (#138)
---
 .travis.yml                                       |   3 -
 api.go                                            |  10 +
 consumer/consumer.go                              |   2 +-
 consumer/interceptor.go                           |   2 +-
 examples/producer/transaction/main.go             | 108 +++++++++
 internal/{utils/messagesysflag.go => callback.go} |  31 +--
 internal/client.go                                |  58 ++++-
 internal/model.go                                 |  31 ++-
 internal/model_test.go                            | 117 +++++++++
 internal/remote/remote_client.go                  |   4 +-
 internal/remote/remote_client_test.go             |  16 +-
 internal/request.go                               |  63 +++++
 internal/trace.go                                 |  54 +----
 internal/utils/set.go                             |  81 +++++++
 primitive/message.go                              | 280 +++++++++++++++++++++-
 primitive/result.go                               | 147 +-----------
 producer/interceptor.go                           |   2 +-
 producer/option.go                                |  16 +-
 producer/producer.go                              | 186 ++++++++++++--
 19 files changed, 947 insertions(+), 264 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index de0a42e..1a9683c 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -24,9 +24,6 @@ before_script:
   - perl -i -pe's/-Xms8g -Xmx8g -Xmn4g/-Xms2g -Xmx2g -Xmn1g/g' bin/runbroker.sh
   - nohup sh bin/mqnamesrv &
   - nohup sh bin/mqbroker -n localhost:9876 &
-# - sleep 10
-# - ./bin/mqadmin updateTopic -n ${NAME_SERVER_ADDRESS} -b ${BROKER_ADDRESS}  
-t ${TOPIC}
-# - ./bin/mqadmin updateSubGroup -n ${NAME_SERVER_ADDRESS} -b 
${BROKER_ADDRESS} -g ${GROUP}
 
 script:
   - cd ${GOPATH}/src/github.com/apache/rocketmq-client-go
diff --git a/api.go b/api.go
index eec98bf..c6f6798 100644
--- a/api.go
+++ b/api.go
@@ -37,6 +37,16 @@ func NewProducer(opts ...producer.Option) (Producer, error) {
        return producer.NewDefaultProducer(opts...)
 }
 
+type TransactionProducer interface {
+       Start() error
+       Shutdown() error
+       SendMessageInTransaction(context.Context, *primitive.Message) 
(*primitive.TransactionSendResult, error)
+}
+
+func NewTransactionProducer(listener primitive.TransactionListener, opts 
...producer.Option) (TransactionProducer, error) {
+       return producer.NewTransactionProducer(listener, opts...)
+}
+
 type PushConsumer interface {
        Start() error
        Shutdown() error
diff --git a/consumer/consumer.go b/consumer/consumer.go
index 25861f9..56836ae 100644
--- a/consumer/consumer.go
+++ b/consumer/consumer.go
@@ -275,7 +275,7 @@ func (dc *defaultConsumer) start() error {
                dc.subscriptionDataTable.Store(retryTopic, sub)
        }
 
-       dc.client = internal.GetOrNewRocketMQClient(dc.option.ClientOptions)
+       dc.client = internal.GetOrNewRocketMQClient(dc.option.ClientOptions, 
nil)
        if dc.model == Clustering {
                dc.option.ChangeInstanceNameToPID()
                dc.storage = NewRemoteOffsetStore(dc.consumerGroup, dc.client)
diff --git a/consumer/interceptor.go b/consumer/interceptor.go
index 6b050df..657633e 100644
--- a/consumer/interceptor.go
+++ b/consumer/interceptor.go
@@ -50,7 +50,7 @@ func newTraceInterceptor(traceCfg primitive.TraceConfig) 
primitive.Interceptor {
                beginT := time.Now()
                // before traceCtx
                traceCx := internal.TraceContext{
-                       RequestId: internal.CreateUniqID(),
+                       RequestId: primitive.CreateUniqID(),
                        TimeStamp: time.Now().UnixNano() / 
int64(time.Millisecond),
                        TraceType: internal.SubBefore,
                        GroupName: consumerCtx.ConsumerGroup,
diff --git a/examples/producer/transaction/main.go 
b/examples/producer/transaction/main.go
new file mode 100644
index 0000000..660db96
--- /dev/null
+++ b/examples/producer/transaction/main.go
@@ -0,0 +1,108 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package main
+
+import (
+       "context"
+       "fmt"
+       "os"
+       "strconv"
+       "sync"
+       "sync/atomic"
+       "time"
+
+       "github.com/apache/rocketmq-client-go"
+       "github.com/apache/rocketmq-client-go/primitive"
+       "github.com/apache/rocketmq-client-go/producer"
+)
+
+type DemoListener struct {
+       localTrans       *sync.Map
+       transactionIndex int32
+}
+
+func NewDemoListener() *DemoListener {
+       return &DemoListener{
+               localTrans: new(sync.Map),
+       }
+}
+
+func (dl *DemoListener) ExecuteLocalTransaction(msg primitive.Message) 
primitive.LocalTransactionState {
+       nextIndex := atomic.AddInt32(&dl.transactionIndex, 1)
+       fmt.Printf("nextIndex: %v for transactionID: %v\n", nextIndex, 
msg.TransactionId)
+       status := nextIndex % 3
+       dl.localTrans.Store(msg.TransactionId, 
primitive.LocalTransactionState(status+1))
+
+       fmt.Printf("dl")
+       return primitive.UnknowState
+}
+
+func (dl *DemoListener) CheckLocalTransaction(msg primitive.MessageExt) 
primitive.LocalTransactionState {
+       fmt.Printf("msg transactionID : %v\n", msg.TransactionId)
+       v, existed := dl.localTrans.Load(msg.TransactionId)
+       if !existed {
+               fmt.Printf("unknow msg: %v, return Commit", msg)
+               return primitive.CommitMessageState
+       }
+       state := v.(primitive.LocalTransactionState)
+       switch state {
+       case 1:
+               fmt.Printf("checkLocalTransaction COMMIT_MESSAGE: %v\n", msg)
+               return primitive.CommitMessageState
+       case 2:
+               fmt.Printf("checkLocalTransaction ROLLBACK_MESSAGE: %v\n", msg)
+               return primitive.RollbackMessageState
+       case 3:
+               fmt.Printf("checkLocalTransaction unknow: %v\n", msg)
+               return primitive.UnknowState
+       default:
+               fmt.Printf("checkLocalTransaction default COMMIT_MESSAGE: 
%v\n", msg)
+               return primitive.CommitMessageState
+       }
+
+       return primitive.UnknowState
+}
+
+func main() {
+       p, _ := rocketmq.NewTransactionProducer(
+               NewDemoListener(),
+               producer.WithNameServer([]string{"127.0.0.1:9876"}),
+               producer.WithRetry(1),
+       )
+       err := p.Start()
+       if err != nil {
+               fmt.Printf("start producer error: %s\n", err.Error())
+               os.Exit(1)
+       }
+
+       for i := 0; i < 10; i++ {
+               res, err := p.SendMessageInTransaction(context.Background(),
+                       primitive.NewMessage("TopicTest5", []byte("Hello 
RocketMQ again "+strconv.Itoa(i))))
+
+               if err != nil {
+                       fmt.Printf("send message error: %s\n", err)
+               } else {
+                       fmt.Printf("send message success: result=%s\n", 
res.String())
+               }
+       }
+       time.Sleep(5 * time.Minute)
+       err = p.Shutdown()
+       if err != nil {
+               fmt.Printf("shundown producer error: %s", err.Error())
+       }
+}
diff --git a/internal/utils/messagesysflag.go b/internal/callback.go
similarity index 61%
rename from internal/utils/messagesysflag.go
rename to internal/callback.go
index 2410a5b..2ff182c 100644
--- a/internal/utils/messagesysflag.go
+++ b/internal/callback.go
@@ -15,30 +15,17 @@ See the License for the specific language governing 
permissions and
 limitations under the License.
 */
 
-package utils
+package internal
 
-var (
-       CompressedFlag = 0x1
+import (
+       "net"
 
-       MultiTagsFlag = 0x1 << 1
-
-       TransactionNotType = 0
-
-       TransactionPreparedType = 0x1 << 2
-
-       TransactionCommitType = 0x2 << 2
-
-       TransactionRollbackType = 0x3 << 2
+       "github.com/apache/rocketmq-client-go/primitive"
 )
 
-func GetTransactionValue(flag int) int {
-       return flag & TransactionRollbackType
-}
-
-func ResetTransactionValue(flag int, typeFlag int) int {
-       return (flag & (^TransactionRollbackType)) | typeFlag
-}
-
-func ClearCompressedFlag(flag int) int {
-       return flag & (^CompressedFlag)
+// remotingClient callback TransactionProducer
+type CheckTransactionStateCallback struct {
+       Addr   net.Addr
+       Msg    primitive.MessageExt
+       Header CheckTransactionStateRequestHeader
 }
diff --git a/internal/client.go b/internal/client.go
index d58efc4..c6b3246 100644
--- a/internal/client.go
+++ b/internal/client.go
@@ -22,6 +22,7 @@ import (
        "context"
        "errors"
        "fmt"
+       "net"
        "os"
        "strconv"
        "strings"
@@ -165,18 +166,50 @@ type rmqClient struct {
 
 var clientMap sync.Map
 
-func GetOrNewRocketMQClient(option ClientOptions) *rmqClient {
+func GetOrNewRocketMQClient(option ClientOptions, callbackCh chan interface{}) 
*rmqClient {
        client := &rmqClient{
                option:       option,
                remoteClient: remote.NewRemotingClient(),
        }
        actual, loaded := clientMap.LoadOrStore(client.ClientID(), client)
        if !loaded {
-               
client.remoteClient.RegisterRequestFunc(ReqNotifyConsumerIdsChanged, func(req 
*remote.RemotingCommand) *remote.RemotingCommand {
+               
client.remoteClient.RegisterRequestFunc(ReqNotifyConsumerIdsChanged, func(req 
*remote.RemotingCommand, addr net.Addr) *remote.RemotingCommand {
                        rlog.Infof("receive broker's notification, the consumer 
group: %s", req.ExtFields["consumerGroup"])
                        client.RebalanceImmediately()
                        return nil
                })
+               
client.remoteClient.RegisterRequestFunc(ReqCheckTransactionState, func(req 
*remote.RemotingCommand, addr net.Addr) *remote.RemotingCommand {
+                       header := new(CheckTransactionStateRequestHeader)
+                       header.Decode(req.ExtFields)
+                       msgExts := primitive.DecodeMessage(req.Body)
+                       if len(msgExts) == 0 {
+                               rlog.Warn("checkTransactionState, decode 
message failed")
+                               return nil
+                       }
+                       msgExt := msgExts[0]
+                       // TODO: add namespace support
+                       transactionID := 
msgExt.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex]
+                       if len(transactionID) > 0 {
+                               msgExt.TransactionId = transactionID
+                       }
+                       group, existed := 
msgExt.Properties[primitive.PropertyProducerGroup]
+                       if !existed {
+                               rlog.Warn("checkTransactionState, pick producer 
group failed")
+                               return nil
+                       }
+                       if option.GroupName != group {
+                               rlog.Warn("producer group is not equal.")
+                               return nil
+                       }
+                       callback := CheckTransactionStateCallback{
+                               Addr:   addr,
+                               Msg:    *msgExt,
+                               Header: *header,
+                       }
+                       callbackCh <- callback
+                       return nil
+               })
+
        }
        return actual.(*rmqClient)
 }
@@ -283,31 +316,30 @@ func (c *rmqClient) CheckClientInBroker() {
 func (c *rmqClient) SendHeartbeatToAllBrokerWithLock() {
        c.hbMutex.Lock()
        defer c.hbMutex.Unlock()
-       hbData := &heartbeatData{
-               ClientId: c.ClientID(),
-       }
-       pData := make([]producerData, 0)
+       hbData := NewHeartbeatData(c.ClientID())
+
        c.producerMap.Range(func(key, value interface{}) bool {
-               pData = append(pData, producerData(key.(string)))
+               pData := producerData{
+                       GroupName: key.(string),
+               }
+               hbData.ProducerDatas.Add(pData)
                return true
        })
 
-       cData := make([]consumerData, 0)
        c.consumerMap.Range(func(key, value interface{}) bool {
                consumer := value.(InnerConsumer)
-               cData = append(cData, consumerData{
+               cData := consumerData{
                        GroupName:         key.(string),
                        CType:             "PUSH",
                        MessageModel:      "CLUSTERING",
                        Where:             "CONSUME_FROM_FIRST_OFFSET",
                        UnitMode:          consumer.IsUnitMode(),
                        SubscriptionDatas: consumer.SubscriptionDataList(),
-               })
+               }
+               hbData.ConsumerDatas.Add(cData)
                return true
        })
-       hbData.ProducerDatas = pData
-       hbData.ConsumerDatas = cData
-       if len(pData) == 0 && len(cData) == 0 {
+       if hbData.ProducerDatas.Len() == 0 && hbData.ConsumerDatas.Len() == 0 {
                rlog.Info("sending heartbeat, but no producer and no consumer")
                return
        }
diff --git a/internal/model.go b/internal/model.go
index 00cfe63..f534234 100644
--- a/internal/model.go
+++ b/internal/model.go
@@ -20,6 +20,7 @@ package internal
 import (
        "encoding/json"
 
+       "github.com/apache/rocketmq-client-go/internal/utils"
        "github.com/apache/rocketmq-client-go/rlog"
 )
 
@@ -31,8 +32,6 @@ type FindBrokerResult struct {
 
 type (
        // groupName of consumer
-       producerData string
-
        consumeType string
 
        ServiceState int
@@ -55,6 +54,14 @@ type SubscriptionData struct {
        ExpType         string
 }
 
+type producerData struct {
+       GroupName string `json:"groupName"`
+}
+
+func (p producerData) UniqueID() string {
+       return p.GroupName
+}
+
 type consumerData struct {
        GroupName         string              `json:"groupName"`
        CType             consumeType         `json:"consumeType"`
@@ -64,10 +71,22 @@ type consumerData struct {
        UnitMode          bool                `json:"unitMode"`
 }
 
+func (c consumerData) UniqueID() string {
+       return c.GroupName
+}
+
 type heartbeatData struct {
-       ClientId      string         `json:"clientID"`
-       ProducerDatas []producerData `json:"producerDataSet"`
-       ConsumerDatas []consumerData `json:"consumerDataSet"`
+       ClientId      string    `json:"clientID"`
+       ProducerDatas utils.Set `json:"producerDataSet"`
+       ConsumerDatas utils.Set `json:"consumerDataSet"`
+}
+
+func NewHeartbeatData(clientID string) *heartbeatData {
+       return &heartbeatData{
+               ClientId:      clientID,
+               ProducerDatas: utils.NewSet(),
+               ConsumerDatas: utils.NewSet(),
+       }
 }
 
 func (data *heartbeatData) encode() []byte {
@@ -76,6 +95,6 @@ func (data *heartbeatData) encode() []byte {
                rlog.Errorf("marshal heartbeatData error: %s", err.Error())
                return nil
        }
-       rlog.Info(string(d))
+       rlog.Info("heartbeat: " + string(d))
        return d
 }
diff --git a/internal/model_test.go b/internal/model_test.go
new file mode 100644
index 0000000..1dac4ec
--- /dev/null
+++ b/internal/model_test.go
@@ -0,0 +1,117 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package internal
+
+import (
+       "encoding/json"
+       "fmt"
+       "testing"
+
+       . "github.com/smartystreets/goconvey/convey"
+       "github.com/stretchr/testify/assert"
+
+       "github.com/apache/rocketmq-client-go/internal/utils"
+)
+
+func TestHeartbeatData(t *testing.T) {
+       Convey("test heatbeat json", t, func() {
+
+               Convey("producerData set marshal", func() {
+                       pData := &producerData{
+                               GroupName: "group name",
+                       }
+                       pData2 := &producerData{
+                               GroupName: "group name 2",
+                       }
+                       set := utils.NewSet()
+                       set.Add(pData)
+                       set.Add(pData2)
+
+                       v, err := json.Marshal(set)
+                       assert.Nil(t, err)
+                       fmt.Printf("json producer set: %s", string(v))
+               })
+
+               Convey("producer heatbeat", func() {
+
+                       hbt := NewHeartbeatData("producer client id")
+                       p1 := &producerData{
+                               GroupName: "group name",
+                       }
+                       p2 := &producerData{
+                               GroupName: "group name 2",
+                       }
+
+                       hbt.ProducerDatas.Add(p1)
+                       hbt.ProducerDatas.Add(p2)
+
+                       v, err := json.Marshal(hbt)
+                       //ShouldBeNil(t, err)
+                       assert.Nil(t, err)
+                       fmt.Printf("json producer: %s\n", string(v))
+               })
+
+               Convey("consumer heartbeat", func() {
+
+                       hbt := NewHeartbeatData("consumer client id")
+                       c1 := consumerData{
+                               GroupName: "consumer data 1",
+                       }
+                       c2 := consumerData{
+                               GroupName: "consumer data 2",
+                       }
+                       hbt.ConsumerDatas.Add(c1)
+                       hbt.ConsumerDatas.Add(c2)
+
+                       v, err := json.Marshal(hbt)
+                       //ShouldBeNil(t, err)
+                       assert.Nil(t, err)
+                       fmt.Printf("json consumer: %s\n", string(v))
+               })
+
+               Convey("producer & consumer heartbeat", func() {
+
+                       hbt := NewHeartbeatData("consumer client id")
+
+                       p1 := &producerData{
+                               GroupName: "group name",
+                       }
+                       p2 := &producerData{
+                               GroupName: "group name 2",
+                       }
+
+                       hbt.ProducerDatas.Add(p1)
+                       hbt.ProducerDatas.Add(p2)
+
+                       c1 := consumerData{
+                               GroupName: "consumer data 1",
+                       }
+                       c2 := consumerData{
+                               GroupName: "consumer data 2",
+                       }
+                       hbt.ConsumerDatas.Add(c1)
+                       hbt.ConsumerDatas.Add(c2)
+
+                       v, err := json.Marshal(hbt)
+                       //ShouldBeNil(t, err)
+                       assert.Nil(t, err)
+                       fmt.Printf("json producer & consumer: %s\n", string(v))
+               })
+       })
+
+}
diff --git a/internal/remote/remote_client.go b/internal/remote/remote_client.go
index cfeb909..e8f10c2 100644
--- a/internal/remote/remote_client.go
+++ b/internal/remote/remote_client.go
@@ -31,7 +31,7 @@ import (
        "github.com/apache/rocketmq-client-go/rlog"
 )
 
-type ClientRequestFunc func(*RemotingCommand) *RemotingCommand
+type ClientRequestFunc func(*RemotingCommand, net.Addr) *RemotingCommand
 
 type TcpOption struct {
        // TODO
@@ -148,7 +148,7 @@ func (c *RemotingClient) receiveResponse(r net.Conn) {
                        f := c.processors[cmd.Code]
                        if f != nil {
                                go func() { // 单个goroutine会造成死锁
-                                       res := f(cmd)
+                                       res := f(cmd, r.RemoteAddr())
                                        if res != nil {
                                                err := c.sendRequest(r, res)
                                                if err != nil {
diff --git a/internal/remote/remote_client_test.go 
b/internal/remote/remote_client_test.go
index ff6900e..9cbb117 100644
--- a/internal/remote/remote_client_test.go
+++ b/internal/remote/remote_client_test.go
@@ -167,7 +167,12 @@ func TestInvokeSync(t *testing.T) {
        var wg sync.WaitGroup
        wg.Add(1)
        client := NewRemotingClient()
+
+       var clientSend sync.WaitGroup // blocking client send message until the 
server listen success.
+       clientSend.Add(1)
+
        go func() {
+               clientSend.Wait()
                receiveCommand, err := client.InvokeSync(addr,
                        clientSendRemtingCommand, time.Second)
                if err != nil {
@@ -189,6 +194,7 @@ func TestInvokeSync(t *testing.T) {
                t.Fatal(err)
        }
        defer l.Close()
+       clientSend.Done()
        for {
                conn, err := l.Accept()
                if err != nil {
@@ -337,7 +343,11 @@ func TestInvokeOneWay(t *testing.T) {
        var wg sync.WaitGroup
        wg.Add(1)
        client := NewRemotingClient()
+
+       var clientSend sync.WaitGroup // blocking client send message until the 
server listen success.
+       clientSend.Add(1)
        go func() {
+               clientSend.Wait()
                err := client.InvokeOneWay(addr, clientSendRemtingCommand, 
3*time.Second)
                if err != nil {
                        t.Fatalf("failed to invoke synchronous. %s", err)
@@ -350,6 +360,7 @@ func TestInvokeOneWay(t *testing.T) {
                t.Fatal(err)
        }
        defer l.Close()
+       clientSend.Done()
        for {
                conn, err := l.Accept()
                if err != nil {
@@ -366,8 +377,9 @@ func TestInvokeOneWay(t *testing.T) {
                                t.Errorf("wrong code. want=%d, got=%d", 
receivedRemotingCommand.Code,
                                        clientSendRemtingCommand.Code)
                        }
-                       return
+                       goto done
                }
        }
-       wg.Done()
+done:
+       wg.Wait()
 }
diff --git a/internal/request.go b/internal/request.go
index e09a986..3dce875 100644
--- a/internal/request.go
+++ b/internal/request.go
@@ -32,6 +32,7 @@ const (
        ReqGetMaxOffset             = int16(30)
        ReqHeartBeat                = int16(34)
        ReqConsumerSendMsgBack      = int16(36)
+       ReqENDTransaction           = int16(37)
        ReqGetConsumerListByGroup   = int16(38)
        ReqLockBatchMQ              = int16(41)
        ReqUnlockBatchMQ            = int16(42)
@@ -81,6 +82,68 @@ func (request *SendMessageRequest) Decode(properties 
map[string]string) error {
        return nil
 }
 
+type EndTransactionRequestHeader struct {
+       ProducerGroup        string `json:"producerGroup"`
+       TranStateTableOffset int64  `json:"tranStateTableOffset"`
+       CommitLogOffset      int64  `json:"commitLogOffset"`
+       CommitOrRollback     int    `json:"commitOrRollback"`
+       FromTransactionCheck bool   `json:"fromTransactionCheck"`
+       MsgID                string `json:"msgId"`
+       TransactionId        string `json:"transactionId"`
+}
+
+func (request *EndTransactionRequestHeader) Encode() map[string]string {
+       maps := make(map[string]string)
+       maps["producerGroup"] = request.ProducerGroup
+       maps["tranStateTableOffset"] = 
strconv.FormatInt(request.TranStateTableOffset, 10)
+       maps["commitLogOffset"] = strconv.Itoa(int(request.CommitLogOffset))
+       maps["commitOrRollback"] = strconv.Itoa(request.CommitOrRollback)
+       maps["fromTransactionCheck"] = 
strconv.FormatBool(request.FromTransactionCheck)
+       maps["msgId"] = request.MsgID
+       maps["transactionId"] = request.TransactionId
+       return maps
+}
+
+type CheckTransactionStateRequestHeader struct {
+       TranStateTableOffset int64
+       CommitLogOffset      int64
+       MsgId                string
+       TransactionId        string
+       OffsetMsgId          string
+}
+
+func (request *CheckTransactionStateRequestHeader) Encode() map[string]string {
+       maps := make(map[string]string)
+       maps["tranStateTableOffset"] = 
strconv.FormatInt(request.TranStateTableOffset, 10)
+       maps["commitLogOffset"] = strconv.FormatInt(request.CommitLogOffset, 10)
+       maps["msgId"] = request.MsgId
+       maps["transactionId"] = request.TransactionId
+       maps["offsetMsgId"] = request.OffsetMsgId
+
+       return maps
+}
+
+func (request *CheckTransactionStateRequestHeader) Decode(ext 
map[string]string) {
+       if len(ext) == 0 {
+               return
+       }
+       if v, existed := ext["tranStateTableOffset"]; existed {
+               request.TranStateTableOffset, _ = strconv.ParseInt(v, 10, 0)
+       }
+       if v, existed := ext["commitLogOffset"]; existed {
+               request.CommitLogOffset, _ = strconv.ParseInt(v, 10, 0)
+       }
+       if v, existed := ext["msgId"]; existed {
+               request.MsgId = v
+       }
+       if v, existed := ext["transactionId"]; existed {
+               request.MsgId = v
+       }
+       if v, existed := ext["offsetMsgId"]; existed {
+               request.MsgId = v
+       }
+}
+
 type ConsumerSendMsgBackRequest struct {
        Group             string `json:"group"`
        Offset            int64  `json:"offset"`
diff --git a/internal/trace.go b/internal/trace.go
index dc3b246..ebd3be5 100644
--- a/internal/trace.go
+++ b/internal/trace.go
@@ -20,70 +20,18 @@ package internal
 import (
        "bytes"
        "context"
-       "encoding/binary"
-       "encoding/hex"
        "fmt"
-       "os"
        "runtime"
        "strconv"
        "strings"
-       "sync"
        "sync/atomic"
        "time"
 
        "github.com/apache/rocketmq-client-go/internal/remote"
-       "github.com/apache/rocketmq-client-go/internal/utils"
        "github.com/apache/rocketmq-client-go/primitive"
        "github.com/apache/rocketmq-client-go/rlog"
 )
 
-var (
-       counter        int16 = 0
-       startTimestamp int64 = 0
-       nextTimestamp  int64 = 0
-       prefix         string
-       locker         sync.Mutex
-       classLoadId    int32 = 0
-)
-
-func init() {
-       buf := new(bytes.Buffer)
-
-       ip, err := utils.ClientIP4()
-       if err != nil {
-               ip = utils.FakeIP()
-       }
-       _, _ = buf.Write(ip)
-       _ = binary.Write(buf, binary.BigEndian, Pid())
-       _ = binary.Write(buf, binary.BigEndian, classLoadId)
-       prefix = strings.ToUpper(hex.EncodeToString(buf.Bytes()))
-}
-
-func CreateUniqID() string {
-       locker.Lock()
-       defer locker.Unlock()
-
-       if time.Now().Unix() > nextTimestamp {
-               updateTimestamp()
-       }
-       counter++
-       buf := new(bytes.Buffer)
-       _ = binary.Write(buf, binary.BigEndian, 
int32((time.Now().Unix()-startTimestamp)*1000))
-       _ = binary.Write(buf, binary.BigEndian, counter)
-
-       return prefix + hex.EncodeToString(buf.Bytes())
-}
-
-func updateTimestamp() {
-       year, month := time.Now().Year(), time.Now().Month()
-       startTimestamp = time.Date(year, month, 1, 0, 0, 0, 0, 
time.Local).Unix()
-       nextTimestamp = time.Date(year, month, 1, 0, 0, 0, 0, 
time.Local).AddDate(0, 1, 0).Unix()
-}
-
-func Pid() int16 {
-       return int16(os.Getpid())
-}
-
 type TraceBean struct {
        Topic       string
        MsgId       string
@@ -276,7 +224,7 @@ func NewTraceDispatcher(traceTopic string, access 
primitive.AccessChannel) *trac
 
        cliOp := DefaultClientOptions()
        cliOp.RetryTimes = 0
-       cli := GetOrNewRocketMQClient(cliOp)
+       cli := GetOrNewRocketMQClient(cliOp, nil)
        return &traceDispatcher{
                ctx:    ctx,
                cancel: cancel,
diff --git a/internal/utils/set.go b/internal/utils/set.go
new file mode 100644
index 0000000..2f9f214
--- /dev/null
+++ b/internal/utils/set.go
@@ -0,0 +1,81 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package utils
+
+import (
+       "bytes"
+       "encoding/json"
+       "sort"
+)
+
+type UniqueItem interface {
+       UniqueID() string
+}
+
+type Set struct {
+       items map[string]UniqueItem
+}
+
+func NewSet() Set {
+       return Set{
+               items: make(map[string]UniqueItem, 0),
+       }
+}
+
+func (s *Set) Add(v UniqueItem) {
+       s.items[v.UniqueID()] = v
+}
+
+func (s *Set) Len() int {
+       return len(s.items)
+}
+
+var _ json.Marshaler = &Set{}
+
+func (s *Set) MarshalJSON() ([]byte, error) {
+       if len(s.items) == 0 {
+               return []byte("[]"), nil
+       }
+
+       buffer := new(bytes.Buffer)
+       buffer.WriteByte('[')
+       keys := make([]string, 0)
+       for _, k := range s.items {
+               v, err := json.Marshal(k)
+               if err != nil {
+                       return nil, err
+               }
+               keys = append(keys, string(v))
+       }
+       sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
+
+       for i, key := range keys {
+               if i > 0 {
+                       buffer.WriteByte(',')
+               }
+               buffer.WriteString(key)
+       }
+
+       buffer.WriteByte(']')
+
+       return buffer.Bytes(), nil
+}
+
+func (s Set) UnmarshalJSON(data []byte) (err error) {
+       return nil
+}
diff --git a/primitive/message.go b/primitive/message.go
index 28c9eab..3dd4a0c 100644
--- a/primitive/message.go
+++ b/primitive/message.go
@@ -18,9 +18,17 @@ limitations under the License.
 package primitive
 
 import (
+       "bytes"
+       "encoding/binary"
+       "encoding/hex"
        "fmt"
+       "os"
        "strconv"
        "strings"
+       "sync"
+       "time"
+
+       "github.com/pkg/errors"
 
        "github.com/apache/rocketmq-client-go/internal/utils"
 )
@@ -68,11 +76,13 @@ type Message struct {
 }
 
 func NewMessage(topic string, body []byte) *Message {
-       return &Message{
+       msg := &Message{
                Topic:      topic,
                Body:       body,
                Properties: make(map[string]string),
        }
+       msg.Properties[PropertyWaitStoreMsgOk] = strconv.FormatBool(true)
+       return msg
 }
 
 // SetDelayTimeLevel set message delay time to consume.
@@ -164,6 +174,135 @@ func (msgExt *MessageExt) String() string {
                msgExt.PreparedTransactionOffset)
 }
 
+func DecodeMessage(data []byte) []*MessageExt {
+       msgs := make([]*MessageExt, 0)
+       buf := bytes.NewBuffer(data)
+       count := 0
+       for count < len(data) {
+               msg := &MessageExt{}
+
+               // 1. total size
+               binary.Read(buf, binary.BigEndian, &msg.StoreSize)
+               count += 4
+
+               // 2. magic code
+               buf.Next(4)
+               count += 4
+
+               // 3. body CRC32
+               binary.Read(buf, binary.BigEndian, &msg.BodyCRC)
+               count += 4
+
+               // 4. queueID
+               binary.Read(buf, binary.BigEndian, &msg.QueueId)
+               count += 4
+
+               // 5. Flag
+               binary.Read(buf, binary.BigEndian, &msg.Flag)
+               count += 4
+
+               // 6. QueueOffset
+               binary.Read(buf, binary.BigEndian, &msg.QueueOffset)
+               count += 8
+
+               // 7. physical offset
+               binary.Read(buf, binary.BigEndian, &msg.CommitLogOffset)
+               count += 8
+
+               // 8. SysFlag
+               binary.Read(buf, binary.BigEndian, &msg.SysFlag)
+               count += 4
+
+               // 9. BornTimestamp
+               binary.Read(buf, binary.BigEndian, &msg.BornTimestamp)
+               count += 8
+
+               // 10. born host
+               hostBytes := buf.Next(4)
+               var port int32
+               binary.Read(buf, binary.BigEndian, &port)
+               msg.BornHost = fmt.Sprintf("%s:%d", 
utils.GetAddressByBytes(hostBytes), port)
+               count += 8
+
+               // 11. store timestamp
+               binary.Read(buf, binary.BigEndian, &msg.StoreTimestamp)
+               count += 8
+
+               // 12. store host
+               hostBytes = buf.Next(4)
+               binary.Read(buf, binary.BigEndian, &port)
+               msg.StoreHost = fmt.Sprintf("%s:%d", 
utils.GetAddressByBytes(hostBytes), port)
+               count += 8
+
+               // 13. reconsume times
+               binary.Read(buf, binary.BigEndian, &msg.ReconsumeTimes)
+               count += 4
+
+               // 14. prepared transaction offset
+               binary.Read(buf, binary.BigEndian, 
&msg.PreparedTransactionOffset)
+               count += 8
+
+               // 15. body
+               var length int32
+               binary.Read(buf, binary.BigEndian, &length)
+               msg.Body = buf.Next(int(length))
+               if (msg.SysFlag & FlagCompressed) == FlagCompressed {
+                       msg.Body = utils.UnCompress(msg.Body)
+               }
+               count += 4 + int(length)
+
+               // 16. topic
+               _byte, _ := buf.ReadByte()
+               msg.Topic = string(buf.Next(int(_byte)))
+               count += 1 + int(_byte)
+
+               // 17. properties
+               var propertiesLength int16
+               binary.Read(buf, binary.BigEndian, &propertiesLength)
+               if propertiesLength > 0 {
+                       msg.Properties = 
unmarshalProperties(buf.Next(int(propertiesLength)))
+               }
+               count += 2 + int(propertiesLength)
+
+               msg.MsgId = createMessageId(hostBytes, port, 
msg.CommitLogOffset)
+               //count += 16
+               if msg.Properties == nil {
+                       msg.Properties = make(map[string]string, 0)
+               }
+               msgs = append(msgs, msg)
+       }
+
+       return msgs
+}
+
+// unmarshalProperties parse data into property kv pairs.
+func unmarshalProperties(data []byte) map[string]string {
+       m := make(map[string]string)
+       items := bytes.Split(data, []byte{propertySeparator})
+       for _, item := range items {
+               kv := bytes.Split(item, []byte{nameValueSeparator})
+               if len(kv) == 2 {
+                       m[string(kv[0])] = string(kv[1])
+               }
+       }
+       return m
+}
+
+func MarshalPropeties(properties map[string]string) string {
+       if properties == nil {
+               return ""
+       }
+       buffer := bytes.NewBufferString("")
+
+       for k, v := range properties {
+               buffer.WriteString(k)
+               buffer.WriteRune(nameValueSeparator)
+               buffer.WriteString(v)
+               buffer.WriteRune(propertySeparator)
+       }
+       return buffer.String()
+}
+
 // MessageQueue message queue
 type MessageQueue struct {
        Topic      string `json:"topic"`
@@ -206,3 +345,142 @@ const (
        TransMsgCommit
        DelayMsg
 )
+
+type LocalTransactionState int
+
+const (
+       CommitMessageState LocalTransactionState = iota + 1
+       RollbackMessageState
+       UnknowState
+)
+
+type TransactionListener interface {
+       //  When send transactional prepare(half) message succeed, this method 
will be invoked to execute local transaction.
+       ExecuteLocalTransaction(Message) LocalTransactionState
+
+       // When no response to prepare(half) message. broker will send check 
message to check the transaction status, and this
+       // method will be invoked to get local transaction status.
+       CheckLocalTransaction(MessageExt) LocalTransactionState
+}
+
+type MessageID struct {
+       Addr   string
+       Port   int
+       Offset int64
+}
+
+func createMessageId(addr []byte, port int32, offset int64) string {
+       buffer := new(bytes.Buffer)
+       buffer.Write(addr)
+       binary.Write(buffer, binary.BigEndian, port)
+       binary.Write(buffer, binary.BigEndian, offset)
+       return strings.ToUpper(hex.EncodeToString(buffer.Bytes()))
+}
+
+func UnmarshalMsgID(msgID []byte) (*MessageID, error) {
+       if len(msgID) < 32 {
+               return nil, errors.Errorf("%s len < 32", string(msgID))
+       }
+       ip := make([]byte, 8)
+       port := make([]byte, 8)
+       offset := make([]byte, 16)
+       var portVal int
+       var offsetVal int64
+
+       _, err := hex.Decode(ip, msgID[0:8])
+       if err != nil {
+               _, err = hex.Decode(port, msgID[8:16])
+       }
+       if err != nil {
+               _, err = hex.Decode(offset, msgID[16:32])
+       }
+       if err != nil {
+               portVal, err = strconv.Atoi(string(port))
+       }
+       if err != nil {
+               offsetVal, err = strconv.ParseInt(string(offset), 10, 0)
+       }
+
+       if err != nil {
+               return nil, err
+       }
+
+       return &MessageID{
+               Addr:   string(ip),
+               Port:   portVal,
+               Offset: offsetVal,
+       }, nil
+}
+
+var (
+       CompressedFlag = 0x1
+
+       MultiTagsFlag = 0x1 << 1
+
+       TransactionNotType = 0
+
+       TransactionPreparedType = 0x1 << 2
+
+       TransactionCommitType = 0x2 << 2
+
+       TransactionRollbackType = 0x3 << 2
+)
+
+func GetTransactionValue(flag int) int {
+       return flag & TransactionRollbackType
+}
+
+func ResetTransactionValue(flag int, typeFlag int) int {
+       return (flag & (^TransactionRollbackType)) | typeFlag
+}
+
+func ClearCompressedFlag(flag int) int {
+       return flag & (^CompressedFlag)
+}
+
+var (
+       counter        int16 = 0
+       startTimestamp int64 = 0
+       nextTimestamp  int64 = 0
+       prefix         string
+       locker         sync.Mutex
+       classLoadId    int32 = 0
+)
+
+func init() {
+       buf := new(bytes.Buffer)
+
+       ip, err := utils.ClientIP4()
+       if err != nil {
+               ip = utils.FakeIP()
+       }
+       _, _ = buf.Write(ip)
+       _ = binary.Write(buf, binary.BigEndian, Pid())
+       _ = binary.Write(buf, binary.BigEndian, classLoadId)
+       prefix = strings.ToUpper(hex.EncodeToString(buf.Bytes()))
+}
+
+func CreateUniqID() string {
+       locker.Lock()
+       defer locker.Unlock()
+
+       if time.Now().Unix() > nextTimestamp {
+               updateTimestamp()
+       }
+       counter++
+       buf := new(bytes.Buffer)
+       _ = binary.Write(buf, binary.BigEndian, 
int32((time.Now().Unix()-startTimestamp)*1000))
+       _ = binary.Write(buf, binary.BigEndian, counter)
+
+       return prefix + hex.EncodeToString(buf.Bytes())
+}
+
+func updateTimestamp() {
+       year, month := time.Now().Year(), time.Now().Month()
+       startTimestamp = time.Date(year, month, 1, 0, 0, 0, 0, 
time.Local).Unix()
+       nextTimestamp = time.Date(year, month, 1, 0, 0, 0, 0, 
time.Local).AddDate(0, 1, 0).Unix()
+}
+
+func Pid() int16 {
+       return int16(os.Getpid())
+}
diff --git a/primitive/result.go b/primitive/result.go
index 4664cf0..201b36a 100644
--- a/primitive/result.go
+++ b/primitive/result.go
@@ -18,13 +18,7 @@ limitations under the License.
 package primitive
 
 import (
-       "bytes"
-       "encoding/binary"
-       "encoding/hex"
        "fmt"
-       "strings"
-
-       "github.com/apache/rocketmq-client-go/internal/utils"
 )
 
 // SendStatus of message
@@ -62,6 +56,12 @@ func (result *SendResult) String() string {
                result.Status, result.MsgID, result.OffsetMsgID, 
result.QueueOffset, result.MessageQueue.String())
 }
 
+// SendResult RocketMQ send result
+type TransactionSendResult struct {
+       *SendResult
+       State LocalTransactionState
+}
+
 // PullStatus pull Status
 type PullStatus int
 
@@ -115,141 +115,6 @@ func (result *PullResult) String() string {
        return ""
 }
 
-func DecodeMessage(data []byte) []*MessageExt {
-       msgs := make([]*MessageExt, 0)
-       buf := bytes.NewBuffer(data)
-       count := 0
-       for count < len(data) {
-               msg := &MessageExt{}
-
-               // 1. total size
-               binary.Read(buf, binary.BigEndian, &msg.StoreSize)
-               count += 4
-
-               // 2. magic code
-               buf.Next(4)
-               count += 4
-
-               // 3. body CRC32
-               binary.Read(buf, binary.BigEndian, &msg.BodyCRC)
-               count += 4
-
-               // 4. queueID
-               binary.Read(buf, binary.BigEndian, &msg.QueueId)
-               count += 4
-
-               // 5. Flag
-               binary.Read(buf, binary.BigEndian, &msg.Flag)
-               count += 4
-
-               // 6. QueueOffset
-               binary.Read(buf, binary.BigEndian, &msg.QueueOffset)
-               count += 8
-
-               // 7. physical offset
-               binary.Read(buf, binary.BigEndian, &msg.CommitLogOffset)
-               count += 8
-
-               // 8. SysFlag
-               binary.Read(buf, binary.BigEndian, &msg.SysFlag)
-               count += 4
-
-               // 9. BornTimestamp
-               binary.Read(buf, binary.BigEndian, &msg.BornTimestamp)
-               count += 8
-
-               // 10. born host
-               hostBytes := buf.Next(4)
-               var port int32
-               binary.Read(buf, binary.BigEndian, &port)
-               msg.BornHost = fmt.Sprintf("%s:%d", 
utils.GetAddressByBytes(hostBytes), port)
-               count += 8
-
-               // 11. store timestamp
-               binary.Read(buf, binary.BigEndian, &msg.StoreTimestamp)
-               count += 8
-
-               // 12. store host
-               hostBytes = buf.Next(4)
-               binary.Read(buf, binary.BigEndian, &port)
-               msg.StoreHost = fmt.Sprintf("%s:%d", 
utils.GetAddressByBytes(hostBytes), port)
-               count += 8
-
-               // 13. reconsume times
-               binary.Read(buf, binary.BigEndian, &msg.ReconsumeTimes)
-               count += 4
-
-               // 14. prepared transaction offset
-               binary.Read(buf, binary.BigEndian, 
&msg.PreparedTransactionOffset)
-               count += 8
-
-               // 15. body
-               var length int32
-               binary.Read(buf, binary.BigEndian, &length)
-               msg.Body = buf.Next(int(length))
-               if (msg.SysFlag & FlagCompressed) == FlagCompressed {
-                       msg.Body = utils.UnCompress(msg.Body)
-               }
-               count += 4 + int(length)
-
-               // 16. topic
-               _byte, _ := buf.ReadByte()
-               msg.Topic = string(buf.Next(int(_byte)))
-               count += 1 + int(_byte)
-
-               // 17. properties
-               var propertiesLength int16
-               binary.Read(buf, binary.BigEndian, &propertiesLength)
-               if propertiesLength > 0 {
-                       msg.Properties = 
unmarshalProperties(buf.Next(int(propertiesLength)))
-               }
-               count += 2 + int(propertiesLength)
-
-               msg.MsgId = createMessageId(hostBytes, port, 
msg.CommitLogOffset)
-               //count += 16
-
-               msgs = append(msgs, msg)
-       }
-
-       return msgs
-}
-
-func createMessageId(addr []byte, port int32, offset int64) string {
-       buffer := new(bytes.Buffer)
-       buffer.Write(addr)
-       binary.Write(buffer, binary.BigEndian, port)
-       binary.Write(buffer, binary.BigEndian, offset)
-       return strings.ToUpper(hex.EncodeToString(buffer.Bytes()))
-}
-
-// unmarshalProperties parse data into property kv pairs.
-func unmarshalProperties(data []byte) map[string]string {
-       m := make(map[string]string)
-       items := bytes.Split(data, []byte{propertySeparator})
-       for _, item := range items {
-               kv := bytes.Split(item, []byte{nameValueSeparator})
-               if len(kv) == 2 {
-                       m[string(kv[0])] = string(kv[1])
-               }
-       }
-       return m
-}
-
-func MarshalPropeties(properties map[string]string) string {
-       if properties == nil {
-               return ""
-       }
-       buffer := bytes.NewBufferString("")
-
-       for k, v := range properties {
-               buffer.WriteString(k)
-               buffer.WriteRune(nameValueSeparator)
-               buffer.WriteString(v)
-               buffer.WriteRune(propertySeparator)
-       }
-       return buffer.String()
-}
-
 func toMessages(messageExts []*MessageExt) []*Message {
        msgs := make([]*Message, 0)
 
diff --git a/producer/interceptor.go b/producer/interceptor.go
index 46f8042..d545f93 100644
--- a/producer/interceptor.go
+++ b/producer/interceptor.go
@@ -81,7 +81,7 @@ func newTraceInterceptor(traceCfg primitive.TraceConfig) 
primitive.Interceptor {
                }
 
                traceCtx := internal.TraceContext{
-                       RequestId: internal.CreateUniqID(), // set id
+                       RequestId: primitive.CreateUniqID(), // set id
                        TimeStamp: time.Now().UnixNano() / 
int64(time.Millisecond),
 
                        TraceType:  internal.Pub,
diff --git a/producer/option.go b/producer/option.go
index ad6c98e..5ec003c 100644
--- a/producer/option.go
+++ b/producer/option.go
@@ -18,14 +18,17 @@ limitations under the License.
 package producer
 
 import (
+       "time"
+
        "github.com/apache/rocketmq-client-go/internal"
        "github.com/apache/rocketmq-client-go/primitive"
 )
 
 func defaultProducerOptions() producerOptions {
        opts := producerOptions{
-               ClientOptions: internal.DefaultClientOptions(),
-               Selector:      NewRoundRobinQueueSelector(),
+               ClientOptions:  internal.DefaultClientOptions(),
+               Selector:       NewRoundRobinQueueSelector(),
+               SendMsgTimeout: 3 * time.Second,
        }
        opts.ClientOptions.GroupName = "DEFAULT_CONSUMER"
        return opts
@@ -33,7 +36,8 @@ func defaultProducerOptions() producerOptions {
 
 type producerOptions struct {
        internal.ClientOptions
-       Selector QueueSelector
+       Selector       QueueSelector
+       SendMsgTimeout time.Duration
 }
 
 type Option func(*producerOptions)
@@ -57,6 +61,12 @@ func WithNameServer(nameServers []string) Option {
        }
 }
 
+func WithSendMsgTimeout(duration time.Duration) Option {
+       return func(opts *producerOptions) {
+               opts.SendMsgTimeout = duration
+       }
+}
+
 func WithVIPChannel(enable bool) Option {
        return func(opts *producerOptions) {
                opts.VIPChannelEnabled = enable
diff --git a/producer/producer.go b/producer/producer.go
index 7a3f40f..1d082e9 100644
--- a/producer/producer.go
+++ b/producer/producer.go
@@ -20,15 +20,17 @@ package producer
 import (
        "context"
        "fmt"
+       "strconv"
        "sync"
        "time"
 
+       "github.com/pkg/errors"
+
        "github.com/apache/rocketmq-client-go/internal"
        "github.com/apache/rocketmq-client-go/internal/remote"
        "github.com/apache/rocketmq-client-go/internal/utils"
        "github.com/apache/rocketmq-client-go/primitive"
        "github.com/apache/rocketmq-client-go/rlog"
-       "github.com/pkg/errors"
 )
 
 var (
@@ -37,6 +39,17 @@ var (
        ErrNotRunning   = errors.New("producer not started")
 )
 
+type defaultProducer struct {
+       group       string
+       client      internal.RMQClient
+       state       internal.ServiceState
+       options     producerOptions
+       publishInfo sync.Map
+       callbackCh  chan interface{}
+
+       interceptor primitive.Interceptor
+}
+
 func NewDefaultProducer(opts ...Option) (*defaultProducer, error) {
        defaultOpts := defaultProducerOptions()
        for _, apply := range opts {
@@ -49,26 +62,17 @@ func NewDefaultProducer(opts ...Option) (*defaultProducer, 
error) {
        internal.RegisterNamsrv(srvs)
 
        producer := &defaultProducer{
-               group:   "default",
-               client:  
internal.GetOrNewRocketMQClient(defaultOpts.ClientOptions),
-               options: defaultOpts,
+               group:      defaultOpts.GroupName,
+               callbackCh: make(chan interface{}),
+               options:    defaultOpts,
        }
+       producer.client = 
internal.GetOrNewRocketMQClient(defaultOpts.ClientOptions, producer.callbackCh)
 
        producer.interceptor = 
primitive.ChainInterceptors(producer.options.Interceptors...)
 
        return producer, nil
 }
 
-type defaultProducer struct {
-       group       string
-       client      internal.RMQClient
-       state       internal.ServiceState
-       options     producerOptions
-       publishInfo sync.Map
-
-       interceptor primitive.Interceptor
-}
-
 func (p *defaultProducer) Start() error {
        p.state = internal.StateRunning
        p.client.RegisterProducer(p.group, p)
@@ -247,11 +251,23 @@ func (p *defaultProducer) sendOneWay(ctx context.Context, 
msg *primitive.Message
 
 func (p *defaultProducer) buildSendRequest(mq *primitive.MessageQueue,
        msg *primitive.Message) *remote.RemotingCommand {
+       if !msg.Batch && 
msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex] == "" {
+               msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex] 
= primitive.CreateUniqID()
+       }
+       sysFlag := 0
+       v, ok := msg.Properties[primitive.PropertyTransactionPrepared]
+       if ok {
+               tranMsg, err := strconv.ParseBool(v)
+               if err == nil && tranMsg {
+                       sysFlag |= primitive.TransactionPreparedType
+               }
+       }
+
        req := &internal.SendMessageRequest{
                ProducerGroup:  p.group,
                Topic:          mq.Topic,
                QueueId:        mq.QueueId,
-               SysFlag:        0,
+               SysFlag:        sysFlag,
                BornTimestamp:  time.Now().UnixNano() / int64(time.Millisecond),
                Flag:           msg.Flag,
                Properties:     primitive.MarshalPropeties(msg.Properties),
@@ -318,3 +334,143 @@ func (p *defaultProducer) IsPublishTopicNeedUpdate(topic 
string) bool {
 func (p *defaultProducer) IsUnitMode() bool {
        return false
 }
+
+type transactionProducer struct {
+       producer *defaultProducer
+       listener primitive.TransactionListener
+}
+
+// TODO: checkLocalTransaction
+func NewTransactionProducer(listener primitive.TransactionListener, opts 
...Option) (*transactionProducer, error) {
+       producer, err := NewDefaultProducer(opts...)
+       if err != nil {
+               return nil, errors.Wrap(err, "NewDefaultProducer failed.")
+       }
+       return &transactionProducer{
+               producer: producer,
+               listener: listener,
+       }, nil
+}
+
+func (tp *transactionProducer) Start() error {
+       go tp.checkTransactionState()
+       return tp.producer.Start()
+}
+func (tp *transactionProducer) Shutdown() error {
+       return tp.producer.Shutdown()
+}
+
+// TODO: check addr
+func (tp *transactionProducer) checkTransactionState() {
+       for ch := range tp.producer.callbackCh {
+               switch callback := ch.(type) {
+               case internal.CheckTransactionStateCallback:
+                       localTransactionState := 
tp.listener.CheckLocalTransaction(callback.Msg)
+                       uniqueKey, existed := 
callback.Msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex]
+                       if !existed {
+                               uniqueKey = callback.Msg.MsgId
+                       }
+                       header := &internal.EndTransactionRequestHeader{
+                               CommitLogOffset:      
callback.Header.CommitLogOffset,
+                               ProducerGroup:        tp.producer.group,
+                               TranStateTableOffset: 
callback.Header.TranStateTableOffset,
+                               FromTransactionCheck: true,
+                               MsgID:                uniqueKey,
+                               TransactionId:        
callback.Header.TransactionId,
+                               CommitOrRollback:     
tp.transactionState(localTransactionState),
+                       }
+
+                       req := 
remote.NewRemotingCommand(internal.ReqENDTransaction, header, nil)
+                       req.Remark = tp.errRemark(nil)
+
+                       tp.producer.client.InvokeOneWay(callback.Addr.String(), 
req, tp.producer.options.SendMsgTimeout)
+               default:
+                       rlog.Error("unknow type %v", ch)
+               }
+       }
+}
+
+func (tp *transactionProducer) SendMessageInTransaction(ctx context.Context, 
msg *primitive.Message) (*primitive.TransactionSendResult, error) {
+       if msg.Properties == nil {
+               msg.Properties = make(map[string]string, 0)
+       }
+       msg.Properties[primitive.PropertyTransactionPrepared] = "true"
+       msg.Properties[primitive.PropertyProducerGroup] = 
tp.producer.options.GroupName
+
+       rsp, err := tp.producer.SendSync(ctx, msg)
+       if err != nil {
+               return nil, err
+       }
+       localTransactionState := primitive.UnknowState
+       switch rsp.Status {
+       case primitive.SendOK:
+               if len(rsp.TransactionID) > 0 {
+                       msg.Properties["__transactionId__"] = rsp.TransactionID
+               }
+               transactionId := 
msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex]
+               if len(transactionId) > 0 {
+                       msg.TransactionId = transactionId
+               }
+               localTransactionState = 
tp.listener.ExecuteLocalTransaction(*msg)
+               if localTransactionState != primitive.CommitMessageState {
+                       rlog.Errorf("executeLocalTransactionBranch return %v 
with msg: %v\n", localTransactionState, msg)
+               }
+
+       case primitive.SendFlushDiskTimeout, primitive.SendFlushSlaveTimeout, 
primitive.SendSlaveNotAvailable:
+               localTransactionState = primitive.RollbackMessageState
+       default:
+       }
+
+       tp.endTransaction(*rsp, err, localTransactionState)
+
+       transactionSendResult := &primitive.TransactionSendResult{
+               SendResult: rsp,
+               State:      localTransactionState,
+       }
+
+       return transactionSendResult, nil
+}
+
+func (tp *transactionProducer) endTransaction(result primitive.SendResult, err 
error, state primitive.LocalTransactionState) error {
+       var msgID *primitive.MessageID
+       if len(result.OffsetMsgID) > 0 {
+               msgID, _ = primitive.UnmarshalMsgID([]byte(result.OffsetMsgID))
+       } else {
+               msgID, _ = primitive.UnmarshalMsgID([]byte(result.MsgID))
+       }
+       // 估计没有反序列化回来
+       brokerAddr := 
internal.FindBrokerAddrByName(result.MessageQueue.BrokerName)
+       requestHeader := &internal.EndTransactionRequestHeader{
+               TransactionId:        result.TransactionID,
+               CommitLogOffset:      msgID.Offset,
+               ProducerGroup:        tp.producer.group,
+               TranStateTableOffset: result.QueueOffset,
+               MsgID:                result.MsgID,
+               CommitOrRollback:     tp.transactionState(state),
+       }
+
+       req := remote.NewRemotingCommand(internal.ReqENDTransaction, 
requestHeader, nil)
+       req.Remark = tp.errRemark(err)
+
+       return tp.producer.client.InvokeOneWay(brokerAddr, req, 
tp.producer.options.SendMsgTimeout)
+}
+
+func (tp *transactionProducer) errRemark(err error) string {
+       if err != nil {
+               return "executeLocalTransactionBranch exception: " + err.Error()
+       }
+       return ""
+}
+
+func (tp *transactionProducer) transactionState(state 
primitive.LocalTransactionState) int {
+       switch state {
+       case primitive.CommitMessageState:
+               return primitive.TransactionCommitType
+       case primitive.RollbackMessageState:
+               return primitive.TransactionRollbackType
+       case primitive.UnknowState:
+               return primitive.TransactionNotType
+       default:
+               return primitive.TransactionNotType
+       }
+}

Reply via email to