This is an automated email from the ASF dual-hosted git repository.
xyz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
The following commit(s) were added to refs/heads/master by this push:
new 35076ac0 Support acknowledging a list of message IDs (#1301)
35076ac0 is described below
commit 35076ac09138b5fa11a508ccc0ef8cd1b1993347
Author: Yunze Xu <[email protected]>
AuthorDate: Wed Nov 6 13:34:23 2024 +0800
Support acknowledging a list of message IDs (#1301)
---
pulsar/ack_grouping_tracker.go | 60 ++++++----
pulsar/consumer.go | 31 +++++
pulsar/consumer_impl.go | 10 ++
pulsar/consumer_multitopic.go | 43 +++++++
pulsar/consumer_multitopic_test.go | 99 ++++++++++++++++
pulsar/consumer_partition.go | 113 +++++++++++++++++-
pulsar/consumer_regex.go | 12 ++
pulsar/consumer_test.go | 132 +++++++++++++++++++++
pulsar/consumer_zero_queue.go | 4 +
pulsar/impl_message.go | 6 +
.../pulsartracing/consumer_interceptor_test.go | 4 +
11 files changed, 486 insertions(+), 28 deletions(-)
diff --git a/pulsar/ack_grouping_tracker.go b/pulsar/ack_grouping_tracker.go
index c4ecc003..bb9059ac 100644
--- a/pulsar/ack_grouping_tracker.go
+++ b/pulsar/ack_grouping_tracker.go
@@ -62,7 +62,7 @@ func newAckGroupingTracker(options *AckGroupingOptions,
maxNumAcks: int(options.MaxSize),
ackCumulative: ackCumulative,
ackList: ackList,
- pendingAcks: make(map[[2]uint64]*bitset.BitSet),
+ pendingAcks: make(map[position]*bitset.BitSet),
lastCumulativeAck: EarliestMessageID(),
}
@@ -110,6 +110,15 @@ func (i *immediateAckGroupingTracker) flushAndClean() {
func (i *immediateAckGroupingTracker) close() {
}
+type position struct {
+ ledgerID uint64
+ entryID uint64
+}
+
+func newPosition(msgID MessageID) position {
+ return position{ledgerID: uint64(msgID.LedgerID()), entryID:
uint64(msgID.EntryID())}
+}
+
type timedAckGroupingTracker struct {
sync.RWMutex
@@ -124,7 +133,7 @@ type timedAckGroupingTracker struct {
// in the batch whose batch size is 3 are not acknowledged.
// After the 1st message (i.e. batch index is 0) is acknowledged, the
bits will become "011".
// Value is nil if the entry represents a single message.
- pendingAcks map[[2]uint64]*bitset.BitSet
+ pendingAcks map[position]*bitset.BitSet
lastCumulativeAck MessageID
cumulativeAckRequired int32
@@ -138,35 +147,36 @@ func (t *timedAckGroupingTracker) add(id MessageID) {
}
}
-func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID)
map[[2]uint64]*bitset.BitSet {
- t.Lock()
- defer t.Unlock()
- key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())}
-
+func addMsgIDToPendingAcks(pendingAcks map[position]*bitset.BitSet, id
MessageID) {
+ key := newPosition(id)
batchIdx := id.BatchIdx()
batchSize := id.BatchSize()
if batchIdx >= 0 && batchSize > 0 {
- bs, found := t.pendingAcks[key]
+ bs, found := pendingAcks[key]
if !found {
- if batchSize > 1 {
- bs = bitset.New(uint(batchSize))
- for i := uint(0); i < uint(batchSize); i++ {
- bs.Set(i)
- }
+ bs = bitset.New(uint(batchSize))
+ for i := uint(0); i < uint(batchSize); i++ {
+ bs.Set(i)
}
- t.pendingAcks[key] = bs
+ pendingAcks[key] = bs
}
if bs != nil {
bs.Clear(uint(batchIdx))
}
} else {
- t.pendingAcks[key] = nil
+ pendingAcks[key] = nil
}
+}
+func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID)
map[position]*bitset.BitSet {
+ t.Lock()
+ defer t.Unlock()
+
+ addMsgIDToPendingAcks(t.pendingAcks, id)
if len(t.pendingAcks) >= t.maxNumAcks {
pendingAcks := t.pendingAcks
- t.pendingAcks = make(map[[2]uint64]*bitset.BitSet)
+ t.pendingAcks = make(map[position]*bitset.BitSet)
return pendingAcks
}
return nil
@@ -195,7 +205,7 @@ func (t *timedAckGroupingTracker) isDuplicate(id MessageID)
bool {
if messageIDCompare(t.lastCumulativeAck, id) >= 0 {
return true
}
- key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())}
+ key := newPosition(id)
if bs, found := t.pendingAcks[key]; found {
if bs == nil {
return true
@@ -232,11 +242,11 @@ func (t *timedAckGroupingTracker) flushAndClean() {
}
}
-func (t *timedAckGroupingTracker) clearPendingAcks()
map[[2]uint64]*bitset.BitSet {
+func (t *timedAckGroupingTracker) clearPendingAcks()
map[position]*bitset.BitSet {
t.Lock()
defer t.Unlock()
pendingAcks := t.pendingAcks
- t.pendingAcks = make(map[[2]uint64]*bitset.BitSet)
+ t.pendingAcks = make(map[position]*bitset.BitSet)
return pendingAcks
}
@@ -250,12 +260,10 @@ func (t *timedAckGroupingTracker) close() {
}
}
-func (t *timedAckGroupingTracker) flushIndividual(pendingAcks
map[[2]uint64]*bitset.BitSet) {
+func toMsgIDDataList(pendingAcks map[position]*bitset.BitSet)
[]*pb.MessageIdData {
msgIDs := make([]*pb.MessageIdData, 0, len(pendingAcks))
for k, v := range pendingAcks {
- ledgerID := k[0]
- entryID := k[1]
- msgID := &pb.MessageIdData{LedgerId: &ledgerID, EntryId:
&entryID}
+ msgID := &pb.MessageIdData{LedgerId: &k.ledgerID, EntryId:
&k.entryID}
if v != nil && !v.None() {
bytes := v.Bytes()
msgID.AckSet = make([]int64, len(bytes))
@@ -265,5 +273,9 @@ func (t *timedAckGroupingTracker)
flushIndividual(pendingAcks map[[2]uint64]*bit
}
msgIDs = append(msgIDs, msgID)
}
- t.ackList(msgIDs)
+ return msgIDs
+}
+
+func (t *timedAckGroupingTracker) flushIndividual(pendingAcks
map[position]*bitset.BitSet) {
+ t.ackList(toMsgIDDataList(pendingAcks))
}
diff --git a/pulsar/consumer.go b/pulsar/consumer.go
index 880cad56..7aee9645 100644
--- a/pulsar/consumer.go
+++ b/pulsar/consumer.go
@@ -19,6 +19,8 @@ package pulsar
import (
"context"
+ "fmt"
+ "strings"
"time"
"github.com/apache/pulsar-client-go/pulsar/backoff"
@@ -266,6 +268,23 @@ type ConsumerOptions struct {
startMessageID *trackingMessageID
}
+// This error is returned when `AckIDList` failed and `AckWithResponse` is
true.
+// It only contains the valid message IDs that failed to be acknowledged in
the `AckIDList` call.
+// For those invalid message IDs, users should ignore them and not acknowledge
them again.
+type AckError map[MessageID]error
+
+func (e AckError) Error() string {
+ builder := strings.Builder{}
+ errorMap := make(map[string][]MessageID)
+ for id, err := range e {
+ errorMap[err.Error()] = append(errorMap[err.Error()], id)
+ }
+ for err, msgIDs := range errorMap {
+ builder.WriteString(fmt.Sprintf("error: %s, failed message IDs:
%v\n", err, msgIDs))
+ }
+ return builder.String()
+}
+
// Consumer is an interface that abstracts behavior of Pulsar's consumer
type Consumer interface {
// Subscription get a subscription for the consumer
@@ -305,8 +324,20 @@ type Consumer interface {
Ack(Message) error
// AckID the consumption of a single message, identified by its
MessageID
+ // When `EnableBatchIndexAcknowledgment` is false, if a message ID
represents a message in the batch,
+ // it will not be actually acknowledged by broker until all messages in
that batch are acknowledged via
+ // `AckID` or `AckIDList`.
AckID(MessageID) error
+ // AckIDList the consumption of a list of messages, identified by their
MessageIDs
+ //
+ // This method should be used when `AckWithResponse` is true.
Otherwise, it will be equivalent with calling
+ // `AckID` on each message ID in the list.
+ //
+ // When `AckWithResponse` is true, the returned error could be an
`AckError` which contains the failed message ID
+ // and the corresponding error.
+ AckIDList([]MessageID) error
+
// AckWithTxn the consumption of a single message with a transaction
AckWithTxn(Message, Transaction) error
diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index 740a7df9..eafa4b47 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -41,6 +41,7 @@ const defaultNackRedeliveryDelay = 1 * time.Minute
type acker interface {
// AckID does not handle errors returned by the Broker side, so no need
to wait for doneCh to finish.
AckID(id MessageID) error
+ AckIDList(msgIDs []MessageID) error
AckIDWithResponse(id MessageID) error
AckIDWithTxn(msgID MessageID, txn Transaction) error
AckIDCumulative(msgID MessageID) error
@@ -559,6 +560,15 @@ func (c *consumer) AckID(msgID MessageID) error {
return c.consumers[msgID.PartitionIdx()].AckID(msgID)
}
+func (c *consumer) AckIDList(msgIDs []MessageID) error {
+ return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID)
(acker, error) {
+ if err := c.checkMsgIDPartition(msgID); err != nil {
+ return nil, err
+ }
+ return c.consumers[msgID.PartitionIdx()], nil
+ })
+}
+
// AckCumulative the reception of all the messages in the stream up to (and
including)
// the provided message, identified by its MessageID
func (c *consumer) AckCumulative(msg Message) error {
diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go
index 3030bda3..26430add 100644
--- a/pulsar/consumer_multitopic.go
+++ b/pulsar/consumer_multitopic.go
@@ -167,6 +167,49 @@ func (c *multiTopicConsumer) AckID(msgID MessageID) error {
return mid.consumer.AckID(msgID)
}
+func (c *multiTopicConsumer) AckIDList(msgIDs []MessageID) error {
+ return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID)
(acker, error) {
+ if !checkMessageIDType(msgID) {
+ return nil, fmt.Errorf("invalid message id type %T",
msgID)
+ }
+ if mid := toTrackingMessageID(msgID); mid != nil &&
mid.consumer != nil {
+ return mid.consumer, nil
+ }
+ return nil, errors.New("consumer is nil")
+ })
+}
+
+func ackIDListFromMultiTopics(log log.Logger, msgIDs []MessageID, findConsumer
func(MessageID) (acker, error)) error {
+ consumerToMsgIDs := make(map[acker][]MessageID)
+ for _, msgID := range msgIDs {
+ if consumer, err := findConsumer(msgID); err == nil {
+ consumerToMsgIDs[consumer] =
append(consumerToMsgIDs[consumer], msgID)
+ } else {
+ log.Warnf("Can not find consumer for %v", msgID)
+ }
+ }
+
+ ackError := AckError{}
+ for consumer, ids := range consumerToMsgIDs {
+ if err := consumer.AckIDList(ids); err != nil {
+ if topicAckError := err.(AckError); topicAckError !=
nil {
+ for id, err := range topicAckError {
+ ackError[id] = err
+ }
+ } else {
+ // It should not reach here
+ for _, id := range ids {
+ ackError[id] = err
+ }
+ }
+ }
+ }
+ if len(ackError) == 0 {
+ return nil
+ }
+ return ackError
+}
+
// AckWithTxn the consumption of a single message with a transaction
func (c *multiTopicConsumer) AckWithTxn(msg Message, txn Transaction) error {
msgID := msg.ID()
diff --git a/pulsar/consumer_multitopic_test.go
b/pulsar/consumer_multitopic_test.go
index 7c6b898a..4b3ec908 100644
--- a/pulsar/consumer_multitopic_test.go
+++ b/pulsar/consumer_multitopic_test.go
@@ -21,6 +21,7 @@ import (
"fmt"
"strings"
"testing"
+ "time"
"github.com/apache/pulsar-client-go/pulsaradmin"
"github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config"
@@ -218,3 +219,101 @@ func TestMultiTopicGetLastMessageIDs(t *testing.T) {
}
}
+
+func TestMultiTopicAckIDList(t *testing.T) {
+ for _, params := range []bool{true, false} {
+ t.Run(fmt.Sprintf("TestMultiTopicConsumerAckIDList%v", params),
func(t *testing.T) {
+ runMultiTopicAckIDList(t, params)
+ })
+ }
+}
+
+func runMultiTopicAckIDList(t *testing.T, regex bool) {
+ topicPrefix := fmt.Sprintf("multiTopicAckIDList%v",
time.Now().UnixNano())
+ topic1 := "persistent://public/default/" + topicPrefix + "1"
+ topic2 := "persistent://public/default/" + topicPrefix + "2"
+
+ client, err := NewClient(ClientOptions{URL: "pulsar://localhost:6650"})
+ assert.Nil(t, err)
+ defer client.Close()
+
+ if regex {
+ admin, err := pulsaradmin.NewClient(&config.Config{})
+ assert.Nil(t, err)
+ for _, topic := range []string{topic1, topic2} {
+ topicName, err := utils.GetTopicName(topic)
+ assert.Nil(t, err)
+ admin.Topics().Create(*topicName, 0)
+ }
+ }
+
+ createConsumer := func() Consumer {
+ options := ConsumerOptions{
+ SubscriptionName: "sub",
+ Type: Shared,
+ AckWithResponse: true,
+ }
+ if regex {
+ options.TopicsPattern = topicPrefix + ".*"
+ } else {
+ options.Topics = []string{topic1, topic2}
+ }
+ consumer, err := client.Subscribe(options)
+ assert.Nil(t, err)
+ return consumer
+ }
+ consumer := createConsumer()
+
+ sendMessages(t, client, topic1, 0, 3, false)
+ sendMessages(t, client, topic2, 0, 2, false)
+
+ receiveMessageMap := func(consumer Consumer, numMessages int)
map[string][]Message {
+ msgs := receiveMessages(t, consumer, numMessages)
+ topicToMsgs := make(map[string][]Message)
+ for _, msg := range msgs {
+ topicToMsgs[msg.Topic()] =
append(topicToMsgs[msg.Topic()], msg)
+ }
+ return topicToMsgs
+ }
+
+ topicToMsgs := receiveMessageMap(consumer, 5)
+ assert.Equal(t, 3, len(topicToMsgs[topic1]))
+ for i := 0; i < 3; i++ {
+ assert.Equal(t, fmt.Sprintf("msg-%d", i),
string(topicToMsgs[topic1][i].Payload()))
+ }
+ assert.Equal(t, 2, len(topicToMsgs[topic2]))
+ for i := 0; i < 2; i++ {
+ assert.Equal(t, fmt.Sprintf("msg-%d", i),
string(topicToMsgs[topic2][i].Payload()))
+ }
+
+ assert.Nil(t, consumer.AckIDList([]MessageID{
+ topicToMsgs[topic1][0].ID(),
+ topicToMsgs[topic1][2].ID(),
+ topicToMsgs[topic2][1].ID(),
+ }))
+
+ consumer.Close()
+ consumer = createConsumer()
+ topicToMsgs = receiveMessageMap(consumer, 2)
+ assert.Equal(t, 1, len(topicToMsgs[topic1]))
+ assert.Equal(t, "msg-1", string(topicToMsgs[topic1][0].Payload()))
+ assert.Equal(t, 1, len(topicToMsgs[topic2]))
+ assert.Equal(t, "msg-0", string(topicToMsgs[topic2][0].Payload()))
+ consumer.Close()
+
+ msgID0 := topicToMsgs[topic1][0].ID()
+ err = consumer.AckIDList([]MessageID{msgID0})
+ assert.NotNil(t, err)
+ t.Logf("AckIDList error: %v", err)
+
+ msgID1 := topicToMsgs[topic2][0].ID()
+ if ackError, ok := consumer.AckIDList([]MessageID{msgID0,
msgID1}).(AckError); ok {
+ assert.Equal(t, 2, len(ackError))
+ assert.Contains(t, ackError, msgID0)
+ assert.Equal(t, "consumer state is closed",
ackError[msgID0].Error())
+ assert.Contains(t, ackError, msgID1)
+ assert.Equal(t, "consumer state is closed",
ackError[msgID1].Error())
+ } else {
+ assert.Fail(t, "AckIDList should return AckError")
+ }
+}
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index 4e8fba5a..471d45a3 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -198,6 +198,10 @@ func (pc *partitionConsumer) pauseDispatchMessage() {
pc.dispatcherSeekingControlCh <- struct{}{}
}
+func (pc *partitionConsumer) Topic() string {
+ return pc.topic
+}
+
func (pc *partitionConsumer) ActiveConsumerChanged(isActive bool) {
listener := pc.options.consumerEventListener
if listener == nil {
@@ -375,7 +379,12 @@ func newPartitionConsumer(parent Consumer, client *client,
options *partitionCon
pc.ackGroupingTracker =
newAckGroupingTracker(options.ackGroupingOptions,
func(id MessageID) { pc.sendIndividualAck(id) },
func(id MessageID) { pc.sendCumulativeAck(id) },
- func(ids []*pb.MessageIdData) { pc.eventsCh <- ids })
+ func(ids []*pb.MessageIdData) {
+ pc.eventsCh <- &ackListRequest{
+ errCh: nil, // ignore the error
+ msgIDs: ids,
+ }
+ })
pc.setConsumerState(consumerInit)
pc.log = client.log.SubLogger(log.Fields{
"name": pc.name,
@@ -695,6 +704,86 @@ func (pc *partitionConsumer) AckID(msgID MessageID) error {
return pc.ackID(msgID, false)
}
+func (pc *partitionConsumer) AckIDList(msgIDs []MessageID) error {
+ if !pc.options.ackWithResponse {
+ for _, msgID := range msgIDs {
+ if err := pc.AckID(msgID); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+
+ chunkedMsgIDs := make([]*chunkMessageID, 0) // we need to remove them
after acknowledging
+ pendingAcks := make(map[position]*bitset.BitSet)
+ validMsgIDs := make([]MessageID, 0, len(msgIDs))
+
+ // They might be complete after the whole for loop
+ for _, msgID := range msgIDs {
+ if msgID.PartitionIdx() != pc.partitionIdx {
+ pc.log.Errorf("%v inconsistent partition index %v
(current: %v)", msgID, msgID.PartitionIdx(), pc.partitionIdx)
+ } else if msgID.BatchIdx() >= 0 && msgID.BatchSize() > 0 &&
+ msgID.BatchIdx() >= msgID.BatchSize() {
+ pc.log.Errorf("%v invalid batch index %v (size: %v)",
msgID, msgID.BatchIdx(), msgID.BatchSize())
+ } else {
+ valid := true
+ switch convertedMsgID := msgID.(type) {
+ case *trackingMessageID:
+ position := newPosition(msgID)
+ if convertedMsgID.ack() {
+ pendingAcks[position] = nil
+ } else if pc.options.enableBatchIndexAck {
+ pendingAcks[position] =
convertedMsgID.tracker.getAckBitSet()
+ }
+ case *chunkMessageID:
+ for _, id := range
pc.unAckChunksTracker.get(convertedMsgID) {
+ pendingAcks[newPosition(id)] = nil
+ }
+ chunkedMsgIDs = append(chunkedMsgIDs,
convertedMsgID)
+ case *messageID:
+ pendingAcks[newPosition(msgID)] = nil
+ default:
+ pc.log.Errorf("invalid message id type %T: %v",
msgID, msgID)
+ valid = false
+ }
+ if valid {
+ validMsgIDs = append(validMsgIDs, msgID)
+ }
+ }
+ }
+
+ if state := pc.getConsumerState(); state == consumerClosed || state ==
consumerClosing {
+ pc.log.WithField("state", state).Error("Failed to ack by
closing or closed consumer")
+ return toAckError(map[error][]MessageID{errors.New("consumer
state is closed"): validMsgIDs})
+ }
+
+ req := &ackListRequest{
+ errCh: make(chan error),
+ msgIDs: toMsgIDDataList(pendingAcks),
+ }
+ pc.eventsCh <- req
+ if err := <-req.errCh; err != nil {
+ return toAckError(map[error][]MessageID{err: validMsgIDs})
+ }
+ for _, id := range chunkedMsgIDs {
+ pc.unAckChunksTracker.remove(id)
+ }
+ for _, id := range msgIDs {
+ pc.options.interceptors.OnAcknowledge(pc.parentConsumer, id)
+ }
+ return nil
+}
+
+func toAckError(errorMap map[error][]MessageID) AckError {
+ e := AckError{}
+ for err, ids := range errorMap {
+ for _, id := range ids {
+ e[id] = err
+ }
+ }
+ return e
+}
+
func (pc *partitionConsumer) AckIDCumulative(msgID MessageID) error {
if !checkMessageIDType(msgID) {
pc.log.Errorf("invalid message id type %T", msgID)
@@ -1027,11 +1116,22 @@ func (pc *partitionConsumer) internalAck(req
*ackRequest) {
}
}
-func (pc *partitionConsumer) internalAckList(msgIDs []*pb.MessageIdData) {
+func (pc *partitionConsumer) internalAckList(request *ackListRequest) {
+ if request.errCh != nil {
+ reqID := pc.client.rpcClient.NewRequestID()
+ _, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(),
reqID, pb.BaseCommand_ACK, &pb.CommandAck{
+ AckType: pb.CommandAck_Individual.Enum(),
+ ConsumerId: proto.Uint64(pc.consumerID),
+ MessageId: request.msgIDs,
+ RequestId: &reqID,
+ })
+ request.errCh <- err
+ return
+ }
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(),
pb.BaseCommand_ACK, &pb.CommandAck{
AckType: pb.CommandAck_Individual.Enum(),
ConsumerId: proto.Uint64(pc.consumerID),
- MessageId: msgIDs,
+ MessageId: request.msgIDs,
})
}
@@ -1563,6 +1663,11 @@ type ackRequest struct {
err error
}
+type ackListRequest struct {
+ errCh chan error
+ msgIDs []*pb.MessageIdData
+}
+
type ackWithTxnRequest struct {
doneCh chan struct{}
msgID trackingMessageID
@@ -1623,7 +1728,7 @@ func (pc *partitionConsumer) runEventsLoop() {
pc.internalAck(v)
case *ackWithTxnRequest:
pc.internalAckWithTxn(v)
- case []*pb.MessageIdData:
+ case *ackListRequest:
pc.internalAckList(v)
case *redeliveryRequest:
pc.internalRedeliver(v)
diff --git a/pulsar/consumer_regex.go b/pulsar/consumer_regex.go
index 58cfa80f..ced770e9 100644
--- a/pulsar/consumer_regex.go
+++ b/pulsar/consumer_regex.go
@@ -215,6 +215,18 @@ func (c *regexConsumer) AckID(msgID MessageID) error {
return mid.consumer.AckID(msgID)
}
+func (c *regexConsumer) AckIDList(msgIDs []MessageID) error {
+ return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID)
(acker, error) {
+ if !checkMessageIDType(msgID) {
+ return nil, fmt.Errorf("invalid message id type %T",
msgID)
+ }
+ if mid := toTrackingMessageID(msgID); mid.consumer != nil {
+ return mid.consumer, nil
+ }
+ return nil, errors.New("consumer is nil in consumer_regex")
+ })
+}
+
// AckID the consumption of a single message, identified by its MessageID
func (c *regexConsumer) AckWithTxn(msg Message, txn Transaction) error {
msgID := msg.ID()
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index d8f31458..2524f681 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -4745,3 +4745,135 @@ func TestLookupConsumer(t *testing.T) {
consumer.Ack(msg)
}
}
+
+func TestAckIDList(t *testing.T) {
+ for _, params := range []bool{true, false} {
+ t.Run(fmt.Sprintf("TestAckIDList_%v", params), func(t
*testing.T) {
+ runAckIDListTest(t, params)
+ })
+ }
+}
+
+func runAckIDListTest(t *testing.T, enableBatchIndexAck bool) {
+ client, err := NewClient(ClientOptions{URL: lookupURL})
+ assert.Nil(t, err)
+ defer client.Close()
+
+ topic := fmt.Sprintf("test-ack-id-list-%v", time.Now().Nanosecond())
+
+ consumer := createSharedConsumer(t, client, topic, enableBatchIndexAck)
+ sendMessages(t, client, topic, 0, 5, true) // entry 0: [0, 1, 2, 3, 4]
+ sendMessages(t, client, topic, 5, 3, false) // entry 2: [5], 3: [6], 4:
[7]
+ sendMessages(t, client, topic, 8, 2, true) // entry 5: [8, 9]
+
+ msgs := receiveMessages(t, consumer, 10)
+ originalMsgIDs := make([]MessageID, 0)
+ for i := 0; i < 10; i++ {
+ originalMsgIDs = append(originalMsgIDs, msgs[i].ID())
+ assert.Equal(t, fmt.Sprintf("msg-%d", i),
string(msgs[i].Payload()))
+ }
+
+ ackedIndexes := []int{0, 2, 3, 6, 8, 9}
+ unackedIndexes := []int{1, 4, 5, 7}
+ if !enableBatchIndexAck {
+ // [0, 4] is the first batch range but only partial of it is
acked
+ unackedIndexes = []int{0, 1, 2, 3, 4, 5, 7}
+ }
+ msgIDs := make([]MessageID, len(ackedIndexes))
+ for i := 0; i < len(ackedIndexes); i++ {
+ msgIDs[i] = msgs[ackedIndexes[i]].ID()
+ }
+ assert.Nil(t, consumer.AckIDList(msgIDs))
+ consumer.Close()
+
+ consumer = createSharedConsumer(t, client, topic, enableBatchIndexAck)
+ msgs = receiveMessages(t, consumer, len(unackedIndexes))
+ for i := 0; i < len(unackedIndexes); i++ {
+ assert.Equal(t, fmt.Sprintf("msg-%d", unackedIndexes[i]),
string(msgs[i].Payload()))
+ }
+
+ if !enableBatchIndexAck {
+ msgIDs = make([]MessageID, 0)
+ for i := 0; i < 5; i++ {
+ msgIDs = append(msgIDs, originalMsgIDs[i])
+ }
+ assert.Nil(t, consumer.AckIDList(msgIDs))
+ consumer.Close()
+
+ consumer = createSharedConsumer(t, client, topic,
enableBatchIndexAck)
+ msgs = receiveMessages(t, consumer, 2)
+ assert.Equal(t, "msg-5", string(msgs[0].Payload()))
+ assert.Equal(t, "msg-7", string(msgs[1].Payload()))
+ consumer.Close()
+ }
+ consumer.Close()
+ err = consumer.AckIDList(msgIDs)
+ assert.NotNil(t, err)
+ if ackError := err.(AckError); ackError != nil {
+ assert.Equal(t, len(msgIDs), len(ackError))
+ for _, id := range msgIDs {
+ assert.Contains(t, ackError, id)
+ assert.Equal(t, "consumer state is closed",
ackError[id].Error())
+ }
+ } else {
+ assert.Fail(t, "AckIDList should return AckError")
+ }
+}
+
+func createSharedConsumer(t *testing.T, client Client, topic string,
enableBatchIndexAck bool) Consumer {
+ consumer, err := client.Subscribe(ConsumerOptions{
+ Topic: topic,
+ SubscriptionName: "my-sub",
+ SubscriptionInitialPosition: SubscriptionPositionEarliest,
+ Type: Shared,
+ EnableBatchIndexAcknowledgment: enableBatchIndexAck,
+ AckWithResponse: true,
+ })
+ assert.Nil(t, err)
+ return consumer
+}
+
+func sendMessages(t *testing.T, client Client, topic string, startIndex int,
numMessages int, batching bool) {
+ producer, err := client.CreateProducer(ProducerOptions{
+ Topic: topic,
+ DisableBatching: !batching,
+ BatchingMaxMessages: uint(numMessages),
+ BatchingMaxSize: 1024 * 1024 * 10,
+ BatchingMaxPublishDelay: 1 * time.Hour,
+ })
+ assert.Nil(t, err)
+ defer producer.Close()
+
+ ctx := context.Background()
+ for i := 0; i < numMessages; i++ {
+ msg := &ProducerMessage{Payload: []byte(fmt.Sprintf("msg-%d",
startIndex+i))}
+ if batching {
+ producer.SendAsync(ctx, msg, func(_ MessageID, _
*ProducerMessage, err error) {
+ if err != nil {
+ t.Logf("Failed to send message: %v",
err)
+ }
+ })
+ } else {
+ if _, err := producer.Send(ctx, msg); err != nil {
+ assert.Fail(t, "Failed to send message: %v",
err)
+ }
+ }
+ }
+ assert.Nil(t, producer.Flush())
+}
+
+func receiveMessages(t *testing.T, consumer Consumer, numMessages int)
[]Message {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ msgs := make([]Message, 0)
+ for i := 0; i < numMessages; i++ {
+ if msg, err := consumer.Receive(ctx); err == nil {
+ msgs = append(msgs, msg)
+ } else {
+ t.Logf("Failed to receive message: %v", err)
+ break
+ }
+ }
+ assert.Equal(t, numMessages, len(msgs))
+ return msgs
+}
diff --git a/pulsar/consumer_zero_queue.go b/pulsar/consumer_zero_queue.go
index 81171272..3f2862da 100644
--- a/pulsar/consumer_zero_queue.go
+++ b/pulsar/consumer_zero_queue.go
@@ -171,6 +171,10 @@ func (z *zeroQueueConsumer) AckID(msgID MessageID) error {
return z.pc.AckID(msgID)
}
+func (z *zeroQueueConsumer) AckIDList(msgIDs []MessageID) error {
+ return z.pc.AckIDList(msgIDs)
+}
+
func (z *zeroQueueConsumer) AckWithTxn(msg Message, txn Transaction) error {
msgID := msg.ID()
if err := z.checkMsgIDPartition(msgID); err != nil {
diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go
index 478b1af2..0acd782b 100644
--- a/pulsar/impl_message.go
+++ b/pulsar/impl_message.go
@@ -404,6 +404,12 @@ type ackTracker struct {
prevBatchAcked uint32
}
+func (t *ackTracker) getAckBitSet() *bitset.BitSet {
+ t.Lock()
+ defer t.Unlock()
+ return t.batchIDs.Clone()
+}
+
func (t *ackTracker) ack(batchID int) bool {
if batchID < 0 {
return true
diff --git a/pulsar/internal/pulsartracing/consumer_interceptor_test.go
b/pulsar/internal/pulsartracing/consumer_interceptor_test.go
index 1fa1bf0d..e7712356 100644
--- a/pulsar/internal/pulsartracing/consumer_interceptor_test.go
+++ b/pulsar/internal/pulsartracing/consumer_interceptor_test.go
@@ -79,6 +79,10 @@ func (c *mockConsumer) AckID(_ pulsar.MessageID) error {
return nil
}
+func (c *mockConsumer) AckIDList(_ []pulsar.MessageID) error {
+ return nil
+}
+
func (c *mockConsumer) AckCumulative(_ pulsar.Message) error {
return nil
}