This is an automated email from the ASF dual-hosted git repository.

xyz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 3cd02b1  [feature] Support batch index ACK (#938)
3cd02b1 is described below

commit 3cd02b1ce8ed15716aff9b04641109e11a704a80
Author: Yunze Xu <[email protected]>
AuthorDate: Wed Jan 11 16:02:49 2023 +0800

    [feature] Support batch index ACK (#938)
    
    Fixes https://github.com/apache/pulsar-client-go/issues/894
    
    ### Modifications
    
    Add an `EnableBatchIndexAcknowledgment` to specify whether batch index
    ACK is enabled. Since this feature requires the conversion between a bit
    set and its underlying long array, which is similar to Java's `BitSet`,
    this commit introduces github.com/bits-and-blooms/bitset dependency to
    replace the `big.Int` based implementation of the bit set.
    
    Add a `BatchSize()` method to `MessageId` to indicate the size of the
    `ack_set` field. When the batch index ACK happens, convert the
    `[]uint64` to the `[]int64` as the `ack_set` field in `CommandAck`. When
    receiving messages, convert the `ack_set` field in `CommandMessage` to
    filter the acknowledged single messages.
    
    Remove the duplicated code in `AckID` and `AckIDWithResponse`.
    
    ### Verifications
    
    `TestBatchIndexAck` is added to cover the case whether `AckWithResponse`
    is enabled and both individual and cumulative ACK.
---
 go.mod                                 |   1 +
 go.sum                                 |   2 +
 integration-tests/conf/standalone.conf |   2 +
 pulsar/consumer.go                     |   4 +
 pulsar/consumer_impl.go                |   1 +
 pulsar/consumer_partition.go           |  80 ++++++++++---------
 pulsar/consumer_test.go                | 135 +++++++++++++++++++++++++++++++++
 pulsar/impl_message.go                 |  61 ++++++++++-----
 pulsar/impl_message_test.go            |  16 ++--
 pulsar/message.go                      |   5 +-
 pulsar/producer_partition.go           |  12 +++
 pulsar/reader_test.go                  |   4 +
 12 files changed, 259 insertions(+), 64 deletions(-)

diff --git a/go.mod b/go.mod
index 0435db9..e143b04 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,7 @@ require (
        github.com/99designs/keyring v1.2.1
        github.com/AthenZ/athenz v1.10.39
        github.com/DataDog/zstd v1.5.0
+       github.com/bits-and-blooms/bitset v1.4.0
        github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b
        github.com/davecgh/go-spew v1.1.1
        github.com/golang-jwt/jwt v3.2.1+incompatible
diff --git a/go.sum b/go.sum
index 9dbd99f..6b07567 100644
--- a/go.sum
+++ b/go.sum
@@ -65,6 +65,8 @@ github.com/beorn7/perks v1.0.0/go.mod 
h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod 
h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/bgentry/speakeasy v0.1.0/go.mod 
h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
+github.com/bits-and-blooms/bitset v1.4.0 
h1:+YZ8ePm+He2pU3dZlIZiOeAKfrBkXi1lSrXJ/Xzgbu8=
+github.com/bits-and-blooms/bitset v1.4.0/go.mod 
h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA=
 github.com/bketelsen/crypt v0.0.4/go.mod 
h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM=
 github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b 
h1:AP/Y7sqYicnjGDfD5VcY4CIfh1hRXBUavxrvELjTiOE=
 github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b/go.mod 
h1:ac9efd0D1fsDb3EJvhqgXRbFx7bs2wqZ10HQPeU8U/Q=
diff --git a/integration-tests/conf/standalone.conf 
b/integration-tests/conf/standalone.conf
index 8cd2828..b9ff87b 100644
--- a/integration-tests/conf/standalone.conf
+++ b/integration-tests/conf/standalone.conf
@@ -292,3 +292,5 @@ globalZookeeperServers=
 
 # Deprecated. Use brokerDeleteInactiveTopicsFrequencySeconds
 brokerServicePurgeInactiveFrequencyInSeconds=60
+
+acknowledgmentAtBatchIndexLevelEnabled=true
diff --git a/pulsar/consumer.go b/pulsar/consumer.go
index c9f5e89..8bae57d 100644
--- a/pulsar/consumer.go
+++ b/pulsar/consumer.go
@@ -211,6 +211,10 @@ type ConsumerOptions struct {
        // AutoAckIncompleteChunk sets whether consumer auto acknowledges 
incomplete chunked message when it should
        // be removed (e.g.the chunked message pending queue is full). 
(default: false)
        AutoAckIncompleteChunk bool
+
+       // Enable or disable batch index acknowledgment. To enable this 
feature, ensure batch index acknowledgment
+       // is enabled on the broker side. (default: false)
+       EnableBatchIndexAcknowledgment bool
 }
 
 // Consumer is an interface that abstracts behavior of Pulsar's consumer
diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index e0120ad..bf136c8 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -397,6 +397,7 @@ func (c *consumer) internalTopicSubscribeToPartitions() 
error {
                                expireTimeOfIncompleteChunk: 
c.options.ExpireTimeOfIncompleteChunk,
                                autoAckIncompleteChunk:      
c.options.AutoAckIncompleteChunk,
                                consumerEventListener:       
c.options.EventListener,
+                               enableBatchIndexAck:         
c.options.EnableBatchIndexAcknowledgment,
                        }
                        cons, err := newPartitionConsumer(c, c.client, opts, 
c.messageCh, c.dlq, c.metrics)
                        ch <- ConsumerError{
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index 5ee09ab..e723f8a 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -36,6 +36,7 @@ import (
        cryptointernal 
"github.com/apache/pulsar-client-go/pulsar/internal/crypto"
        pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto"
        "github.com/apache/pulsar-client-go/pulsar/log"
+       "github.com/bits-and-blooms/bitset"
 
        uAtomic "go.uber.org/atomic"
 )
@@ -114,6 +115,7 @@ type partitionConsumerOpts struct {
        autoAckIncompleteChunk      bool
        // in failover mode, this callback will be called when consumer change
        consumerEventListener ConsumerEventListener
+       enableBatchIndexAck   bool
 }
 
 type ConsumerEventListener interface {
@@ -450,7 +452,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() 
(trackingMessageID, error
        return convertToMessageID(id), nil
 }
 
-func (pc *partitionConsumer) AckIDWithResponse(msgID MessageID) error {
+func (pc *partitionConsumer) ackID(msgID MessageID, withResponse bool) error {
        if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
                pc.log.WithField("state", state).Error("Failed to ack by 
closing or closed consumer")
                return errors.New("consumer state is closed")
@@ -474,47 +476,31 @@ func (pc *partitionConsumer) AckIDWithResponse(msgID 
MessageID) error {
                ackReq.msgID = trackingID
                // send ack request to eventsCh
                pc.eventsCh <- ackReq
-               // wait for the request to complete
-               <-ackReq.doneCh
-
-               pc.options.interceptors.OnAcknowledge(pc.parentConsumer, msgID)
-       }
-
-       return ackReq.err
-}
-
-func (pc *partitionConsumer) AckID(msgID MessageID) error {
-       if state := pc.getConsumerState(); state == consumerClosed || state == 
consumerClosing {
-               pc.log.WithField("state", state).Error("Failed to ack by 
closing or closed consumer")
-               return errors.New("consumer state is closed")
-       }
-
-       if cmid, ok := toChunkedMessageID(msgID); ok {
-               return pc.unAckChunksTracker.ack(cmid)
-       }
 
-       trackingID, ok := toTrackingMessageID(msgID)
-       if !ok {
-               return errors.New("failed to convert trackingMessageID")
-       }
+               if withResponse {
+                       <-ackReq.doneCh
+               }
 
-       ackReq := new(ackRequest)
-       ackReq.doneCh = make(chan struct{})
-       ackReq.ackType = individualAck
-       if !trackingID.Undefined() && trackingID.ack() {
-               pc.metrics.AcksCounter.Inc()
-               
pc.metrics.ProcessingTime.Observe(float64(time.Now().UnixNano()-trackingID.receivedTime.UnixNano())
 / 1.0e9)
+               pc.options.interceptors.OnAcknowledge(pc.parentConsumer, msgID)
+       } else if pc.options.enableBatchIndexAck {
                ackReq.msgID = trackingID
-               // send ack request to eventsCh
                pc.eventsCh <- ackReq
-               // No need to wait for ackReq.doneCh to finish
-
-               pc.options.interceptors.OnAcknowledge(pc.parentConsumer, msgID)
        }
 
+       if withResponse {
+               return ackReq.err
+       }
        return nil
 }
 
+func (pc *partitionConsumer) AckIDWithResponse(msgID MessageID) error {
+       return pc.ackID(msgID, true)
+}
+
+func (pc *partitionConsumer) AckID(msgID MessageID) error {
+       return pc.ackID(msgID, false)
+}
+
 func (pc *partitionConsumer) AckIDCumulative(msgID MessageID) error {
        return pc.internalAckIDCumulative(msgID, false)
 }
@@ -541,7 +527,7 @@ func (pc *partitionConsumer) internalAckIDCumulative(msgID 
MessageID, withRespon
        ackReq := new(ackRequest)
        ackReq.doneCh = make(chan struct{})
        ackReq.ackType = cumulativeAck
-       if trackingID.ackCumulative() {
+       if trackingID.ackCumulative() || pc.options.enableBatchIndexAck {
                ackReq.msgID = trackingID
        } else if !trackingID.tracker.hasPrevBatchAcked() {
                // get previous batch message id
@@ -774,6 +760,12 @@ func (pc *partitionConsumer) internalAck(req *ackRequest) {
                LedgerId: proto.Uint64(uint64(msgID.ledgerID)),
                EntryId:  proto.Uint64(uint64(msgID.entryID)),
        }
+       if pc.options.enableBatchIndexAck && msgID.tracker != nil {
+               ackSet := msgID.tracker.toAckSet()
+               if ackSet != nil {
+                       messageIDs[0].AckSet = ackSet
+               }
+       }
 
        reqID := pc.client.rpcClient.NewRequestID()
        cmdAck := &pb.CommandAck{
@@ -832,7 +824,7 @@ func (pc *partitionConsumer) MessageReceived(response 
*pb.CommandMessage, header
                switch crypToFailureAction {
                case crypto.ConsumerCryptoFailureActionFail:
                        pc.log.Errorf("consuming message failed due to 
decryption err :%v", err)
-                       
pc.NackID(newTrackingMessageID(int64(pbMsgID.GetLedgerId()), 
int64(pbMsgID.GetEntryId()), 0, 0, nil))
+                       
pc.NackID(newTrackingMessageID(int64(pbMsgID.GetLedgerId()), 
int64(pbMsgID.GetEntryId()), 0, 0, 0, nil))
                        return err
                case crypto.ConsumerCryptoFailureActionDiscard:
                        pc.discardCorruptedMessage(pbMsgID, 
pb.CommandAck_DecryptionError)
@@ -852,6 +844,7 @@ func (pc *partitionConsumer) MessageReceived(response 
*pb.CommandMessage, header
                                                int64(pbMsgID.GetEntryId()),
                                                pbMsgID.GetBatchIndex(),
                                                pc.partitionIdx,
+                                               pbMsgID.GetBatchSize(),
                                        ),
                                        payLoad:             
headersAndPayload.ReadableSlice(),
                                        schema:              pc.options.schema,
@@ -899,7 +892,17 @@ func (pc *partitionConsumer) MessageReceived(response 
*pb.CommandMessage, header
        var ackTracker *ackTracker
        // are there multiple messages in this batch?
        if numMsgs > 1 {
-               ackTracker = newAckTracker(numMsgs)
+               ackTracker = newAckTracker(uint(numMsgs))
+       }
+
+       var ackSet *bitset.BitSet
+       if response.GetAckSet() != nil {
+               ackSetFromResponse := response.GetAckSet()
+               buf := make([]uint64, len(ackSetFromResponse))
+               for i := 0; i < len(buf); i++ {
+                       buf[i] = uint64(ackSetFromResponse[i])
+               }
+               ackSet = bitset.From(buf)
        }
 
        pc.metrics.MessagesReceived.Add(float64(numMsgs))
@@ -911,6 +914,10 @@ func (pc *partitionConsumer) MessageReceived(response 
*pb.CommandMessage, header
                        pc.discardCorruptedMessage(pbMsgID, 
pb.CommandAck_BatchDeSerializeError)
                        return err
                }
+               if ackSet != nil && !ackSet.Test(uint(i)) {
+                       pc.log.Debugf("Ignoring message from %vth message, 
which has been acknowledged", i)
+                       continue
+               }
 
                pc.metrics.BytesReceived.Add(float64(len(payload)))
                pc.metrics.PrefetchedBytes.Add(float64(len(payload)))
@@ -920,6 +927,7 @@ func (pc *partitionConsumer) MessageReceived(response 
*pb.CommandMessage, header
                        int64(pbMsgID.GetEntryId()),
                        int32(i),
                        pc.partitionIdx,
+                       int32(numMsgs),
                        ackTracker)
                // set the consumer so we know how to ack the message id
                trackingMsgID.consumer = pc
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index 2f1a056..366be5d 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -3851,3 +3851,138 @@ func TestAckWithMessageID(t *testing.T) {
        err = consumer.AckID(newID)
        assert.Nil(t, err)
 }
+
+func TestBatchIndexAck(t *testing.T) {
+       tests := []struct {
+               AckWithResponse bool
+               Cumulative      bool
+       }{
+               {
+                       AckWithResponse: true,
+                       Cumulative:      true,
+               },
+               {
+                       AckWithResponse: true,
+                       Cumulative:      false,
+               },
+               {
+                       AckWithResponse: false,
+                       Cumulative:      true,
+               },
+               {
+                       AckWithResponse: false,
+                       Cumulative:      false,
+               },
+       }
+       for _, params := range tests {
+               
t.Run(fmt.Sprintf("TestBatchIndexAck_WithResponse_%v_Cumulative_%v",
+                       params.AckWithResponse, params.Cumulative),
+                       func(t *testing.T) {
+                               runBatchIndexAckTest(t, params.AckWithResponse, 
params.Cumulative)
+                       })
+       }
+}
+
+func runBatchIndexAckTest(t *testing.T, ackWithResponse bool, cumulative bool) 
{
+       client, err := NewClient(ClientOptions{
+               URL: lookupURL,
+       })
+
+       assert.Nil(t, err)
+
+       topic := newTopicName()
+       createConsumer := func() Consumer {
+               consumer, err := client.Subscribe(ConsumerOptions{
+                       Topic:                          topic,
+                       SubscriptionName:               "my-sub",
+                       AckWithResponse:                ackWithResponse,
+                       EnableBatchIndexAcknowledgment: true,
+               })
+               assert.Nil(t, err)
+               return consumer
+       }
+
+       consumer := createConsumer()
+
+       duration, err := time.ParseDuration("1h")
+       assert.Nil(t, err)
+
+       const BatchingMaxSize int = 2 * 5
+       producer, err := client.CreateProducer(ProducerOptions{
+               Topic:                   topic,
+               DisableBatching:         false,
+               BatchingMaxMessages:     uint(BatchingMaxSize),
+               BatchingMaxSize:         uint(1024 * 1024 * 10),
+               BatchingMaxPublishDelay: duration,
+       })
+       assert.Nil(t, err)
+       for i := 0; i < BatchingMaxSize; i++ {
+               producer.SendAsync(context.Background(), &ProducerMessage{
+                       Payload: []byte(fmt.Sprintf("msg-%d", i)),
+               }, func(id MessageID, producerMessage *ProducerMessage, err 
error) {
+                       assert.Nil(t, err)
+                       log.Printf("Sent to %v:%d:%d", id, id.BatchIdx(), 
id.BatchSize())
+               })
+       }
+       assert.Nil(t, producer.Flush())
+
+       msgIds := make([]MessageID, BatchingMaxSize)
+       for i := 0; i < BatchingMaxSize; i++ {
+               message, err := consumer.Receive(context.Background())
+               assert.Nil(t, err)
+               msgIds[i] = message.ID()
+               log.Printf("Received %v from %v:%d:%d", 
string(message.Payload()), message.ID(),
+                       message.ID().BatchIdx(), message.ID().BatchSize())
+       }
+
+       // Acknowledge half of the messages
+       if cumulative {
+               msgID := msgIds[BatchingMaxSize/2-1]
+               consumer.AckIDCumulative(msgID)
+               log.Printf("Acknowledge %v:%d cumulatively\n", msgID, 
msgID.BatchIdx())
+       } else {
+               for i := 0; i < BatchingMaxSize; i++ {
+                       msgID := msgIds[i]
+                       if i%2 == 0 {
+                               consumer.AckID(msgID)
+                               log.Printf("Acknowledge %v:%d\n", msgID, 
msgID.BatchIdx())
+                       }
+               }
+       }
+       consumer.Close()
+       consumer = createConsumer()
+
+       for i := 0; i < BatchingMaxSize/2; i++ {
+               message, err := consumer.Receive(context.Background())
+               assert.Nil(t, err)
+               log.Printf("Received %v from %v:%d:%d", 
string(message.Payload()), message.ID(),
+                       message.ID().BatchIdx(), message.ID().BatchSize())
+               index := i*2 + 1
+               if cumulative {
+                       index = i + BatchingMaxSize/2
+               }
+               assert.Equal(t, []byte(fmt.Sprintf("msg-%d", index)), 
message.Payload())
+               assert.Equal(t, msgIds[index].BatchIdx(), 
message.ID().BatchIdx())
+               // We should not acknowledge message.ID() here because 
message.ID() shares a different
+               // tracker with msgIds
+               if !cumulative {
+                       msgID := msgIds[index]
+                       consumer.AckID(msgID)
+                       log.Printf("Acknowledge %v:%d\n", msgID, 
msgID.BatchIdx())
+               }
+       }
+       if cumulative {
+               msgID := msgIds[BatchingMaxSize-1]
+               consumer.AckIDCumulative(msgID)
+               log.Printf("Acknowledge %v:%d cumulatively\n", msgID, 
msgID.BatchIdx())
+       }
+       consumer.Close()
+       consumer = createConsumer()
+       _, err = producer.Send(context.Background(), &ProducerMessage{Payload: 
[]byte("end-marker")})
+       assert.Nil(t, err)
+       msg, err := consumer.Receive(context.Background())
+       assert.Nil(t, err)
+       assert.Equal(t, "end-marker", string(msg.Payload()))
+
+       client.Close()
+}
diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go
index d863da9..39db8e1 100644
--- a/pulsar/impl_message.go
+++ b/pulsar/impl_message.go
@@ -21,8 +21,6 @@ import (
        "errors"
        "fmt"
        "math"
-       "math/big"
-       "strings"
        "sync"
        "sync/atomic"
        "time"
@@ -30,6 +28,7 @@ import (
        "google.golang.org/protobuf/proto"
 
        pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto"
+       "github.com/bits-and-blooms/bitset"
 )
 
 type messageID struct {
@@ -37,6 +36,7 @@ type messageID struct {
        entryID      int64
        batchIdx     int32
        partitionIdx int32
+       batchSize    int32
 }
 
 var latestMessageID = messageID{
@@ -44,6 +44,7 @@ var latestMessageID = messageID{
        entryID:      math.MaxInt64,
        batchIdx:     -1,
        partitionIdx: -1,
+       batchSize:    0,
 }
 
 var earliestMessageID = messageID{
@@ -51,6 +52,7 @@ var earliestMessageID = messageID{
        entryID:      -1,
        batchIdx:     -1,
        partitionIdx: -1,
+       batchSize:    0,
 }
 
 type trackingMessageID struct {
@@ -159,6 +161,7 @@ func (id messageID) Serialize() []byte {
                EntryId:    proto.Uint64(uint64(id.entryID)),
                BatchIndex: proto.Int32(id.batchIdx),
                Partition:  proto.Int32(id.partitionIdx),
+               BatchSize:  proto.Int32(id.batchSize),
        }
        data, _ := proto.Marshal(msgID)
        return data
@@ -180,6 +183,10 @@ func (id messageID) PartitionIdx() int32 {
        return id.partitionIdx
 }
 
+func (id messageID) BatchSize() int32 {
+       return id.batchSize
+}
+
 func (id messageID) String() string {
        return fmt.Sprintf("%d:%d:%d", id.ledgerID, id.entryID, id.partitionIdx)
 }
@@ -195,20 +202,22 @@ func deserializeMessageID(data []byte) (MessageID, error) 
{
                int64(msgID.GetEntryId()),
                msgID.GetBatchIndex(),
                msgID.GetPartition(),
+               msgID.GetBatchSize(),
        )
        return id, nil
 }
 
-func newMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx 
int32) MessageID {
+func newMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx 
int32, batchSize int32) MessageID {
        return messageID{
                ledgerID:     ledgerID,
                entryID:      entryID,
                batchIdx:     batchIdx,
                partitionIdx: partitionIdx,
+               batchSize:    batchSize,
        }
 }
 
-func newTrackingMessageID(ledgerID int64, entryID int64, batchIdx int32, 
partitionIdx int32,
+func newTrackingMessageID(ledgerID int64, entryID int64, batchIdx int32, 
partitionIdx int32, batchSize int32,
        tracker *ackTracker) trackingMessageID {
        return trackingMessageID{
                messageID: messageID{
@@ -216,6 +225,7 @@ func newTrackingMessageID(ledgerID int64, entryID int64, 
batchIdx int32, partiti
                        entryID:      entryID,
                        batchIdx:     batchIdx,
                        partitionIdx: partitionIdx,
+                       batchSize:    batchSize,
                },
                tracker:      tracker,
                receivedTime: time.Now(),
@@ -370,14 +380,10 @@ func (msg *message) BrokerPublishTime() *time.Time {
        return msg.brokerPublishTime
 }
 
-func newAckTracker(size int) *ackTracker {
-       var batchIDs *big.Int
-       if size <= 64 {
-               shift := uint32(64 - size)
-               setBits := ^uint64(0) >> shift
-               batchIDs = new(big.Int).SetUint64(setBits)
-       } else {
-               batchIDs, _ = new(big.Int).SetString(strings.Repeat("1", size), 
2)
+func newAckTracker(size uint) *ackTracker {
+       batchIDs := bitset.New(size)
+       for i := uint(0); i < size; i++ {
+               batchIDs.Set(i)
        }
        return &ackTracker{
                size:     size,
@@ -387,8 +393,8 @@ func newAckTracker(size int) *ackTracker {
 
 type ackTracker struct {
        sync.Mutex
-       size           int
-       batchIDs       *big.Int
+       size           uint
+       batchIDs       *bitset.BitSet
        prevBatchAcked uint32
 }
 
@@ -398,19 +404,20 @@ func (t *ackTracker) ack(batchID int) bool {
        }
        t.Lock()
        defer t.Unlock()
-       t.batchIDs = t.batchIDs.SetBit(t.batchIDs, batchID, 0)
-       return len(t.batchIDs.Bits()) == 0
+       t.batchIDs.Clear(uint(batchID))
+       return t.batchIDs.None()
 }
 
 func (t *ackTracker) ackCumulative(batchID int) bool {
        if batchID < 0 {
                return true
        }
-       mask := big.NewInt(-1)
        t.Lock()
        defer t.Unlock()
-       t.batchIDs.And(t.batchIDs, mask.Lsh(mask, uint(batchID+1)))
-       return len(t.batchIDs.Bits()) == 0
+       for i := 0; i <= batchID; i++ {
+               t.batchIDs.Clear(uint(i))
+       }
+       return t.batchIDs.None()
 }
 
 func (t *ackTracker) hasPrevBatchAcked() bool {
@@ -424,7 +431,21 @@ func (t *ackTracker) setPrevBatchAcked() {
 func (t *ackTracker) completed() bool {
        t.Lock()
        defer t.Unlock()
-       return len(t.batchIDs.Bits()) == 0
+       return t.batchIDs.None()
+}
+
+func (t *ackTracker) toAckSet() []int64 {
+       t.Lock()
+       defer t.Unlock()
+       if t.batchIDs.None() {
+               return nil
+       }
+       bytes := t.batchIDs.Bytes()
+       ackSet := make([]int64, len(bytes))
+       for i := 0; i < len(bytes); i++ {
+               ackSet[i] = int64(bytes[i])
+       }
+       return ackSet
 }
 
 type chunkMessageID struct {
diff --git a/pulsar/impl_message_test.go b/pulsar/impl_message_test.go
index 89aab8a..413a39f 100644
--- a/pulsar/impl_message_test.go
+++ b/pulsar/impl_message_test.go
@@ -24,7 +24,7 @@ import (
 )
 
 func TestMessageId(t *testing.T) {
-       id := newMessageID(1, 2, 3, 4)
+       id := newMessageID(1, 2, 3, 4, 5)
        bytes := id.Serialize()
 
        id2, err := DeserializeMessageID(bytes)
@@ -35,6 +35,7 @@ func TestMessageId(t *testing.T) {
        assert.Equal(t, int64(2), id2.(messageID).entryID)
        assert.Equal(t, int32(3), id2.(messageID).batchIdx)
        assert.Equal(t, int32(4), id2.(messageID).partitionIdx)
+       assert.Equal(t, int32(5), id2.(messageID).batchSize)
 
        id, err = DeserializeMessageID(nil)
        assert.Error(t, err)
@@ -47,11 +48,12 @@ func TestMessageId(t *testing.T) {
 
 func TestMessageIdGetFuncs(t *testing.T) {
        // test LedgerId,EntryId,BatchIdx,PartitionIdx
-       id := newMessageID(1, 2, 3, 4)
+       id := newMessageID(1, 2, 3, 4, 5)
        assert.Equal(t, int64(1), id.LedgerID())
        assert.Equal(t, int64(2), id.EntryID())
        assert.Equal(t, int32(3), id.BatchIdx())
        assert.Equal(t, int32(4), id.PartitionIdx())
+       assert.Equal(t, int32(5), id.BatchSize())
 }
 
 func TestAckTracker(t *testing.T) {
@@ -101,7 +103,7 @@ func TestAckTracker(t *testing.T) {
 
 func TestAckingMessageIDBatchOne(t *testing.T) {
        tracker := newAckTracker(1)
-       msgID := newTrackingMessageID(1, 1, 0, 0, tracker)
+       msgID := newTrackingMessageID(1, 1, 0, 0, 0, tracker)
        assert.Equal(t, true, msgID.ack())
        assert.Equal(t, true, tracker.completed())
 }
@@ -109,8 +111,8 @@ func TestAckingMessageIDBatchOne(t *testing.T) {
 func TestAckingMessageIDBatchTwo(t *testing.T) {
        tracker := newAckTracker(2)
        ids := []trackingMessageID{
-               newTrackingMessageID(1, 1, 0, 0, tracker),
-               newTrackingMessageID(1, 1, 1, 0, tracker),
+               newTrackingMessageID(1, 1, 0, 0, 0, tracker),
+               newTrackingMessageID(1, 1, 1, 0, 0, tracker),
        }
 
        assert.Equal(t, false, ids[0].ack())
@@ -120,8 +122,8 @@ func TestAckingMessageIDBatchTwo(t *testing.T) {
        // try reverse order
        tracker = newAckTracker(2)
        ids = []trackingMessageID{
-               newTrackingMessageID(1, 1, 0, 0, tracker),
-               newTrackingMessageID(1, 1, 1, 0, tracker),
+               newTrackingMessageID(1, 1, 0, 0, 0, tracker),
+               newTrackingMessageID(1, 1, 1, 0, 0, tracker),
        }
        assert.Equal(t, false, ids[1].ack())
        assert.Equal(t, true, ids[0].ack())
diff --git a/pulsar/message.go b/pulsar/message.go
index 76d0176..d37692b 100644
--- a/pulsar/message.go
+++ b/pulsar/message.go
@@ -155,6 +155,9 @@ type MessageID interface {
        // PartitionIdx returns the message partitionIdx
        PartitionIdx() int32
 
+       // BatchSize returns 0 or the batch size, which must be greater than 
BatchIdx()
+       BatchSize() int32
+
        // String returns message id in string format
        String() string
 }
@@ -166,7 +169,7 @@ func DeserializeMessageID(data []byte) (MessageID, error) {
 
 // NewMessageID Custom Create MessageID
 func NewMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx 
int32) MessageID {
-       return newMessageID(ledgerID, entryID, batchIdx, partitionIdx)
+       return newMessageID(ledgerID, entryID, batchIdx, partitionIdx, 0)
 }
 
 // EarliestMessageID returns a messageID that points to the earliest message 
available in a topic
diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 6308b55..b0467f5 100644
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -1118,6 +1118,15 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
                pi.Lock()
                defer pi.Unlock()
                
p.metrics.PublishRPCLatency.Observe(float64(now-pi.sentAt.UnixNano()) / 1.0e9)
+               batchSize := int32(0)
+               for _, i := range pi.sendRequests {
+                       sr := i.(*sendRequest)
+                       if sr.msg != nil {
+                               batchSize = batchSize + 1
+                       } else { // Flush request
+                               break
+                       }
+               }
                for idx, i := range pi.sendRequests {
                        sr := i.(*sendRequest)
                        if sr.msg != nil {
@@ -1138,6 +1147,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
                                        int64(response.MessageId.GetEntryId()),
                                        int32(idx),
                                        p.partitionIdx,
+                                       batchSize,
                                )
 
                                if sr.totalChunks > 1 {
@@ -1148,6 +1158,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
                                                                
int64(response.MessageId.GetEntryId()),
                                                                -1,
                                                                p.partitionIdx,
+                                                               0,
                                                        })
                                        } else if sr.chunkID == 
sr.totalChunks-1 {
                                                sr.chunkRecorder.setLastChunkID(
@@ -1156,6 +1167,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
                                                                
int64(response.MessageId.GetEntryId()),
                                                                -1,
                                                                p.partitionIdx,
+                                                               0,
                                                        })
                                                // use chunkMsgID to set msgID
                                                msgID = 
sr.chunkRecorder.chunkedMsgID
diff --git a/pulsar/reader_test.go b/pulsar/reader_test.go
index 53bd459..3543187 100644
--- a/pulsar/reader_test.go
+++ b/pulsar/reader_test.go
@@ -426,6 +426,10 @@ func (id *myMessageID) BatchIdx() int32 {
        return id.BatchIdx()
 }
 
+func (id *myMessageID) BatchSize() int32 {
+       return id.BatchSize()
+}
+
 func (id *myMessageID) PartitionIdx() int32 {
        return id.PartitionIdx()
 }

Reply via email to