This is an automated email from the ASF dual-hosted git repository.

xyz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 35076ac0 Support acknowledging a list of message IDs (#1301)
35076ac0 is described below

commit 35076ac09138b5fa11a508ccc0ef8cd1b1993347
Author: Yunze Xu <[email protected]>
AuthorDate: Wed Nov 6 13:34:23 2024 +0800

    Support acknowledging a list of message IDs (#1301)
---
 pulsar/ack_grouping_tracker.go                     |  60 ++++++----
 pulsar/consumer.go                                 |  31 +++++
 pulsar/consumer_impl.go                            |  10 ++
 pulsar/consumer_multitopic.go                      |  43 +++++++
 pulsar/consumer_multitopic_test.go                 |  99 ++++++++++++++++
 pulsar/consumer_partition.go                       | 113 +++++++++++++++++-
 pulsar/consumer_regex.go                           |  12 ++
 pulsar/consumer_test.go                            | 132 +++++++++++++++++++++
 pulsar/consumer_zero_queue.go                      |   4 +
 pulsar/impl_message.go                             |   6 +
 .../pulsartracing/consumer_interceptor_test.go     |   4 +
 11 files changed, 486 insertions(+), 28 deletions(-)

diff --git a/pulsar/ack_grouping_tracker.go b/pulsar/ack_grouping_tracker.go
index c4ecc003..bb9059ac 100644
--- a/pulsar/ack_grouping_tracker.go
+++ b/pulsar/ack_grouping_tracker.go
@@ -62,7 +62,7 @@ func newAckGroupingTracker(options *AckGroupingOptions,
                maxNumAcks:        int(options.MaxSize),
                ackCumulative:     ackCumulative,
                ackList:           ackList,
-               pendingAcks:       make(map[[2]uint64]*bitset.BitSet),
+               pendingAcks:       make(map[position]*bitset.BitSet),
                lastCumulativeAck: EarliestMessageID(),
        }
 
@@ -110,6 +110,15 @@ func (i *immediateAckGroupingTracker) flushAndClean() {
 func (i *immediateAckGroupingTracker) close() {
 }
 
+type position struct {
+       ledgerID uint64
+       entryID  uint64
+}
+
+func newPosition(msgID MessageID) position {
+       return position{ledgerID: uint64(msgID.LedgerID()), entryID: 
uint64(msgID.EntryID())}
+}
+
 type timedAckGroupingTracker struct {
        sync.RWMutex
 
@@ -124,7 +133,7 @@ type timedAckGroupingTracker struct {
        // in the batch whose batch size is 3 are not acknowledged.
        // After the 1st message (i.e. batch index is 0) is acknowledged, the 
bits will become "011".
        // Value is nil if the entry represents a single message.
-       pendingAcks map[[2]uint64]*bitset.BitSet
+       pendingAcks map[position]*bitset.BitSet
 
        lastCumulativeAck     MessageID
        cumulativeAckRequired int32
@@ -138,35 +147,36 @@ func (t *timedAckGroupingTracker) add(id MessageID) {
        }
 }
 
-func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID) 
map[[2]uint64]*bitset.BitSet {
-       t.Lock()
-       defer t.Unlock()
-       key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())}
-
+func addMsgIDToPendingAcks(pendingAcks map[position]*bitset.BitSet, id 
MessageID) {
+       key := newPosition(id)
        batchIdx := id.BatchIdx()
        batchSize := id.BatchSize()
 
        if batchIdx >= 0 && batchSize > 0 {
-               bs, found := t.pendingAcks[key]
+               bs, found := pendingAcks[key]
                if !found {
-                       if batchSize > 1 {
-                               bs = bitset.New(uint(batchSize))
-                               for i := uint(0); i < uint(batchSize); i++ {
-                                       bs.Set(i)
-                               }
+                       bs = bitset.New(uint(batchSize))
+                       for i := uint(0); i < uint(batchSize); i++ {
+                               bs.Set(i)
                        }
-                       t.pendingAcks[key] = bs
+                       pendingAcks[key] = bs
                }
                if bs != nil {
                        bs.Clear(uint(batchIdx))
                }
        } else {
-               t.pendingAcks[key] = nil
+               pendingAcks[key] = nil
        }
+}
 
+func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID) 
map[position]*bitset.BitSet {
+       t.Lock()
+       defer t.Unlock()
+
+       addMsgIDToPendingAcks(t.pendingAcks, id)
        if len(t.pendingAcks) >= t.maxNumAcks {
                pendingAcks := t.pendingAcks
-               t.pendingAcks = make(map[[2]uint64]*bitset.BitSet)
+               t.pendingAcks = make(map[position]*bitset.BitSet)
                return pendingAcks
        }
        return nil
@@ -195,7 +205,7 @@ func (t *timedAckGroupingTracker) isDuplicate(id MessageID) 
bool {
        if messageIDCompare(t.lastCumulativeAck, id) >= 0 {
                return true
        }
-       key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())}
+       key := newPosition(id)
        if bs, found := t.pendingAcks[key]; found {
                if bs == nil {
                        return true
@@ -232,11 +242,11 @@ func (t *timedAckGroupingTracker) flushAndClean() {
        }
 }
 
-func (t *timedAckGroupingTracker) clearPendingAcks() 
map[[2]uint64]*bitset.BitSet {
+func (t *timedAckGroupingTracker) clearPendingAcks() 
map[position]*bitset.BitSet {
        t.Lock()
        defer t.Unlock()
        pendingAcks := t.pendingAcks
-       t.pendingAcks = make(map[[2]uint64]*bitset.BitSet)
+       t.pendingAcks = make(map[position]*bitset.BitSet)
        return pendingAcks
 }
 
@@ -250,12 +260,10 @@ func (t *timedAckGroupingTracker) close() {
        }
 }
 
-func (t *timedAckGroupingTracker) flushIndividual(pendingAcks 
map[[2]uint64]*bitset.BitSet) {
+func toMsgIDDataList(pendingAcks map[position]*bitset.BitSet) 
[]*pb.MessageIdData {
        msgIDs := make([]*pb.MessageIdData, 0, len(pendingAcks))
        for k, v := range pendingAcks {
-               ledgerID := k[0]
-               entryID := k[1]
-               msgID := &pb.MessageIdData{LedgerId: &ledgerID, EntryId: 
&entryID}
+               msgID := &pb.MessageIdData{LedgerId: &k.ledgerID, EntryId: 
&k.entryID}
                if v != nil && !v.None() {
                        bytes := v.Bytes()
                        msgID.AckSet = make([]int64, len(bytes))
@@ -265,5 +273,9 @@ func (t *timedAckGroupingTracker) 
flushIndividual(pendingAcks map[[2]uint64]*bit
                }
                msgIDs = append(msgIDs, msgID)
        }
-       t.ackList(msgIDs)
+       return msgIDs
+}
+
+func (t *timedAckGroupingTracker) flushIndividual(pendingAcks 
map[position]*bitset.BitSet) {
+       t.ackList(toMsgIDDataList(pendingAcks))
 }
diff --git a/pulsar/consumer.go b/pulsar/consumer.go
index 880cad56..7aee9645 100644
--- a/pulsar/consumer.go
+++ b/pulsar/consumer.go
@@ -19,6 +19,8 @@ package pulsar
 
 import (
        "context"
+       "fmt"
+       "strings"
        "time"
 
        "github.com/apache/pulsar-client-go/pulsar/backoff"
@@ -266,6 +268,23 @@ type ConsumerOptions struct {
        startMessageID *trackingMessageID
 }
 
+// This error is returned when `AckIDList` failed and `AckWithResponse` is 
true.
+// It only contains the valid message IDs that failed to be acknowledged in 
the `AckIDList` call.
+// For those invalid message IDs, users should ignore them and not acknowledge 
them again.
+type AckError map[MessageID]error
+
+func (e AckError) Error() string {
+       builder := strings.Builder{}
+       errorMap := make(map[string][]MessageID)
+       for id, err := range e {
+               errorMap[err.Error()] = append(errorMap[err.Error()], id)
+       }
+       for err, msgIDs := range errorMap {
+               builder.WriteString(fmt.Sprintf("error: %s, failed message IDs: 
%v\n", err, msgIDs))
+       }
+       return builder.String()
+}
+
 // Consumer is an interface that abstracts behavior of Pulsar's consumer
 type Consumer interface {
        // Subscription get a subscription for the consumer
@@ -305,8 +324,20 @@ type Consumer interface {
        Ack(Message) error
 
        // AckID the consumption of a single message, identified by its 
MessageID
+       // When `EnableBatchIndexAcknowledgment` is false, if a message ID 
represents a message in the batch,
+       // it will not be actually acknowledged by broker until all messages in 
that batch are acknowledged via
+       // `AckID` or `AckIDList`.
        AckID(MessageID) error
 
+       // AckIDList the consumption of a list of messages, identified by their 
MessageIDs
+       //
+       // This method should be used when `AckWithResponse` is true. 
Otherwise, it will be equivalent with calling
+       // `AckID` on each message ID in the list.
+       //
+       // When `AckWithResponse` is true, the returned error could be an 
`AckError` which contains the failed message ID
+       // and the corresponding error.
+       AckIDList([]MessageID) error
+
        // AckWithTxn the consumption of a single message with a transaction
        AckWithTxn(Message, Transaction) error
 
diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index 740a7df9..eafa4b47 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -41,6 +41,7 @@ const defaultNackRedeliveryDelay = 1 * time.Minute
 type acker interface {
        // AckID does not handle errors returned by the Broker side, so no need 
to wait for doneCh to finish.
        AckID(id MessageID) error
+       AckIDList(msgIDs []MessageID) error
        AckIDWithResponse(id MessageID) error
        AckIDWithTxn(msgID MessageID, txn Transaction) error
        AckIDCumulative(msgID MessageID) error
@@ -559,6 +560,15 @@ func (c *consumer) AckID(msgID MessageID) error {
        return c.consumers[msgID.PartitionIdx()].AckID(msgID)
 }
 
+func (c *consumer) AckIDList(msgIDs []MessageID) error {
+       return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) 
(acker, error) {
+               if err := c.checkMsgIDPartition(msgID); err != nil {
+                       return nil, err
+               }
+               return c.consumers[msgID.PartitionIdx()], nil
+       })
+}
+
 // AckCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message, identified by its MessageID
 func (c *consumer) AckCumulative(msg Message) error {
diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go
index 3030bda3..26430add 100644
--- a/pulsar/consumer_multitopic.go
+++ b/pulsar/consumer_multitopic.go
@@ -167,6 +167,49 @@ func (c *multiTopicConsumer) AckID(msgID MessageID) error {
        return mid.consumer.AckID(msgID)
 }
 
+func (c *multiTopicConsumer) AckIDList(msgIDs []MessageID) error {
+       return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) 
(acker, error) {
+               if !checkMessageIDType(msgID) {
+                       return nil, fmt.Errorf("invalid message id type %T", 
msgID)
+               }
+               if mid := toTrackingMessageID(msgID); mid != nil && 
mid.consumer != nil {
+                       return mid.consumer, nil
+               }
+               return nil, errors.New("consumer is nil")
+       })
+}
+
+func ackIDListFromMultiTopics(log log.Logger, msgIDs []MessageID, findConsumer 
func(MessageID) (acker, error)) error {
+       consumerToMsgIDs := make(map[acker][]MessageID)
+       for _, msgID := range msgIDs {
+               if consumer, err := findConsumer(msgID); err == nil {
+                       consumerToMsgIDs[consumer] = 
append(consumerToMsgIDs[consumer], msgID)
+               } else {
+                       log.Warnf("Can not find consumer for %v", msgID)
+               }
+       }
+
+       ackError := AckError{}
+       for consumer, ids := range consumerToMsgIDs {
+               if err := consumer.AckIDList(ids); err != nil {
+                       if topicAckError := err.(AckError); topicAckError != 
nil {
+                               for id, err := range topicAckError {
+                                       ackError[id] = err
+                               }
+                       } else {
+                               // It should not reach here
+                               for _, id := range ids {
+                                       ackError[id] = err
+                               }
+                       }
+               }
+       }
+       if len(ackError) == 0 {
+               return nil
+       }
+       return ackError
+}
+
 // AckWithTxn the consumption of a single message with a transaction
 func (c *multiTopicConsumer) AckWithTxn(msg Message, txn Transaction) error {
        msgID := msg.ID()
diff --git a/pulsar/consumer_multitopic_test.go 
b/pulsar/consumer_multitopic_test.go
index 7c6b898a..4b3ec908 100644
--- a/pulsar/consumer_multitopic_test.go
+++ b/pulsar/consumer_multitopic_test.go
@@ -21,6 +21,7 @@ import (
        "fmt"
        "strings"
        "testing"
+       "time"
 
        "github.com/apache/pulsar-client-go/pulsaradmin"
        "github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config"
@@ -218,3 +219,101 @@ func TestMultiTopicGetLastMessageIDs(t *testing.T) {
        }
 
 }
+
+func TestMultiTopicAckIDList(t *testing.T) {
+       for _, params := range []bool{true, false} {
+               t.Run(fmt.Sprintf("TestMultiTopicConsumerAckIDList%v", params), 
func(t *testing.T) {
+                       runMultiTopicAckIDList(t, params)
+               })
+       }
+}
+
+func runMultiTopicAckIDList(t *testing.T, regex bool) {
+       topicPrefix := fmt.Sprintf("multiTopicAckIDList%v", 
time.Now().UnixNano())
+       topic1 := "persistent://public/default/" + topicPrefix + "1"
+       topic2 := "persistent://public/default/" + topicPrefix + "2"
+
+       client, err := NewClient(ClientOptions{URL: "pulsar://localhost:6650"})
+       assert.Nil(t, err)
+       defer client.Close()
+
+       if regex {
+               admin, err := pulsaradmin.NewClient(&config.Config{})
+               assert.Nil(t, err)
+               for _, topic := range []string{topic1, topic2} {
+                       topicName, err := utils.GetTopicName(topic)
+                       assert.Nil(t, err)
+                       admin.Topics().Create(*topicName, 0)
+               }
+       }
+
+       createConsumer := func() Consumer {
+               options := ConsumerOptions{
+                       SubscriptionName: "sub",
+                       Type:             Shared,
+                       AckWithResponse:  true,
+               }
+               if regex {
+                       options.TopicsPattern = topicPrefix + ".*"
+               } else {
+                       options.Topics = []string{topic1, topic2}
+               }
+               consumer, err := client.Subscribe(options)
+               assert.Nil(t, err)
+               return consumer
+       }
+       consumer := createConsumer()
+
+       sendMessages(t, client, topic1, 0, 3, false)
+       sendMessages(t, client, topic2, 0, 2, false)
+
+       receiveMessageMap := func(consumer Consumer, numMessages int) 
map[string][]Message {
+               msgs := receiveMessages(t, consumer, numMessages)
+               topicToMsgs := make(map[string][]Message)
+               for _, msg := range msgs {
+                       topicToMsgs[msg.Topic()] = 
append(topicToMsgs[msg.Topic()], msg)
+               }
+               return topicToMsgs
+       }
+
+       topicToMsgs := receiveMessageMap(consumer, 5)
+       assert.Equal(t, 3, len(topicToMsgs[topic1]))
+       for i := 0; i < 3; i++ {
+               assert.Equal(t, fmt.Sprintf("msg-%d", i), 
string(topicToMsgs[topic1][i].Payload()))
+       }
+       assert.Equal(t, 2, len(topicToMsgs[topic2]))
+       for i := 0; i < 2; i++ {
+               assert.Equal(t, fmt.Sprintf("msg-%d", i), 
string(topicToMsgs[topic2][i].Payload()))
+       }
+
+       assert.Nil(t, consumer.AckIDList([]MessageID{
+               topicToMsgs[topic1][0].ID(),
+               topicToMsgs[topic1][2].ID(),
+               topicToMsgs[topic2][1].ID(),
+       }))
+
+       consumer.Close()
+       consumer = createConsumer()
+       topicToMsgs = receiveMessageMap(consumer, 2)
+       assert.Equal(t, 1, len(topicToMsgs[topic1]))
+       assert.Equal(t, "msg-1", string(topicToMsgs[topic1][0].Payload()))
+       assert.Equal(t, 1, len(topicToMsgs[topic2]))
+       assert.Equal(t, "msg-0", string(topicToMsgs[topic2][0].Payload()))
+       consumer.Close()
+
+       msgID0 := topicToMsgs[topic1][0].ID()
+       err = consumer.AckIDList([]MessageID{msgID0})
+       assert.NotNil(t, err)
+       t.Logf("AckIDList error: %v", err)
+
+       msgID1 := topicToMsgs[topic2][0].ID()
+       if ackError, ok := consumer.AckIDList([]MessageID{msgID0, 
msgID1}).(AckError); ok {
+               assert.Equal(t, 2, len(ackError))
+               assert.Contains(t, ackError, msgID0)
+               assert.Equal(t, "consumer state is closed", 
ackError[msgID0].Error())
+               assert.Contains(t, ackError, msgID1)
+               assert.Equal(t, "consumer state is closed", 
ackError[msgID1].Error())
+       } else {
+               assert.Fail(t, "AckIDList should return AckError")
+       }
+}
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index 4e8fba5a..471d45a3 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -198,6 +198,10 @@ func (pc *partitionConsumer) pauseDispatchMessage() {
        pc.dispatcherSeekingControlCh <- struct{}{}
 }
 
+func (pc *partitionConsumer) Topic() string {
+       return pc.topic
+}
+
 func (pc *partitionConsumer) ActiveConsumerChanged(isActive bool) {
        listener := pc.options.consumerEventListener
        if listener == nil {
@@ -375,7 +379,12 @@ func newPartitionConsumer(parent Consumer, client *client, 
options *partitionCon
        pc.ackGroupingTracker = 
newAckGroupingTracker(options.ackGroupingOptions,
                func(id MessageID) { pc.sendIndividualAck(id) },
                func(id MessageID) { pc.sendCumulativeAck(id) },
-               func(ids []*pb.MessageIdData) { pc.eventsCh <- ids })
+               func(ids []*pb.MessageIdData) {
+                       pc.eventsCh <- &ackListRequest{
+                               errCh:  nil, // ignore the error
+                               msgIDs: ids,
+                       }
+               })
        pc.setConsumerState(consumerInit)
        pc.log = client.log.SubLogger(log.Fields{
                "name":         pc.name,
@@ -695,6 +704,86 @@ func (pc *partitionConsumer) AckID(msgID MessageID) error {
        return pc.ackID(msgID, false)
 }
 
+func (pc *partitionConsumer) AckIDList(msgIDs []MessageID) error {
+       if !pc.options.ackWithResponse {
+               for _, msgID := range msgIDs {
+                       if err := pc.AckID(msgID); err != nil {
+                               return err
+                       }
+               }
+               return nil
+       }
+
+       chunkedMsgIDs := make([]*chunkMessageID, 0) // we need to remove them 
after acknowledging
+       pendingAcks := make(map[position]*bitset.BitSet)
+       validMsgIDs := make([]MessageID, 0, len(msgIDs))
+
+       // They might be complete after the whole for loop
+       for _, msgID := range msgIDs {
+               if msgID.PartitionIdx() != pc.partitionIdx {
+                       pc.log.Errorf("%v inconsistent partition index %v 
(current: %v)", msgID, msgID.PartitionIdx(), pc.partitionIdx)
+               } else if msgID.BatchIdx() >= 0 && msgID.BatchSize() > 0 &&
+                       msgID.BatchIdx() >= msgID.BatchSize() {
+                       pc.log.Errorf("%v invalid batch index %v (size: %v)", 
msgID, msgID.BatchIdx(), msgID.BatchSize())
+               } else {
+                       valid := true
+                       switch convertedMsgID := msgID.(type) {
+                       case *trackingMessageID:
+                               position := newPosition(msgID)
+                               if convertedMsgID.ack() {
+                                       pendingAcks[position] = nil
+                               } else if pc.options.enableBatchIndexAck {
+                                       pendingAcks[position] = 
convertedMsgID.tracker.getAckBitSet()
+                               }
+                       case *chunkMessageID:
+                               for _, id := range 
pc.unAckChunksTracker.get(convertedMsgID) {
+                                       pendingAcks[newPosition(id)] = nil
+                               }
+                               chunkedMsgIDs = append(chunkedMsgIDs, 
convertedMsgID)
+                       case *messageID:
+                               pendingAcks[newPosition(msgID)] = nil
+                       default:
+                               pc.log.Errorf("invalid message id type %T: %v", 
msgID, msgID)
+                               valid = false
+                       }
+                       if valid {
+                               validMsgIDs = append(validMsgIDs, msgID)
+                       }
+               }
+       }
+
+       if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
+               pc.log.WithField("state", state).Error("Failed to ack by 
closing or closed consumer")
+               return toAckError(map[error][]MessageID{errors.New("consumer 
state is closed"): validMsgIDs})
+       }
+
+       req := &ackListRequest{
+               errCh:  make(chan error),
+               msgIDs: toMsgIDDataList(pendingAcks),
+       }
+       pc.eventsCh <- req
+       if err := <-req.errCh; err != nil {
+               return toAckError(map[error][]MessageID{err: validMsgIDs})
+       }
+       for _, id := range chunkedMsgIDs {
+               pc.unAckChunksTracker.remove(id)
+       }
+       for _, id := range msgIDs {
+               pc.options.interceptors.OnAcknowledge(pc.parentConsumer, id)
+       }
+       return nil
+}
+
+func toAckError(errorMap map[error][]MessageID) AckError {
+       e := AckError{}
+       for err, ids := range errorMap {
+               for _, id := range ids {
+                       e[id] = err
+               }
+       }
+       return e
+}
+
 func (pc *partitionConsumer) AckIDCumulative(msgID MessageID) error {
        if !checkMessageIDType(msgID) {
                pc.log.Errorf("invalid message id type %T", msgID)
@@ -1027,11 +1116,22 @@ func (pc *partitionConsumer) internalAck(req 
*ackRequest) {
        }
 }
 
-func (pc *partitionConsumer) internalAckList(msgIDs []*pb.MessageIdData) {
+func (pc *partitionConsumer) internalAckList(request *ackListRequest) {
+       if request.errCh != nil {
+               reqID := pc.client.rpcClient.NewRequestID()
+               _, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), 
reqID, pb.BaseCommand_ACK, &pb.CommandAck{
+                       AckType:    pb.CommandAck_Individual.Enum(),
+                       ConsumerId: proto.Uint64(pc.consumerID),
+                       MessageId:  request.msgIDs,
+                       RequestId:  &reqID,
+               })
+               request.errCh <- err
+               return
+       }
        pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), 
pb.BaseCommand_ACK, &pb.CommandAck{
                AckType:    pb.CommandAck_Individual.Enum(),
                ConsumerId: proto.Uint64(pc.consumerID),
-               MessageId:  msgIDs,
+               MessageId:  request.msgIDs,
        })
 }
 
@@ -1563,6 +1663,11 @@ type ackRequest struct {
        err     error
 }
 
+type ackListRequest struct {
+       errCh  chan error
+       msgIDs []*pb.MessageIdData
+}
+
 type ackWithTxnRequest struct {
        doneCh      chan struct{}
        msgID       trackingMessageID
@@ -1623,7 +1728,7 @@ func (pc *partitionConsumer) runEventsLoop() {
                                pc.internalAck(v)
                        case *ackWithTxnRequest:
                                pc.internalAckWithTxn(v)
-                       case []*pb.MessageIdData:
+                       case *ackListRequest:
                                pc.internalAckList(v)
                        case *redeliveryRequest:
                                pc.internalRedeliver(v)
diff --git a/pulsar/consumer_regex.go b/pulsar/consumer_regex.go
index 58cfa80f..ced770e9 100644
--- a/pulsar/consumer_regex.go
+++ b/pulsar/consumer_regex.go
@@ -215,6 +215,18 @@ func (c *regexConsumer) AckID(msgID MessageID) error {
        return mid.consumer.AckID(msgID)
 }
 
+func (c *regexConsumer) AckIDList(msgIDs []MessageID) error {
+       return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) 
(acker, error) {
+               if !checkMessageIDType(msgID) {
+                       return nil, fmt.Errorf("invalid message id type %T", 
msgID)
+               }
+               if mid := toTrackingMessageID(msgID); mid.consumer != nil {
+                       return mid.consumer, nil
+               }
+               return nil, errors.New("consumer is nil in consumer_regex")
+       })
+}
+
 // AckID the consumption of a single message, identified by its MessageID
 func (c *regexConsumer) AckWithTxn(msg Message, txn Transaction) error {
        msgID := msg.ID()
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index d8f31458..2524f681 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -4745,3 +4745,135 @@ func TestLookupConsumer(t *testing.T) {
                consumer.Ack(msg)
        }
 }
+
+func TestAckIDList(t *testing.T) {
+       for _, params := range []bool{true, false} {
+               t.Run(fmt.Sprintf("TestAckIDList_%v", params), func(t 
*testing.T) {
+                       runAckIDListTest(t, params)
+               })
+       }
+}
+
+func runAckIDListTest(t *testing.T, enableBatchIndexAck bool) {
+       client, err := NewClient(ClientOptions{URL: lookupURL})
+       assert.Nil(t, err)
+       defer client.Close()
+
+       topic := fmt.Sprintf("test-ack-id-list-%v", time.Now().Nanosecond())
+
+       consumer := createSharedConsumer(t, client, topic, enableBatchIndexAck)
+       sendMessages(t, client, topic, 0, 5, true)  // entry 0: [0, 1, 2, 3, 4]
+       sendMessages(t, client, topic, 5, 3, false) // entry 2: [5], 3: [6], 4: 
[7]
+       sendMessages(t, client, topic, 8, 2, true)  // entry 5: [8, 9]
+
+       msgs := receiveMessages(t, consumer, 10)
+       originalMsgIDs := make([]MessageID, 0)
+       for i := 0; i < 10; i++ {
+               originalMsgIDs = append(originalMsgIDs, msgs[i].ID())
+               assert.Equal(t, fmt.Sprintf("msg-%d", i), 
string(msgs[i].Payload()))
+       }
+
+       ackedIndexes := []int{0, 2, 3, 6, 8, 9}
+       unackedIndexes := []int{1, 4, 5, 7}
+       if !enableBatchIndexAck {
+               // [0, 4] is the first batch range but only partial of it is 
acked
+               unackedIndexes = []int{0, 1, 2, 3, 4, 5, 7}
+       }
+       msgIDs := make([]MessageID, len(ackedIndexes))
+       for i := 0; i < len(ackedIndexes); i++ {
+               msgIDs[i] = msgs[ackedIndexes[i]].ID()
+       }
+       assert.Nil(t, consumer.AckIDList(msgIDs))
+       consumer.Close()
+
+       consumer = createSharedConsumer(t, client, topic, enableBatchIndexAck)
+       msgs = receiveMessages(t, consumer, len(unackedIndexes))
+       for i := 0; i < len(unackedIndexes); i++ {
+               assert.Equal(t, fmt.Sprintf("msg-%d", unackedIndexes[i]), 
string(msgs[i].Payload()))
+       }
+
+       if !enableBatchIndexAck {
+               msgIDs = make([]MessageID, 0)
+               for i := 0; i < 5; i++ {
+                       msgIDs = append(msgIDs, originalMsgIDs[i])
+               }
+               assert.Nil(t, consumer.AckIDList(msgIDs))
+               consumer.Close()
+
+               consumer = createSharedConsumer(t, client, topic, 
enableBatchIndexAck)
+               msgs = receiveMessages(t, consumer, 2)
+               assert.Equal(t, "msg-5", string(msgs[0].Payload()))
+               assert.Equal(t, "msg-7", string(msgs[1].Payload()))
+               consumer.Close()
+       }
+       consumer.Close()
+       err = consumer.AckIDList(msgIDs)
+       assert.NotNil(t, err)
+       if ackError := err.(AckError); ackError != nil {
+               assert.Equal(t, len(msgIDs), len(ackError))
+               for _, id := range msgIDs {
+                       assert.Contains(t, ackError, id)
+                       assert.Equal(t, "consumer state is closed", 
ackError[id].Error())
+               }
+       } else {
+               assert.Fail(t, "AckIDList should return AckError")
+       }
+}
+
+func createSharedConsumer(t *testing.T, client Client, topic string, 
enableBatchIndexAck bool) Consumer {
+       consumer, err := client.Subscribe(ConsumerOptions{
+               Topic:                          topic,
+               SubscriptionName:               "my-sub",
+               SubscriptionInitialPosition:    SubscriptionPositionEarliest,
+               Type:                           Shared,
+               EnableBatchIndexAcknowledgment: enableBatchIndexAck,
+               AckWithResponse:                true,
+       })
+       assert.Nil(t, err)
+       return consumer
+}
+
+func sendMessages(t *testing.T, client Client, topic string, startIndex int, 
numMessages int, batching bool) {
+       producer, err := client.CreateProducer(ProducerOptions{
+               Topic:                   topic,
+               DisableBatching:         !batching,
+               BatchingMaxMessages:     uint(numMessages),
+               BatchingMaxSize:         1024 * 1024 * 10,
+               BatchingMaxPublishDelay: 1 * time.Hour,
+       })
+       assert.Nil(t, err)
+       defer producer.Close()
+
+       ctx := context.Background()
+       for i := 0; i < numMessages; i++ {
+               msg := &ProducerMessage{Payload: []byte(fmt.Sprintf("msg-%d", 
startIndex+i))}
+               if batching {
+                       producer.SendAsync(ctx, msg, func(_ MessageID, _ 
*ProducerMessage, err error) {
+                               if err != nil {
+                                       t.Logf("Failed to send message: %v", 
err)
+                               }
+                       })
+               } else {
+                       if _, err := producer.Send(ctx, msg); err != nil {
+                               assert.Fail(t, "Failed to send message: %v", 
err)
+                       }
+               }
+       }
+       assert.Nil(t, producer.Flush())
+}
+
+func receiveMessages(t *testing.T, consumer Consumer, numMessages int) 
[]Message {
+       ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+       defer cancel()
+       msgs := make([]Message, 0)
+       for i := 0; i < numMessages; i++ {
+               if msg, err := consumer.Receive(ctx); err == nil {
+                       msgs = append(msgs, msg)
+               } else {
+                       t.Logf("Failed to receive message: %v", err)
+                       break
+               }
+       }
+       assert.Equal(t, numMessages, len(msgs))
+       return msgs
+}
diff --git a/pulsar/consumer_zero_queue.go b/pulsar/consumer_zero_queue.go
index 81171272..3f2862da 100644
--- a/pulsar/consumer_zero_queue.go
+++ b/pulsar/consumer_zero_queue.go
@@ -171,6 +171,10 @@ func (z *zeroQueueConsumer) AckID(msgID MessageID) error {
        return z.pc.AckID(msgID)
 }
 
+func (z *zeroQueueConsumer) AckIDList(msgIDs []MessageID) error {
+       return z.pc.AckIDList(msgIDs)
+}
+
 func (z *zeroQueueConsumer) AckWithTxn(msg Message, txn Transaction) error {
        msgID := msg.ID()
        if err := z.checkMsgIDPartition(msgID); err != nil {
diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go
index 478b1af2..0acd782b 100644
--- a/pulsar/impl_message.go
+++ b/pulsar/impl_message.go
@@ -404,6 +404,12 @@ type ackTracker struct {
        prevBatchAcked uint32
 }
 
+func (t *ackTracker) getAckBitSet() *bitset.BitSet {
+       t.Lock()
+       defer t.Unlock()
+       return t.batchIDs.Clone()
+}
+
 func (t *ackTracker) ack(batchID int) bool {
        if batchID < 0 {
                return true
diff --git a/pulsar/internal/pulsartracing/consumer_interceptor_test.go 
b/pulsar/internal/pulsartracing/consumer_interceptor_test.go
index 1fa1bf0d..e7712356 100644
--- a/pulsar/internal/pulsartracing/consumer_interceptor_test.go
+++ b/pulsar/internal/pulsartracing/consumer_interceptor_test.go
@@ -79,6 +79,10 @@ func (c *mockConsumer) AckID(_ pulsar.MessageID) error {
        return nil
 }
 
+func (c *mockConsumer) AckIDList(_ []pulsar.MessageID) error {
+       return nil
+}
+
 func (c *mockConsumer) AckCumulative(_ pulsar.Message) error {
        return nil
 }

Reply via email to