This is an automated email from the ASF dual-hosted git repository.
dinglei pushed a commit to branch native
in repository https://gitbox.apache.org/repos/asf/rocketmq-client-go.git
The following commit(s) were added to refs/heads/native by this push:
new 68358ec Add transaction support. resolve #137 (#138)
68358ec is described below
commit 68358ecfc6ec1ee48b1c045f895a676e028c2bfa
Author: xujianhai666 <[email protected]>
AuthorDate: Wed Aug 14 14:41:22 2019 +0800
Add transaction support. resolve #137 (#138)
---
.travis.yml | 3 -
api.go | 10 +
consumer/consumer.go | 2 +-
consumer/interceptor.go | 2 +-
examples/producer/transaction/main.go | 108 +++++++++
internal/{utils/messagesysflag.go => callback.go} | 31 +--
internal/client.go | 58 ++++-
internal/model.go | 31 ++-
internal/model_test.go | 117 +++++++++
internal/remote/remote_client.go | 4 +-
internal/remote/remote_client_test.go | 16 +-
internal/request.go | 63 +++++
internal/trace.go | 54 +----
internal/utils/set.go | 81 +++++++
primitive/message.go | 280 +++++++++++++++++++++-
primitive/result.go | 147 +-----------
producer/interceptor.go | 2 +-
producer/option.go | 16 +-
producer/producer.go | 186 ++++++++++++--
19 files changed, 947 insertions(+), 264 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index de0a42e..1a9683c 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -24,9 +24,6 @@ before_script:
- perl -i -pe's/-Xms8g -Xmx8g -Xmn4g/-Xms2g -Xmx2g -Xmn1g/g' bin/runbroker.sh
- nohup sh bin/mqnamesrv &
- nohup sh bin/mqbroker -n localhost:9876 &
-# - sleep 10
-# - ./bin/mqadmin updateTopic -n ${NAME_SERVER_ADDRESS} -b ${BROKER_ADDRESS}
-t ${TOPIC}
-# - ./bin/mqadmin updateSubGroup -n ${NAME_SERVER_ADDRESS} -b
${BROKER_ADDRESS} -g ${GROUP}
script:
- cd ${GOPATH}/src/github.com/apache/rocketmq-client-go
diff --git a/api.go b/api.go
index eec98bf..c6f6798 100644
--- a/api.go
+++ b/api.go
@@ -37,6 +37,16 @@ func NewProducer(opts ...producer.Option) (Producer, error) {
return producer.NewDefaultProducer(opts...)
}
+type TransactionProducer interface {
+ Start() error
+ Shutdown() error
+ SendMessageInTransaction(context.Context, *primitive.Message)
(*primitive.TransactionSendResult, error)
+}
+
+func NewTransactionProducer(listener primitive.TransactionListener, opts
...producer.Option) (TransactionProducer, error) {
+ return producer.NewTransactionProducer(listener, opts...)
+}
+
type PushConsumer interface {
Start() error
Shutdown() error
diff --git a/consumer/consumer.go b/consumer/consumer.go
index 25861f9..56836ae 100644
--- a/consumer/consumer.go
+++ b/consumer/consumer.go
@@ -275,7 +275,7 @@ func (dc *defaultConsumer) start() error {
dc.subscriptionDataTable.Store(retryTopic, sub)
}
- dc.client = internal.GetOrNewRocketMQClient(dc.option.ClientOptions)
+ dc.client = internal.GetOrNewRocketMQClient(dc.option.ClientOptions,
nil)
if dc.model == Clustering {
dc.option.ChangeInstanceNameToPID()
dc.storage = NewRemoteOffsetStore(dc.consumerGroup, dc.client)
diff --git a/consumer/interceptor.go b/consumer/interceptor.go
index 6b050df..657633e 100644
--- a/consumer/interceptor.go
+++ b/consumer/interceptor.go
@@ -50,7 +50,7 @@ func newTraceInterceptor(traceCfg primitive.TraceConfig)
primitive.Interceptor {
beginT := time.Now()
// before traceCtx
traceCx := internal.TraceContext{
- RequestId: internal.CreateUniqID(),
+ RequestId: primitive.CreateUniqID(),
TimeStamp: time.Now().UnixNano() /
int64(time.Millisecond),
TraceType: internal.SubBefore,
GroupName: consumerCtx.ConsumerGroup,
diff --git a/examples/producer/transaction/main.go
b/examples/producer/transaction/main.go
new file mode 100644
index 0000000..660db96
--- /dev/null
+++ b/examples/producer/transaction/main.go
@@ -0,0 +1,108 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/apache/rocketmq-client-go"
+ "github.com/apache/rocketmq-client-go/primitive"
+ "github.com/apache/rocketmq-client-go/producer"
+)
+
+type DemoListener struct {
+ localTrans *sync.Map
+ transactionIndex int32
+}
+
+func NewDemoListener() *DemoListener {
+ return &DemoListener{
+ localTrans: new(sync.Map),
+ }
+}
+
+func (dl *DemoListener) ExecuteLocalTransaction(msg primitive.Message)
primitive.LocalTransactionState {
+ nextIndex := atomic.AddInt32(&dl.transactionIndex, 1)
+ fmt.Printf("nextIndex: %v for transactionID: %v\n", nextIndex,
msg.TransactionId)
+ status := nextIndex % 3
+ dl.localTrans.Store(msg.TransactionId,
primitive.LocalTransactionState(status+1))
+
+ fmt.Printf("dl")
+ return primitive.UnknowState
+}
+
+func (dl *DemoListener) CheckLocalTransaction(msg primitive.MessageExt)
primitive.LocalTransactionState {
+ fmt.Printf("msg transactionID : %v\n", msg.TransactionId)
+ v, existed := dl.localTrans.Load(msg.TransactionId)
+ if !existed {
+ fmt.Printf("unknow msg: %v, return Commit", msg)
+ return primitive.CommitMessageState
+ }
+ state := v.(primitive.LocalTransactionState)
+ switch state {
+ case 1:
+ fmt.Printf("checkLocalTransaction COMMIT_MESSAGE: %v\n", msg)
+ return primitive.CommitMessageState
+ case 2:
+ fmt.Printf("checkLocalTransaction ROLLBACK_MESSAGE: %v\n", msg)
+ return primitive.RollbackMessageState
+ case 3:
+ fmt.Printf("checkLocalTransaction unknow: %v\n", msg)
+ return primitive.UnknowState
+ default:
+ fmt.Printf("checkLocalTransaction default COMMIT_MESSAGE:
%v\n", msg)
+ return primitive.CommitMessageState
+ }
+
+ return primitive.UnknowState
+}
+
+func main() {
+ p, _ := rocketmq.NewTransactionProducer(
+ NewDemoListener(),
+ producer.WithNameServer([]string{"127.0.0.1:9876"}),
+ producer.WithRetry(1),
+ )
+ err := p.Start()
+ if err != nil {
+ fmt.Printf("start producer error: %s\n", err.Error())
+ os.Exit(1)
+ }
+
+ for i := 0; i < 10; i++ {
+ res, err := p.SendMessageInTransaction(context.Background(),
+ primitive.NewMessage("TopicTest5", []byte("Hello
RocketMQ again "+strconv.Itoa(i))))
+
+ if err != nil {
+ fmt.Printf("send message error: %s\n", err)
+ } else {
+ fmt.Printf("send message success: result=%s\n",
res.String())
+ }
+ }
+ time.Sleep(5 * time.Minute)
+ err = p.Shutdown()
+ if err != nil {
+ fmt.Printf("shundown producer error: %s", err.Error())
+ }
+}
diff --git a/internal/utils/messagesysflag.go b/internal/callback.go
similarity index 61%
rename from internal/utils/messagesysflag.go
rename to internal/callback.go
index 2410a5b..2ff182c 100644
--- a/internal/utils/messagesysflag.go
+++ b/internal/callback.go
@@ -15,30 +15,17 @@ See the License for the specific language governing
permissions and
limitations under the License.
*/
-package utils
+package internal
-var (
- CompressedFlag = 0x1
+import (
+ "net"
- MultiTagsFlag = 0x1 << 1
-
- TransactionNotType = 0
-
- TransactionPreparedType = 0x1 << 2
-
- TransactionCommitType = 0x2 << 2
-
- TransactionRollbackType = 0x3 << 2
+ "github.com/apache/rocketmq-client-go/primitive"
)
-func GetTransactionValue(flag int) int {
- return flag & TransactionRollbackType
-}
-
-func ResetTransactionValue(flag int, typeFlag int) int {
- return (flag & (^TransactionRollbackType)) | typeFlag
-}
-
-func ClearCompressedFlag(flag int) int {
- return flag & (^CompressedFlag)
+// remotingClient callback TransactionProducer
+type CheckTransactionStateCallback struct {
+ Addr net.Addr
+ Msg primitive.MessageExt
+ Header CheckTransactionStateRequestHeader
}
diff --git a/internal/client.go b/internal/client.go
index d58efc4..c6b3246 100644
--- a/internal/client.go
+++ b/internal/client.go
@@ -22,6 +22,7 @@ import (
"context"
"errors"
"fmt"
+ "net"
"os"
"strconv"
"strings"
@@ -165,18 +166,50 @@ type rmqClient struct {
var clientMap sync.Map
-func GetOrNewRocketMQClient(option ClientOptions) *rmqClient {
+func GetOrNewRocketMQClient(option ClientOptions, callbackCh chan interface{})
*rmqClient {
client := &rmqClient{
option: option,
remoteClient: remote.NewRemotingClient(),
}
actual, loaded := clientMap.LoadOrStore(client.ClientID(), client)
if !loaded {
-
client.remoteClient.RegisterRequestFunc(ReqNotifyConsumerIdsChanged, func(req
*remote.RemotingCommand) *remote.RemotingCommand {
+
client.remoteClient.RegisterRequestFunc(ReqNotifyConsumerIdsChanged, func(req
*remote.RemotingCommand, addr net.Addr) *remote.RemotingCommand {
rlog.Infof("receive broker's notification, the consumer
group: %s", req.ExtFields["consumerGroup"])
client.RebalanceImmediately()
return nil
})
+
client.remoteClient.RegisterRequestFunc(ReqCheckTransactionState, func(req
*remote.RemotingCommand, addr net.Addr) *remote.RemotingCommand {
+ header := new(CheckTransactionStateRequestHeader)
+ header.Decode(req.ExtFields)
+ msgExts := primitive.DecodeMessage(req.Body)
+ if len(msgExts) == 0 {
+ rlog.Warn("checkTransactionState, decode
message failed")
+ return nil
+ }
+ msgExt := msgExts[0]
+ // TODO: add namespace support
+ transactionID :=
msgExt.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex]
+ if len(transactionID) > 0 {
+ msgExt.TransactionId = transactionID
+ }
+ group, existed :=
msgExt.Properties[primitive.PropertyProducerGroup]
+ if !existed {
+ rlog.Warn("checkTransactionState, pick producer
group failed")
+ return nil
+ }
+ if option.GroupName != group {
+ rlog.Warn("producer group is not equal.")
+ return nil
+ }
+ callback := CheckTransactionStateCallback{
+ Addr: addr,
+ Msg: *msgExt,
+ Header: *header,
+ }
+ callbackCh <- callback
+ return nil
+ })
+
}
return actual.(*rmqClient)
}
@@ -283,31 +316,30 @@ func (c *rmqClient) CheckClientInBroker() {
func (c *rmqClient) SendHeartbeatToAllBrokerWithLock() {
c.hbMutex.Lock()
defer c.hbMutex.Unlock()
- hbData := &heartbeatData{
- ClientId: c.ClientID(),
- }
- pData := make([]producerData, 0)
+ hbData := NewHeartbeatData(c.ClientID())
+
c.producerMap.Range(func(key, value interface{}) bool {
- pData = append(pData, producerData(key.(string)))
+ pData := producerData{
+ GroupName: key.(string),
+ }
+ hbData.ProducerDatas.Add(pData)
return true
})
- cData := make([]consumerData, 0)
c.consumerMap.Range(func(key, value interface{}) bool {
consumer := value.(InnerConsumer)
- cData = append(cData, consumerData{
+ cData := consumerData{
GroupName: key.(string),
CType: "PUSH",
MessageModel: "CLUSTERING",
Where: "CONSUME_FROM_FIRST_OFFSET",
UnitMode: consumer.IsUnitMode(),
SubscriptionDatas: consumer.SubscriptionDataList(),
- })
+ }
+ hbData.ConsumerDatas.Add(cData)
return true
})
- hbData.ProducerDatas = pData
- hbData.ConsumerDatas = cData
- if len(pData) == 0 && len(cData) == 0 {
+ if hbData.ProducerDatas.Len() == 0 && hbData.ConsumerDatas.Len() == 0 {
rlog.Info("sending heartbeat, but no producer and no consumer")
return
}
diff --git a/internal/model.go b/internal/model.go
index 00cfe63..f534234 100644
--- a/internal/model.go
+++ b/internal/model.go
@@ -20,6 +20,7 @@ package internal
import (
"encoding/json"
+ "github.com/apache/rocketmq-client-go/internal/utils"
"github.com/apache/rocketmq-client-go/rlog"
)
@@ -31,8 +32,6 @@ type FindBrokerResult struct {
type (
// groupName of consumer
- producerData string
-
consumeType string
ServiceState int
@@ -55,6 +54,14 @@ type SubscriptionData struct {
ExpType string
}
+type producerData struct {
+ GroupName string `json:"groupName"`
+}
+
+func (p producerData) UniqueID() string {
+ return p.GroupName
+}
+
type consumerData struct {
GroupName string `json:"groupName"`
CType consumeType `json:"consumeType"`
@@ -64,10 +71,22 @@ type consumerData struct {
UnitMode bool `json:"unitMode"`
}
+func (c consumerData) UniqueID() string {
+ return c.GroupName
+}
+
type heartbeatData struct {
- ClientId string `json:"clientID"`
- ProducerDatas []producerData `json:"producerDataSet"`
- ConsumerDatas []consumerData `json:"consumerDataSet"`
+ ClientId string `json:"clientID"`
+ ProducerDatas utils.Set `json:"producerDataSet"`
+ ConsumerDatas utils.Set `json:"consumerDataSet"`
+}
+
+func NewHeartbeatData(clientID string) *heartbeatData {
+ return &heartbeatData{
+ ClientId: clientID,
+ ProducerDatas: utils.NewSet(),
+ ConsumerDatas: utils.NewSet(),
+ }
}
func (data *heartbeatData) encode() []byte {
@@ -76,6 +95,6 @@ func (data *heartbeatData) encode() []byte {
rlog.Errorf("marshal heartbeatData error: %s", err.Error())
return nil
}
- rlog.Info(string(d))
+ rlog.Info("heartbeat: " + string(d))
return d
}
diff --git a/internal/model_test.go b/internal/model_test.go
new file mode 100644
index 0000000..1dac4ec
--- /dev/null
+++ b/internal/model_test.go
@@ -0,0 +1,117 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package internal
+
+import (
+ "encoding/json"
+ "fmt"
+ "testing"
+
+ . "github.com/smartystreets/goconvey/convey"
+ "github.com/stretchr/testify/assert"
+
+ "github.com/apache/rocketmq-client-go/internal/utils"
+)
+
+func TestHeartbeatData(t *testing.T) {
+ Convey("test heatbeat json", t, func() {
+
+ Convey("producerData set marshal", func() {
+ pData := &producerData{
+ GroupName: "group name",
+ }
+ pData2 := &producerData{
+ GroupName: "group name 2",
+ }
+ set := utils.NewSet()
+ set.Add(pData)
+ set.Add(pData2)
+
+ v, err := json.Marshal(set)
+ assert.Nil(t, err)
+ fmt.Printf("json producer set: %s", string(v))
+ })
+
+ Convey("producer heatbeat", func() {
+
+ hbt := NewHeartbeatData("producer client id")
+ p1 := &producerData{
+ GroupName: "group name",
+ }
+ p2 := &producerData{
+ GroupName: "group name 2",
+ }
+
+ hbt.ProducerDatas.Add(p1)
+ hbt.ProducerDatas.Add(p2)
+
+ v, err := json.Marshal(hbt)
+ //ShouldBeNil(t, err)
+ assert.Nil(t, err)
+ fmt.Printf("json producer: %s\n", string(v))
+ })
+
+ Convey("consumer heartbeat", func() {
+
+ hbt := NewHeartbeatData("consumer client id")
+ c1 := consumerData{
+ GroupName: "consumer data 1",
+ }
+ c2 := consumerData{
+ GroupName: "consumer data 2",
+ }
+ hbt.ConsumerDatas.Add(c1)
+ hbt.ConsumerDatas.Add(c2)
+
+ v, err := json.Marshal(hbt)
+ //ShouldBeNil(t, err)
+ assert.Nil(t, err)
+ fmt.Printf("json consumer: %s\n", string(v))
+ })
+
+ Convey("producer & consumer heartbeat", func() {
+
+ hbt := NewHeartbeatData("consumer client id")
+
+ p1 := &producerData{
+ GroupName: "group name",
+ }
+ p2 := &producerData{
+ GroupName: "group name 2",
+ }
+
+ hbt.ProducerDatas.Add(p1)
+ hbt.ProducerDatas.Add(p2)
+
+ c1 := consumerData{
+ GroupName: "consumer data 1",
+ }
+ c2 := consumerData{
+ GroupName: "consumer data 2",
+ }
+ hbt.ConsumerDatas.Add(c1)
+ hbt.ConsumerDatas.Add(c2)
+
+ v, err := json.Marshal(hbt)
+ //ShouldBeNil(t, err)
+ assert.Nil(t, err)
+ fmt.Printf("json producer & consumer: %s\n", string(v))
+ })
+ })
+
+}
diff --git a/internal/remote/remote_client.go b/internal/remote/remote_client.go
index cfeb909..e8f10c2 100644
--- a/internal/remote/remote_client.go
+++ b/internal/remote/remote_client.go
@@ -31,7 +31,7 @@ import (
"github.com/apache/rocketmq-client-go/rlog"
)
-type ClientRequestFunc func(*RemotingCommand) *RemotingCommand
+type ClientRequestFunc func(*RemotingCommand, net.Addr) *RemotingCommand
type TcpOption struct {
// TODO
@@ -148,7 +148,7 @@ func (c *RemotingClient) receiveResponse(r net.Conn) {
f := c.processors[cmd.Code]
if f != nil {
go func() { // 单个goroutine会造成死锁
- res := f(cmd)
+ res := f(cmd, r.RemoteAddr())
if res != nil {
err := c.sendRequest(r, res)
if err != nil {
diff --git a/internal/remote/remote_client_test.go
b/internal/remote/remote_client_test.go
index ff6900e..9cbb117 100644
--- a/internal/remote/remote_client_test.go
+++ b/internal/remote/remote_client_test.go
@@ -167,7 +167,12 @@ func TestInvokeSync(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
client := NewRemotingClient()
+
+ var clientSend sync.WaitGroup // blocking client send message until the
server listen success.
+ clientSend.Add(1)
+
go func() {
+ clientSend.Wait()
receiveCommand, err := client.InvokeSync(addr,
clientSendRemtingCommand, time.Second)
if err != nil {
@@ -189,6 +194,7 @@ func TestInvokeSync(t *testing.T) {
t.Fatal(err)
}
defer l.Close()
+ clientSend.Done()
for {
conn, err := l.Accept()
if err != nil {
@@ -337,7 +343,11 @@ func TestInvokeOneWay(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
client := NewRemotingClient()
+
+ var clientSend sync.WaitGroup // blocking client send message until the
server listen success.
+ clientSend.Add(1)
go func() {
+ clientSend.Wait()
err := client.InvokeOneWay(addr, clientSendRemtingCommand,
3*time.Second)
if err != nil {
t.Fatalf("failed to invoke synchronous. %s", err)
@@ -350,6 +360,7 @@ func TestInvokeOneWay(t *testing.T) {
t.Fatal(err)
}
defer l.Close()
+ clientSend.Done()
for {
conn, err := l.Accept()
if err != nil {
@@ -366,8 +377,9 @@ func TestInvokeOneWay(t *testing.T) {
t.Errorf("wrong code. want=%d, got=%d",
receivedRemotingCommand.Code,
clientSendRemtingCommand.Code)
}
- return
+ goto done
}
}
- wg.Done()
+done:
+ wg.Wait()
}
diff --git a/internal/request.go b/internal/request.go
index e09a986..3dce875 100644
--- a/internal/request.go
+++ b/internal/request.go
@@ -32,6 +32,7 @@ const (
ReqGetMaxOffset = int16(30)
ReqHeartBeat = int16(34)
ReqConsumerSendMsgBack = int16(36)
+ ReqENDTransaction = int16(37)
ReqGetConsumerListByGroup = int16(38)
ReqLockBatchMQ = int16(41)
ReqUnlockBatchMQ = int16(42)
@@ -81,6 +82,68 @@ func (request *SendMessageRequest) Decode(properties
map[string]string) error {
return nil
}
+type EndTransactionRequestHeader struct {
+ ProducerGroup string `json:"producerGroup"`
+ TranStateTableOffset int64 `json:"tranStateTableOffset"`
+ CommitLogOffset int64 `json:"commitLogOffset"`
+ CommitOrRollback int `json:"commitOrRollback"`
+ FromTransactionCheck bool `json:"fromTransactionCheck"`
+ MsgID string `json:"msgId"`
+ TransactionId string `json:"transactionId"`
+}
+
+func (request *EndTransactionRequestHeader) Encode() map[string]string {
+ maps := make(map[string]string)
+ maps["producerGroup"] = request.ProducerGroup
+ maps["tranStateTableOffset"] =
strconv.FormatInt(request.TranStateTableOffset, 10)
+ maps["commitLogOffset"] = strconv.Itoa(int(request.CommitLogOffset))
+ maps["commitOrRollback"] = strconv.Itoa(request.CommitOrRollback)
+ maps["fromTransactionCheck"] =
strconv.FormatBool(request.FromTransactionCheck)
+ maps["msgId"] = request.MsgID
+ maps["transactionId"] = request.TransactionId
+ return maps
+}
+
+type CheckTransactionStateRequestHeader struct {
+ TranStateTableOffset int64
+ CommitLogOffset int64
+ MsgId string
+ TransactionId string
+ OffsetMsgId string
+}
+
+func (request *CheckTransactionStateRequestHeader) Encode() map[string]string {
+ maps := make(map[string]string)
+ maps["tranStateTableOffset"] =
strconv.FormatInt(request.TranStateTableOffset, 10)
+ maps["commitLogOffset"] = strconv.FormatInt(request.CommitLogOffset, 10)
+ maps["msgId"] = request.MsgId
+ maps["transactionId"] = request.TransactionId
+ maps["offsetMsgId"] = request.OffsetMsgId
+
+ return maps
+}
+
+func (request *CheckTransactionStateRequestHeader) Decode(ext
map[string]string) {
+ if len(ext) == 0 {
+ return
+ }
+ if v, existed := ext["tranStateTableOffset"]; existed {
+ request.TranStateTableOffset, _ = strconv.ParseInt(v, 10, 0)
+ }
+ if v, existed := ext["commitLogOffset"]; existed {
+ request.CommitLogOffset, _ = strconv.ParseInt(v, 10, 0)
+ }
+ if v, existed := ext["msgId"]; existed {
+ request.MsgId = v
+ }
+ if v, existed := ext["transactionId"]; existed {
+ request.MsgId = v
+ }
+ if v, existed := ext["offsetMsgId"]; existed {
+ request.MsgId = v
+ }
+}
+
type ConsumerSendMsgBackRequest struct {
Group string `json:"group"`
Offset int64 `json:"offset"`
diff --git a/internal/trace.go b/internal/trace.go
index dc3b246..ebd3be5 100644
--- a/internal/trace.go
+++ b/internal/trace.go
@@ -20,70 +20,18 @@ package internal
import (
"bytes"
"context"
- "encoding/binary"
- "encoding/hex"
"fmt"
- "os"
"runtime"
"strconv"
"strings"
- "sync"
"sync/atomic"
"time"
"github.com/apache/rocketmq-client-go/internal/remote"
- "github.com/apache/rocketmq-client-go/internal/utils"
"github.com/apache/rocketmq-client-go/primitive"
"github.com/apache/rocketmq-client-go/rlog"
)
-var (
- counter int16 = 0
- startTimestamp int64 = 0
- nextTimestamp int64 = 0
- prefix string
- locker sync.Mutex
- classLoadId int32 = 0
-)
-
-func init() {
- buf := new(bytes.Buffer)
-
- ip, err := utils.ClientIP4()
- if err != nil {
- ip = utils.FakeIP()
- }
- _, _ = buf.Write(ip)
- _ = binary.Write(buf, binary.BigEndian, Pid())
- _ = binary.Write(buf, binary.BigEndian, classLoadId)
- prefix = strings.ToUpper(hex.EncodeToString(buf.Bytes()))
-}
-
-func CreateUniqID() string {
- locker.Lock()
- defer locker.Unlock()
-
- if time.Now().Unix() > nextTimestamp {
- updateTimestamp()
- }
- counter++
- buf := new(bytes.Buffer)
- _ = binary.Write(buf, binary.BigEndian,
int32((time.Now().Unix()-startTimestamp)*1000))
- _ = binary.Write(buf, binary.BigEndian, counter)
-
- return prefix + hex.EncodeToString(buf.Bytes())
-}
-
-func updateTimestamp() {
- year, month := time.Now().Year(), time.Now().Month()
- startTimestamp = time.Date(year, month, 1, 0, 0, 0, 0,
time.Local).Unix()
- nextTimestamp = time.Date(year, month, 1, 0, 0, 0, 0,
time.Local).AddDate(0, 1, 0).Unix()
-}
-
-func Pid() int16 {
- return int16(os.Getpid())
-}
-
type TraceBean struct {
Topic string
MsgId string
@@ -276,7 +224,7 @@ func NewTraceDispatcher(traceTopic string, access
primitive.AccessChannel) *trac
cliOp := DefaultClientOptions()
cliOp.RetryTimes = 0
- cli := GetOrNewRocketMQClient(cliOp)
+ cli := GetOrNewRocketMQClient(cliOp, nil)
return &traceDispatcher{
ctx: ctx,
cancel: cancel,
diff --git a/internal/utils/set.go b/internal/utils/set.go
new file mode 100644
index 0000000..2f9f214
--- /dev/null
+++ b/internal/utils/set.go
@@ -0,0 +1,81 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package utils
+
+import (
+ "bytes"
+ "encoding/json"
+ "sort"
+)
+
+type UniqueItem interface {
+ UniqueID() string
+}
+
+type Set struct {
+ items map[string]UniqueItem
+}
+
+func NewSet() Set {
+ return Set{
+ items: make(map[string]UniqueItem, 0),
+ }
+}
+
+func (s *Set) Add(v UniqueItem) {
+ s.items[v.UniqueID()] = v
+}
+
+func (s *Set) Len() int {
+ return len(s.items)
+}
+
+var _ json.Marshaler = &Set{}
+
+func (s *Set) MarshalJSON() ([]byte, error) {
+ if len(s.items) == 0 {
+ return []byte("[]"), nil
+ }
+
+ buffer := new(bytes.Buffer)
+ buffer.WriteByte('[')
+ keys := make([]string, 0)
+ for _, k := range s.items {
+ v, err := json.Marshal(k)
+ if err != nil {
+ return nil, err
+ }
+ keys = append(keys, string(v))
+ }
+ sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
+
+ for i, key := range keys {
+ if i > 0 {
+ buffer.WriteByte(',')
+ }
+ buffer.WriteString(key)
+ }
+
+ buffer.WriteByte(']')
+
+ return buffer.Bytes(), nil
+}
+
+func (s Set) UnmarshalJSON(data []byte) (err error) {
+ return nil
+}
diff --git a/primitive/message.go b/primitive/message.go
index 28c9eab..3dd4a0c 100644
--- a/primitive/message.go
+++ b/primitive/message.go
@@ -18,9 +18,17 @@ limitations under the License.
package primitive
import (
+ "bytes"
+ "encoding/binary"
+ "encoding/hex"
"fmt"
+ "os"
"strconv"
"strings"
+ "sync"
+ "time"
+
+ "github.com/pkg/errors"
"github.com/apache/rocketmq-client-go/internal/utils"
)
@@ -68,11 +76,13 @@ type Message struct {
}
func NewMessage(topic string, body []byte) *Message {
- return &Message{
+ msg := &Message{
Topic: topic,
Body: body,
Properties: make(map[string]string),
}
+ msg.Properties[PropertyWaitStoreMsgOk] = strconv.FormatBool(true)
+ return msg
}
// SetDelayTimeLevel set message delay time to consume.
@@ -164,6 +174,135 @@ func (msgExt *MessageExt) String() string {
msgExt.PreparedTransactionOffset)
}
+func DecodeMessage(data []byte) []*MessageExt {
+ msgs := make([]*MessageExt, 0)
+ buf := bytes.NewBuffer(data)
+ count := 0
+ for count < len(data) {
+ msg := &MessageExt{}
+
+ // 1. total size
+ binary.Read(buf, binary.BigEndian, &msg.StoreSize)
+ count += 4
+
+ // 2. magic code
+ buf.Next(4)
+ count += 4
+
+ // 3. body CRC32
+ binary.Read(buf, binary.BigEndian, &msg.BodyCRC)
+ count += 4
+
+ // 4. queueID
+ binary.Read(buf, binary.BigEndian, &msg.QueueId)
+ count += 4
+
+ // 5. Flag
+ binary.Read(buf, binary.BigEndian, &msg.Flag)
+ count += 4
+
+ // 6. QueueOffset
+ binary.Read(buf, binary.BigEndian, &msg.QueueOffset)
+ count += 8
+
+ // 7. physical offset
+ binary.Read(buf, binary.BigEndian, &msg.CommitLogOffset)
+ count += 8
+
+ // 8. SysFlag
+ binary.Read(buf, binary.BigEndian, &msg.SysFlag)
+ count += 4
+
+ // 9. BornTimestamp
+ binary.Read(buf, binary.BigEndian, &msg.BornTimestamp)
+ count += 8
+
+ // 10. born host
+ hostBytes := buf.Next(4)
+ var port int32
+ binary.Read(buf, binary.BigEndian, &port)
+ msg.BornHost = fmt.Sprintf("%s:%d",
utils.GetAddressByBytes(hostBytes), port)
+ count += 8
+
+ // 11. store timestamp
+ binary.Read(buf, binary.BigEndian, &msg.StoreTimestamp)
+ count += 8
+
+ // 12. store host
+ hostBytes = buf.Next(4)
+ binary.Read(buf, binary.BigEndian, &port)
+ msg.StoreHost = fmt.Sprintf("%s:%d",
utils.GetAddressByBytes(hostBytes), port)
+ count += 8
+
+ // 13. reconsume times
+ binary.Read(buf, binary.BigEndian, &msg.ReconsumeTimes)
+ count += 4
+
+ // 14. prepared transaction offset
+ binary.Read(buf, binary.BigEndian,
&msg.PreparedTransactionOffset)
+ count += 8
+
+ // 15. body
+ var length int32
+ binary.Read(buf, binary.BigEndian, &length)
+ msg.Body = buf.Next(int(length))
+ if (msg.SysFlag & FlagCompressed) == FlagCompressed {
+ msg.Body = utils.UnCompress(msg.Body)
+ }
+ count += 4 + int(length)
+
+ // 16. topic
+ _byte, _ := buf.ReadByte()
+ msg.Topic = string(buf.Next(int(_byte)))
+ count += 1 + int(_byte)
+
+ // 17. properties
+ var propertiesLength int16
+ binary.Read(buf, binary.BigEndian, &propertiesLength)
+ if propertiesLength > 0 {
+ msg.Properties =
unmarshalProperties(buf.Next(int(propertiesLength)))
+ }
+ count += 2 + int(propertiesLength)
+
+ msg.MsgId = createMessageId(hostBytes, port,
msg.CommitLogOffset)
+ //count += 16
+ if msg.Properties == nil {
+ msg.Properties = make(map[string]string, 0)
+ }
+ msgs = append(msgs, msg)
+ }
+
+ return msgs
+}
+
+// unmarshalProperties parse data into property kv pairs.
+func unmarshalProperties(data []byte) map[string]string {
+ m := make(map[string]string)
+ items := bytes.Split(data, []byte{propertySeparator})
+ for _, item := range items {
+ kv := bytes.Split(item, []byte{nameValueSeparator})
+ if len(kv) == 2 {
+ m[string(kv[0])] = string(kv[1])
+ }
+ }
+ return m
+}
+
+func MarshalPropeties(properties map[string]string) string {
+ if properties == nil {
+ return ""
+ }
+ buffer := bytes.NewBufferString("")
+
+ for k, v := range properties {
+ buffer.WriteString(k)
+ buffer.WriteRune(nameValueSeparator)
+ buffer.WriteString(v)
+ buffer.WriteRune(propertySeparator)
+ }
+ return buffer.String()
+}
+
// MessageQueue message queue
type MessageQueue struct {
Topic string `json:"topic"`
@@ -206,3 +345,142 @@ const (
TransMsgCommit
DelayMsg
)
+
+type LocalTransactionState int
+
+const (
+ CommitMessageState LocalTransactionState = iota + 1
+ RollbackMessageState
+ UnknowState
+)
+
+type TransactionListener interface {
+ // When send transactional prepare(half) message succeed, this method
will be invoked to execute local transaction.
+ ExecuteLocalTransaction(Message) LocalTransactionState
+
+ // When no response to prepare(half) message. broker will send check
message to check the transaction status, and this
+ // method will be invoked to get local transaction status.
+ CheckLocalTransaction(MessageExt) LocalTransactionState
+}
+
+type MessageID struct {
+ Addr string
+ Port int
+ Offset int64
+}
+
+func createMessageId(addr []byte, port int32, offset int64) string {
+ buffer := new(bytes.Buffer)
+ buffer.Write(addr)
+ binary.Write(buffer, binary.BigEndian, port)
+ binary.Write(buffer, binary.BigEndian, offset)
+ return strings.ToUpper(hex.EncodeToString(buffer.Bytes()))
+}
+
+func UnmarshalMsgID(msgID []byte) (*MessageID, error) {
+ if len(msgID) < 32 {
+ return nil, errors.Errorf("%s len < 32", string(msgID))
+ }
+ ip := make([]byte, 8)
+ port := make([]byte, 8)
+ offset := make([]byte, 16)
+ var portVal int
+ var offsetVal int64
+
+ _, err := hex.Decode(ip, msgID[0:8])
+ if err != nil {
+ _, err = hex.Decode(port, msgID[8:16])
+ }
+ if err != nil {
+ _, err = hex.Decode(offset, msgID[16:32])
+ }
+ if err != nil {
+ portVal, err = strconv.Atoi(string(port))
+ }
+ if err != nil {
+ offsetVal, err = strconv.ParseInt(string(offset), 10, 0)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return &MessageID{
+ Addr: string(ip),
+ Port: portVal,
+ Offset: offsetVal,
+ }, nil
+}
+
+var (
+ CompressedFlag = 0x1
+
+ MultiTagsFlag = 0x1 << 1
+
+ TransactionNotType = 0
+
+ TransactionPreparedType = 0x1 << 2
+
+ TransactionCommitType = 0x2 << 2
+
+ TransactionRollbackType = 0x3 << 2
+)
+
+func GetTransactionValue(flag int) int {
+ return flag & TransactionRollbackType
+}
+
+func ResetTransactionValue(flag int, typeFlag int) int {
+ return (flag & (^TransactionRollbackType)) | typeFlag
+}
+
+func ClearCompressedFlag(flag int) int {
+ return flag & (^CompressedFlag)
+}
+
+var (
+ counter int16 = 0
+ startTimestamp int64 = 0
+ nextTimestamp int64 = 0
+ prefix string
+ locker sync.Mutex
+ classLoadId int32 = 0
+)
+
+func init() {
+ buf := new(bytes.Buffer)
+
+ ip, err := utils.ClientIP4()
+ if err != nil {
+ ip = utils.FakeIP()
+ }
+ _, _ = buf.Write(ip)
+ _ = binary.Write(buf, binary.BigEndian, Pid())
+ _ = binary.Write(buf, binary.BigEndian, classLoadId)
+ prefix = strings.ToUpper(hex.EncodeToString(buf.Bytes()))
+}
+
+func CreateUniqID() string {
+ locker.Lock()
+ defer locker.Unlock()
+
+ if time.Now().Unix() > nextTimestamp {
+ updateTimestamp()
+ }
+ counter++
+ buf := new(bytes.Buffer)
+ _ = binary.Write(buf, binary.BigEndian,
int32((time.Now().Unix()-startTimestamp)*1000))
+ _ = binary.Write(buf, binary.BigEndian, counter)
+
+ return prefix + hex.EncodeToString(buf.Bytes())
+}
+
+func updateTimestamp() {
+ year, month := time.Now().Year(), time.Now().Month()
+ startTimestamp = time.Date(year, month, 1, 0, 0, 0, 0,
time.Local).Unix()
+ nextTimestamp = time.Date(year, month, 1, 0, 0, 0, 0,
time.Local).AddDate(0, 1, 0).Unix()
+}
+
+func Pid() int16 {
+ return int16(os.Getpid())
+}
diff --git a/primitive/result.go b/primitive/result.go
index 4664cf0..201b36a 100644
--- a/primitive/result.go
+++ b/primitive/result.go
@@ -18,13 +18,7 @@ limitations under the License.
package primitive
import (
- "bytes"
- "encoding/binary"
- "encoding/hex"
"fmt"
- "strings"
-
- "github.com/apache/rocketmq-client-go/internal/utils"
)
// SendStatus of message
@@ -62,6 +56,12 @@ func (result *SendResult) String() string {
result.Status, result.MsgID, result.OffsetMsgID,
result.QueueOffset, result.MessageQueue.String())
}
+// SendResult RocketMQ send result
+type TransactionSendResult struct {
+ *SendResult
+ State LocalTransactionState
+}
+
// PullStatus pull Status
type PullStatus int
@@ -115,141 +115,6 @@ func (result *PullResult) String() string {
return ""
}
-func DecodeMessage(data []byte) []*MessageExt {
- msgs := make([]*MessageExt, 0)
- buf := bytes.NewBuffer(data)
- count := 0
- for count < len(data) {
- msg := &MessageExt{}
-
- // 1. total size
- binary.Read(buf, binary.BigEndian, &msg.StoreSize)
- count += 4
-
- // 2. magic code
- buf.Next(4)
- count += 4
-
- // 3. body CRC32
- binary.Read(buf, binary.BigEndian, &msg.BodyCRC)
- count += 4
-
- // 4. queueID
- binary.Read(buf, binary.BigEndian, &msg.QueueId)
- count += 4
-
- // 5. Flag
- binary.Read(buf, binary.BigEndian, &msg.Flag)
- count += 4
-
- // 6. QueueOffset
- binary.Read(buf, binary.BigEndian, &msg.QueueOffset)
- count += 8
-
- // 7. physical offset
- binary.Read(buf, binary.BigEndian, &msg.CommitLogOffset)
- count += 8
-
- // 8. SysFlag
- binary.Read(buf, binary.BigEndian, &msg.SysFlag)
- count += 4
-
- // 9. BornTimestamp
- binary.Read(buf, binary.BigEndian, &msg.BornTimestamp)
- count += 8
-
- // 10. born host
- hostBytes := buf.Next(4)
- var port int32
- binary.Read(buf, binary.BigEndian, &port)
- msg.BornHost = fmt.Sprintf("%s:%d",
utils.GetAddressByBytes(hostBytes), port)
- count += 8
-
- // 11. store timestamp
- binary.Read(buf, binary.BigEndian, &msg.StoreTimestamp)
- count += 8
-
- // 12. store host
- hostBytes = buf.Next(4)
- binary.Read(buf, binary.BigEndian, &port)
- msg.StoreHost = fmt.Sprintf("%s:%d",
utils.GetAddressByBytes(hostBytes), port)
- count += 8
-
- // 13. reconsume times
- binary.Read(buf, binary.BigEndian, &msg.ReconsumeTimes)
- count += 4
-
- // 14. prepared transaction offset
- binary.Read(buf, binary.BigEndian,
&msg.PreparedTransactionOffset)
- count += 8
-
- // 15. body
- var length int32
- binary.Read(buf, binary.BigEndian, &length)
- msg.Body = buf.Next(int(length))
- if (msg.SysFlag & FlagCompressed) == FlagCompressed {
- msg.Body = utils.UnCompress(msg.Body)
- }
- count += 4 + int(length)
-
- // 16. topic
- _byte, _ := buf.ReadByte()
- msg.Topic = string(buf.Next(int(_byte)))
- count += 1 + int(_byte)
-
- // 17. properties
- var propertiesLength int16
- binary.Read(buf, binary.BigEndian, &propertiesLength)
- if propertiesLength > 0 {
- msg.Properties =
unmarshalProperties(buf.Next(int(propertiesLength)))
- }
- count += 2 + int(propertiesLength)
-
- msg.MsgId = createMessageId(hostBytes, port,
msg.CommitLogOffset)
- //count += 16
-
- msgs = append(msgs, msg)
- }
-
- return msgs
-}
-
-func createMessageId(addr []byte, port int32, offset int64) string {
- buffer := new(bytes.Buffer)
- buffer.Write(addr)
- binary.Write(buffer, binary.BigEndian, port)
- binary.Write(buffer, binary.BigEndian, offset)
- return strings.ToUpper(hex.EncodeToString(buffer.Bytes()))
-}
-
-// unmarshalProperties parse data into property kv pairs.
-func unmarshalProperties(data []byte) map[string]string {
- m := make(map[string]string)
- items := bytes.Split(data, []byte{propertySeparator})
- for _, item := range items {
- kv := bytes.Split(item, []byte{nameValueSeparator})
- if len(kv) == 2 {
- m[string(kv[0])] = string(kv[1])
- }
- }
- return m
-}
-
-func MarshalPropeties(properties map[string]string) string {
- if properties == nil {
- return ""
- }
- buffer := bytes.NewBufferString("")
-
- for k, v := range properties {
- buffer.WriteString(k)
- buffer.WriteRune(nameValueSeparator)
- buffer.WriteString(v)
- buffer.WriteRune(propertySeparator)
- }
- return buffer.String()
-}
-
func toMessages(messageExts []*MessageExt) []*Message {
msgs := make([]*Message, 0)
diff --git a/producer/interceptor.go b/producer/interceptor.go
index 46f8042..d545f93 100644
--- a/producer/interceptor.go
+++ b/producer/interceptor.go
@@ -81,7 +81,7 @@ func newTraceInterceptor(traceCfg primitive.TraceConfig)
primitive.Interceptor {
}
traceCtx := internal.TraceContext{
- RequestId: internal.CreateUniqID(), // set id
+ RequestId: primitive.CreateUniqID(), // set id
TimeStamp: time.Now().UnixNano() /
int64(time.Millisecond),
TraceType: internal.Pub,
diff --git a/producer/option.go b/producer/option.go
index ad6c98e..5ec003c 100644
--- a/producer/option.go
+++ b/producer/option.go
@@ -18,14 +18,17 @@ limitations under the License.
package producer
import (
+ "time"
+
"github.com/apache/rocketmq-client-go/internal"
"github.com/apache/rocketmq-client-go/primitive"
)
func defaultProducerOptions() producerOptions {
opts := producerOptions{
- ClientOptions: internal.DefaultClientOptions(),
- Selector: NewRoundRobinQueueSelector(),
+ ClientOptions: internal.DefaultClientOptions(),
+ Selector: NewRoundRobinQueueSelector(),
+ SendMsgTimeout: 3 * time.Second,
}
opts.ClientOptions.GroupName = "DEFAULT_CONSUMER"
return opts
@@ -33,7 +36,8 @@ func defaultProducerOptions() producerOptions {
type producerOptions struct {
internal.ClientOptions
- Selector QueueSelector
+ Selector QueueSelector
+ SendMsgTimeout time.Duration
}
type Option func(*producerOptions)
@@ -57,6 +61,12 @@ func WithNameServer(nameServers []string) Option {
}
}
+func WithSendMsgTimeout(duration time.Duration) Option {
+ return func(opts *producerOptions) {
+ opts.SendMsgTimeout = duration
+ }
+}
+
func WithVIPChannel(enable bool) Option {
return func(opts *producerOptions) {
opts.VIPChannelEnabled = enable
diff --git a/producer/producer.go b/producer/producer.go
index 7a3f40f..1d082e9 100644
--- a/producer/producer.go
+++ b/producer/producer.go
@@ -20,15 +20,17 @@ package producer
import (
"context"
"fmt"
+ "strconv"
"sync"
"time"
+ "github.com/pkg/errors"
+
"github.com/apache/rocketmq-client-go/internal"
"github.com/apache/rocketmq-client-go/internal/remote"
"github.com/apache/rocketmq-client-go/internal/utils"
"github.com/apache/rocketmq-client-go/primitive"
"github.com/apache/rocketmq-client-go/rlog"
- "github.com/pkg/errors"
)
var (
@@ -37,6 +39,17 @@ var (
ErrNotRunning = errors.New("producer not started")
)
+type defaultProducer struct {
+ group string
+ client internal.RMQClient
+ state internal.ServiceState
+ options producerOptions
+ publishInfo sync.Map
+ callbackCh chan interface{}
+
+ interceptor primitive.Interceptor
+}
+
func NewDefaultProducer(opts ...Option) (*defaultProducer, error) {
defaultOpts := defaultProducerOptions()
for _, apply := range opts {
@@ -49,26 +62,17 @@ func NewDefaultProducer(opts ...Option) (*defaultProducer,
error) {
internal.RegisterNamsrv(srvs)
producer := &defaultProducer{
- group: "default",
- client:
internal.GetOrNewRocketMQClient(defaultOpts.ClientOptions),
- options: defaultOpts,
+ group: defaultOpts.GroupName,
+ callbackCh: make(chan interface{}),
+ options: defaultOpts,
}
+ producer.client =
internal.GetOrNewRocketMQClient(defaultOpts.ClientOptions, producer.callbackCh)
producer.interceptor =
primitive.ChainInterceptors(producer.options.Interceptors...)
return producer, nil
}
-type defaultProducer struct {
- group string
- client internal.RMQClient
- state internal.ServiceState
- options producerOptions
- publishInfo sync.Map
-
- interceptor primitive.Interceptor
-}
-
func (p *defaultProducer) Start() error {
p.state = internal.StateRunning
p.client.RegisterProducer(p.group, p)
@@ -247,11 +251,23 @@ func (p *defaultProducer) sendOneWay(ctx context.Context,
msg *primitive.Message
func (p *defaultProducer) buildSendRequest(mq *primitive.MessageQueue,
msg *primitive.Message) *remote.RemotingCommand {
+ if !msg.Batch &&
msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex] == "" {
+ msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex]
= primitive.CreateUniqID()
+ }
+ sysFlag := 0
+ v, ok := msg.Properties[primitive.PropertyTransactionPrepared]
+ if ok {
+ tranMsg, err := strconv.ParseBool(v)
+ if err == nil && tranMsg {
+ sysFlag |= primitive.TransactionPreparedType
+ }
+ }
+
req := &internal.SendMessageRequest{
ProducerGroup: p.group,
Topic: mq.Topic,
QueueId: mq.QueueId,
- SysFlag: 0,
+ SysFlag: sysFlag,
BornTimestamp: time.Now().UnixNano() / int64(time.Millisecond),
Flag: msg.Flag,
Properties: primitive.MarshalPropeties(msg.Properties),
@@ -318,3 +334,143 @@ func (p *defaultProducer) IsPublishTopicNeedUpdate(topic
string) bool {
func (p *defaultProducer) IsUnitMode() bool {
return false
}
+
+type transactionProducer struct {
+ producer *defaultProducer
+ listener primitive.TransactionListener
+}
+
+// TODO: checkLocalTransaction
+func NewTransactionProducer(listener primitive.TransactionListener, opts
...Option) (*transactionProducer, error) {
+ producer, err := NewDefaultProducer(opts...)
+ if err != nil {
+ return nil, errors.Wrap(err, "NewDefaultProducer failed.")
+ }
+ return &transactionProducer{
+ producer: producer,
+ listener: listener,
+ }, nil
+}
+
+func (tp *transactionProducer) Start() error {
+ go tp.checkTransactionState()
+ return tp.producer.Start()
+}
+func (tp *transactionProducer) Shutdown() error {
+ return tp.producer.Shutdown()
+}
+
+// TODO: check addr
+func (tp *transactionProducer) checkTransactionState() {
+ for ch := range tp.producer.callbackCh {
+ switch callback := ch.(type) {
+ case internal.CheckTransactionStateCallback:
+ localTransactionState :=
tp.listener.CheckLocalTransaction(callback.Msg)
+ uniqueKey, existed :=
callback.Msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex]
+ if !existed {
+ uniqueKey = callback.Msg.MsgId
+ }
+ header := &internal.EndTransactionRequestHeader{
+ CommitLogOffset:
callback.Header.CommitLogOffset,
+ ProducerGroup: tp.producer.group,
+ TranStateTableOffset:
callback.Header.TranStateTableOffset,
+ FromTransactionCheck: true,
+ MsgID: uniqueKey,
+ TransactionId:
callback.Header.TransactionId,
+ CommitOrRollback:
tp.transactionState(localTransactionState),
+ }
+
+ req :=
remote.NewRemotingCommand(internal.ReqENDTransaction, header, nil)
+ req.Remark = tp.errRemark(nil)
+
+ tp.producer.client.InvokeOneWay(callback.Addr.String(),
req, tp.producer.options.SendMsgTimeout)
+ default:
+ rlog.Error("unknow type %v", ch)
+ }
+ }
+}
+
+func (tp *transactionProducer) SendMessageInTransaction(ctx context.Context,
msg *primitive.Message) (*primitive.TransactionSendResult, error) {
+ if msg.Properties == nil {
+ msg.Properties = make(map[string]string, 0)
+ }
+ msg.Properties[primitive.PropertyTransactionPrepared] = "true"
+ msg.Properties[primitive.PropertyProducerGroup] =
tp.producer.options.GroupName
+
+ rsp, err := tp.producer.SendSync(ctx, msg)
+ if err != nil {
+ return nil, err
+ }
+ localTransactionState := primitive.UnknowState
+ switch rsp.Status {
+ case primitive.SendOK:
+ if len(rsp.TransactionID) > 0 {
+ msg.Properties["__transactionId__"] = rsp.TransactionID
+ }
+ transactionId :=
msg.Properties[primitive.PropertyUniqueClientMessageIdKeyIndex]
+ if len(transactionId) > 0 {
+ msg.TransactionId = transactionId
+ }
+ localTransactionState =
tp.listener.ExecuteLocalTransaction(*msg)
+ if localTransactionState != primitive.CommitMessageState {
+ rlog.Errorf("executeLocalTransactionBranch return %v
with msg: %v\n", localTransactionState, msg)
+ }
+
+ case primitive.SendFlushDiskTimeout, primitive.SendFlushSlaveTimeout,
primitive.SendSlaveNotAvailable:
+ localTransactionState = primitive.RollbackMessageState
+ default:
+ }
+
+ tp.endTransaction(*rsp, err, localTransactionState)
+
+ transactionSendResult := &primitive.TransactionSendResult{
+ SendResult: rsp,
+ State: localTransactionState,
+ }
+
+ return transactionSendResult, nil
+}
+
+func (tp *transactionProducer) endTransaction(result primitive.SendResult, err
error, state primitive.LocalTransactionState) error {
+ var msgID *primitive.MessageID
+ if len(result.OffsetMsgID) > 0 {
+ msgID, _ = primitive.UnmarshalMsgID([]byte(result.OffsetMsgID))
+ } else {
+ msgID, _ = primitive.UnmarshalMsgID([]byte(result.MsgID))
+ }
+ // 估计没有反序列化回来
+ brokerAddr :=
internal.FindBrokerAddrByName(result.MessageQueue.BrokerName)
+ requestHeader := &internal.EndTransactionRequestHeader{
+ TransactionId: result.TransactionID,
+ CommitLogOffset: msgID.Offset,
+ ProducerGroup: tp.producer.group,
+ TranStateTableOffset: result.QueueOffset,
+ MsgID: result.MsgID,
+ CommitOrRollback: tp.transactionState(state),
+ }
+
+ req := remote.NewRemotingCommand(internal.ReqENDTransaction,
requestHeader, nil)
+ req.Remark = tp.errRemark(err)
+
+ return tp.producer.client.InvokeOneWay(brokerAddr, req,
tp.producer.options.SendMsgTimeout)
+}
+
+func (tp *transactionProducer) errRemark(err error) string {
+ if err != nil {
+ return "executeLocalTransactionBranch exception: " + err.Error()
+ }
+ return ""
+}
+
+func (tp *transactionProducer) transactionState(state
primitive.LocalTransactionState) int {
+ switch state {
+ case primitive.CommitMessageState:
+ return primitive.TransactionCommitType
+ case primitive.RollbackMessageState:
+ return primitive.TransactionRollbackType
+ case primitive.UnknowState:
+ return primitive.TransactionNotType
+ default:
+ return primitive.TransactionNotType
+ }
+}