This is an automated email from the ASF dual-hosted git repository.

rxl pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 154bff0  [Issue 172] Add key based batcher (#400)
154bff0 is described below

commit 154bff0bb825a4f23aed675eecd3959376214850
Author: Rui Fu <[email protected]>
AuthorDate: Tue Dec 1 16:36:39 2020 +0800

    [Issue 172] Add key based batcher (#400)
    
    Fixes #172
    
    ### Motivation
    
    Add a new batch message container named `keyBasedBatchContainer` to support 
batching message in key_shared subscription mode.
    
    ### Modifications
    
    - add `BatchBuilder` interface, add `FlushBatches` and `IsMultiBatches` func
    - change old `BatchBuilder` struct to `batchContainer`
    - add `keyBasedBatchContainer`
    - add tests
    
    ### Verifying this change
    
    This change added tests and can be verified as follows:
    
      - *Added integration tests for key based batch producer with multiple 
consumer in KeyShared mode*
      - *Added integration tests for message ordering with key based batch 
producer and KeyShared consumer*
---
 go.mod                                     |   2 +
 pulsar/batcher_builder.go                  |  44 ++++++
 pulsar/consumer_test.go                    | 167 ++++++++++++++++++++
 pulsar/internal/batch_builder.go           | 189 ++++++++++++++++-------
 pulsar/internal/key_based_batch_builder.go | 237 +++++++++++++++++++++++++++++
 pulsar/producer.go                         |   7 +
 pulsar/producer_partition.go               |  70 +++++++--
 7 files changed, 649 insertions(+), 67 deletions(-)

diff --git a/go.mod b/go.mod
index bf0b627..817e223 100644
--- a/go.mod
+++ b/go.mod
@@ -14,6 +14,8 @@ require (
        github.com/klauspost/compress v1.10.8
        github.com/kr/pretty v0.2.0 // indirect
        github.com/linkedin/goavro/v2 v2.9.8
+       github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // 
indirect
+       github.com/modern-go/reflect2 v1.0.1 // indirect
        github.com/pierrec/lz4 v2.0.5+incompatible
        github.com/pkg/errors v0.9.1
        github.com/prometheus/client_golang v1.7.1
diff --git a/pulsar/batcher_builder.go b/pulsar/batcher_builder.go
new file mode 100644
index 0000000..caefa8d
--- /dev/null
+++ b/pulsar/batcher_builder.go
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package pulsar
+
+import (
+       "errors"
+
+       "github.com/apache/pulsar-client-go/pulsar/internal"
+)
+
+type BatcherBuilderType int
+
+const (
+       DefaultBatchBuilder BatcherBuilderType = iota
+       KeyBasedBatchBuilder
+)
+
+func GetBatcherBuilderProvider(typ BatcherBuilderType) (
+       internal.BatcherBuilderProvider, error,
+) {
+       switch typ {
+       case DefaultBatchBuilder:
+               return internal.NewBatchBuilder, nil
+       case KeyBasedBatchBuilder:
+               return internal.NewKeyBasedBatchBuilder, nil
+       default:
+               return nil, errors.New("unsupported batcher builder provider 
type")
+       }
+}
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index f31bb69..6d58cd4 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -1712,3 +1712,170 @@ func TestConsumerName(t *testing.T) {
 
        assert.Equal(consumerName, consumer.Name())
 }
+
+func TestKeyBasedBatchProducerConsumerKeyShared(t *testing.T) {
+       const MsgBatchCount = 100
+       client, err := NewClient(ClientOptions{
+               URL: lookupURL,
+       })
+       assert.Nil(t, err)
+       defer client.Close()
+
+       topic := 
"persistent://public/default/test-key-based-batch-with-key-shared"
+
+       consumer1, err := client.Subscribe(ConsumerOptions{
+               Topic:            topic,
+               SubscriptionName: "sub-1",
+               Type:             KeyShared,
+       })
+       assert.Nil(t, err)
+       defer consumer1.Close()
+
+       consumer2, err := client.Subscribe(ConsumerOptions{
+               Topic:            topic,
+               SubscriptionName: "sub-1",
+               Type:             KeyShared,
+       })
+       assert.Nil(t, err)
+       defer consumer2.Close()
+
+       // create producer
+       producer, err := client.CreateProducer(ProducerOptions{
+               Topic:               topic,
+               DisableBatching:     false,
+               BatcherBuilderType:  KeyBasedBatchBuilder,
+               BatchingMaxMessages: 10,
+       })
+       assert.Nil(t, err)
+       defer producer.Close()
+
+       ctx := context.Background()
+       keys := []string{"key1", "key2", "key3"}
+       for i := 0; i < MsgBatchCount; i++ {
+               for _, k := range keys {
+                       producer.SendAsync(ctx, &ProducerMessage{
+                               Key:     k,
+                               Payload: []byte(fmt.Sprintf("value-%d", i)),
+                       }, func(id MessageID, producerMessage *ProducerMessage, 
err error) {
+                               assert.Nil(t, err)
+                       },
+                       )
+               }
+       }
+
+       receivedConsumer1 := 0
+       receivedConsumer2 := 0
+       consumer1Keys := make(map[string]int)
+       consumer2Keys := make(map[string]int)
+       for (receivedConsumer1 + receivedConsumer2) < 300 {
+               select {
+               case cm, ok := <-consumer1.Chan():
+                       if !ok {
+                               break
+                       }
+                       receivedConsumer1++
+                       cnt := 0
+                       if _, has := consumer1Keys[cm.Key()]; has {
+                               cnt = consumer1Keys[cm.Key()]
+                       }
+                       assert.Equal(
+                               t, fmt.Sprintf("value-%d", cnt),
+                               string(cm.Payload()),
+                       )
+                       consumer1Keys[cm.Key()] = cnt + 1
+                       consumer1.Ack(cm.Message)
+               case cm, ok := <-consumer2.Chan():
+                       if !ok {
+                               break
+                       }
+                       receivedConsumer2++
+                       cnt := 0
+                       if _, has := consumer2Keys[cm.Key()]; has {
+                               cnt = consumer2Keys[cm.Key()]
+                       }
+                       assert.Equal(
+                               t, fmt.Sprintf("value-%d", cnt),
+                               string(cm.Payload()),
+                       )
+                       consumer2Keys[cm.Key()] = cnt + 1
+                       consumer2.Ack(cm.Message)
+               }
+       }
+
+       assert.NotEqual(t, 0, receivedConsumer1)
+       assert.NotEqual(t, 0, receivedConsumer2)
+       assert.Equal(t, len(consumer1Keys)*MsgBatchCount, receivedConsumer1)
+       assert.Equal(t, len(consumer2Keys)*MsgBatchCount, receivedConsumer2)
+
+       fmt.Printf("TestKeyBasedBatchProducerConsumerKeyShared received 
messages consumer1: %d consumser2: %d\n",
+               receivedConsumer1, receivedConsumer2)
+       assert.Equal(t, 300, receivedConsumer1+receivedConsumer2)
+
+       fmt.Printf("TestKeyBasedBatchProducerConsumerKeyShared received 
messages keys consumer1: %v consumser2: %v\n",
+               consumer1Keys, consumer2Keys)
+}
+
+func TestOrderingOfKeyBasedBatchProducerConsumerKeyShared(t *testing.T) {
+       const MsgBatchCount = 10
+       client, err := NewClient(ClientOptions{
+               URL: lookupURL,
+       })
+       assert.Nil(t, err)
+       defer client.Close()
+
+       topic := 
"persistent://public/default/test-ordering-of-key-based-batch-with-key-shared"
+
+       consumer1, err := client.Subscribe(ConsumerOptions{
+               Topic:            topic,
+               SubscriptionName: "sub-1",
+               Type:             KeyShared,
+       })
+       assert.Nil(t, err)
+       defer consumer1.Close()
+
+       // create producer
+       producer, err := client.CreateProducer(ProducerOptions{
+               Topic:                   topic,
+               DisableBatching:         false,
+               BatcherBuilderType:      KeyBasedBatchBuilder,
+               BatchingMaxMessages:     30,
+               BatchingMaxPublishDelay: time.Second * 5,
+       })
+       assert.Nil(t, err)
+       defer producer.Close()
+
+       ctx := context.Background()
+       keys := []string{"key1", "key2", "key3"}
+       for i := 0; i < MsgBatchCount; i++ {
+               for _, k := range keys {
+                       producer.SendAsync(ctx, &ProducerMessage{
+                               Key:     k,
+                               Payload: []byte(fmt.Sprintf("value-%d", i)),
+                       }, func(id MessageID, producerMessage *ProducerMessage, 
err error) {
+                               assert.Nil(t, err)
+                       },
+                       )
+               }
+       }
+
+       var receivedKey string
+       var receivedMessageIndex int
+       for i := 0; i < len(keys)*MsgBatchCount; i++ {
+               cm, ok := <-consumer1.Chan()
+               if !ok {
+                       break
+               }
+               if receivedKey != cm.Key() {
+                       receivedKey = cm.Key()
+                       receivedMessageIndex = 0
+               }
+               assert.Equal(
+                       t, fmt.Sprintf("value-%d", receivedMessageIndex%10),
+                       string(cm.Payload()),
+               )
+               consumer1.Ack(cm.Message)
+               receivedMessageIndex++
+       }
+
+       // TODO: add OrderingKey support, see GH issue #401
+}
diff --git a/pulsar/internal/batch_builder.go b/pulsar/internal/batch_builder.go
index ecf2b88..3e1601f 100644
--- a/pulsar/internal/batch_builder.go
+++ b/pulsar/internal/batch_builder.go
@@ -31,8 +31,44 @@ type BuffersPool interface {
        GetBuffer() Buffer
 }
 
-// BatchBuilder wraps the objects needed to build a batch.
-type BatchBuilder struct {
+// BatcherBuilderProvider defines func which returns the BatchBuilder.
+type BatcherBuilderProvider func(
+       maxMessages uint, maxBatchSize uint, producerName string, producerID 
uint64,
+       compressionType pb.CompressionType, level compression.Level,
+       bufferPool BuffersPool, logger log.Logger,
+) (BatchBuilder, error)
+
+// BatchBuilder is a interface of batch builders
+type BatchBuilder interface {
+       // IsFull check if the size in the current batch exceeds the maximum 
size allowed by the batch
+       IsFull() bool
+
+       // Add will add single message to batch.
+       Add(
+               metadata *pb.SingleMessageMetadata, sequenceIDGenerator *uint64,
+               payload []byte,
+               callback interface{}, replicateTo []string, deliverAt time.Time,
+       ) bool
+
+       // Flush all the messages buffered in the client and wait until all 
messages have been successfully persisted.
+       Flush() (batchData Buffer, sequenceID uint64, callbacks []interface{})
+
+       // Flush all the messages buffered in multiple batches and wait until 
all
+       // messages have been successfully persisted.
+       FlushBatches() (
+               batchData []Buffer, sequenceID []uint64, callbacks 
[][]interface{},
+       )
+
+       // Return the batch container batch message in multiple batches.
+       IsMultiBatches() bool
+
+       reset()
+       Close() error
+}
+
+// batchContainer wraps the objects needed to a batch.
+// batchContainer implement BatchBuilder as a single batch container.
+type batchContainer struct {
        buffer Buffer
 
        // Current number of messages in the batch
@@ -41,7 +77,7 @@ type BatchBuilder struct {
        // Max number of message allowed in the batch
        maxMessages uint
 
-       // The largest size for a batch sent from this praticular producer.
+       // The largest size for a batch sent from this particular producer.
        // This is used as a baseline to allocate a new buffer that can hold 
the entire batch
        // without needing costly re-allocations.
        maxBatchSize uint
@@ -59,22 +95,26 @@ type BatchBuilder struct {
        log log.Logger
 }
 
-// NewBatchBuilder init batch builder and return BatchBuilder pointer. Build a 
new batch message container.
-func NewBatchBuilder(maxMessages uint, maxBatchSize uint, producerName string, 
producerID uint64,
+// newBatchContainer init a batchContainer
+func newBatchContainer(
+       maxMessages uint, maxBatchSize uint, producerName string, producerID 
uint64,
        compressionType pb.CompressionType, level compression.Level,
-       bufferPool BuffersPool, logger log.Logger) (*BatchBuilder, error) {
+       bufferPool BuffersPool, logger log.Logger,
+) batchContainer {
 
-       bb := &BatchBuilder{
+       bc := batchContainer{
                buffer:       NewBuffer(4096),
                numMessages:  0,
                maxMessages:  maxMessages,
                maxBatchSize: maxBatchSize,
                producerName: producerName,
                producerID:   producerID,
-               cmdSend: baseCommand(pb.BaseCommand_SEND,
+               cmdSend: baseCommand(
+                       pb.BaseCommand_SEND,
                        &pb.CommandSend{
                                ProducerId: &producerID,
-                       }),
+                       },
+               ),
                msgMetadata: &pb.MessageMetadata{
                        ProducerName: &producerName,
                },
@@ -85,99 +125,140 @@ func NewBatchBuilder(maxMessages uint, maxBatchSize uint, 
producerName string, p
        }
 
        if compressionType != pb.CompressionType_NONE {
-               bb.msgMetadata.Compression = &compressionType
+               bc.msgMetadata.Compression = &compressionType
        }
 
-       return bb, nil
+       return bc
+}
+
+// NewBatchBuilder init batch builder and return BatchBuilder pointer. Build a 
new batch message container.
+func NewBatchBuilder(
+       maxMessages uint, maxBatchSize uint, producerName string, producerID 
uint64,
+       compressionType pb.CompressionType, level compression.Level,
+       bufferPool BuffersPool, logger log.Logger,
+) (BatchBuilder, error) {
+
+       bc := newBatchContainer(
+               maxMessages, maxBatchSize, producerName, producerID, 
compressionType,
+               level, bufferPool, logger,
+       )
+
+       return &bc, nil
 }
 
 // IsFull check if the size in the current batch exceeds the maximum size 
allowed by the batch
-func (bb *BatchBuilder) IsFull() bool {
-       return bb.numMessages >= bb.maxMessages || bb.buffer.ReadableBytes() > 
uint32(bb.maxBatchSize)
+func (bc *batchContainer) IsFull() bool {
+       return bc.numMessages >= bc.maxMessages || bc.buffer.ReadableBytes() > 
uint32(bc.maxBatchSize)
 }
 
-func (bb *BatchBuilder) hasSpace(payload []byte) bool {
+func (bc *batchContainer) hasSpace(payload []byte) bool {
        msgSize := uint32(len(payload))
-       return bb.numMessages > 0 && (bb.buffer.ReadableBytes()+msgSize) > 
uint32(bb.maxBatchSize)
+       return bc.numMessages > 0 && (bc.buffer.ReadableBytes()+msgSize) > 
uint32(bc.maxBatchSize)
 }
 
 // Add will add single message to batch.
-func (bb *BatchBuilder) Add(metadata *pb.SingleMessageMetadata, sequenceID 
uint64, payload []byte,
-       callback interface{}, replicateTo []string, deliverAt time.Time) bool {
-       if replicateTo != nil && bb.numMessages != 0 {
+func (bc *batchContainer) Add(
+       metadata *pb.SingleMessageMetadata, sequenceIDGenerator *uint64,
+       payload []byte,
+       callback interface{}, replicateTo []string, deliverAt time.Time,
+) bool {
+       if replicateTo != nil && bc.numMessages != 0 {
                // If the current batch is not empty and we're trying to set 
the replication clusters,
                // then we need to force the current batch to flush and send 
the message individually
                return false
-       } else if bb.msgMetadata.ReplicateTo != nil {
+       } else if bc.msgMetadata.ReplicateTo != nil {
                // There's already a message with cluster replication list. 
need to flush before next
                // message can be sent
                return false
-       } else if bb.hasSpace(payload) {
+       } else if bc.hasSpace(payload) {
                // The current batch is full. Producer has to call Flush() to
                return false
        }
 
-       if bb.numMessages == 0 {
-               bb.msgMetadata.SequenceId = proto.Uint64(sequenceID)
-               bb.msgMetadata.PublishTime = 
proto.Uint64(TimestampMillis(time.Now()))
-               bb.msgMetadata.SequenceId = proto.Uint64(sequenceID)
-               bb.msgMetadata.ProducerName = &bb.producerName
-               bb.msgMetadata.ReplicateTo = replicateTo
-               bb.msgMetadata.PartitionKey = metadata.PartitionKey
+       if bc.numMessages == 0 {
+               var sequenceID uint64
+               if metadata.SequenceId != nil {
+                       sequenceID = *metadata.SequenceId
+               } else {
+                       sequenceID = GetAndAdd(sequenceIDGenerator, 1)
+               }
+               bc.msgMetadata.SequenceId = proto.Uint64(sequenceID)
+               bc.msgMetadata.PublishTime = 
proto.Uint64(TimestampMillis(time.Now()))
+               bc.msgMetadata.ProducerName = &bc.producerName
+               bc.msgMetadata.ReplicateTo = replicateTo
+               bc.msgMetadata.PartitionKey = metadata.PartitionKey
 
                if deliverAt.UnixNano() > 0 {
-                       bb.msgMetadata.DeliverAtTime = 
proto.Int64(int64(TimestampMillis(deliverAt)))
+                       bc.msgMetadata.DeliverAtTime = 
proto.Int64(int64(TimestampMillis(deliverAt)))
                }
 
-               bb.cmdSend.Send.SequenceId = proto.Uint64(sequenceID)
+               bc.cmdSend.Send.SequenceId = proto.Uint64(sequenceID)
        }
-       addSingleMessageToBatch(bb.buffer, metadata, payload)
+       addSingleMessageToBatch(bc.buffer, metadata, payload)
 
-       bb.numMessages++
-       bb.callbacks = append(bb.callbacks, callback)
+       bc.numMessages++
+       bc.callbacks = append(bc.callbacks, callback)
        return true
 }
 
-func (bb *BatchBuilder) reset() {
-       bb.numMessages = 0
-       bb.buffer.Clear()
-       bb.callbacks = []interface{}{}
-       bb.msgMetadata.ReplicateTo = nil
-       bb.msgMetadata.DeliverAtTime = nil
+func (bc *batchContainer) reset() {
+       bc.numMessages = 0
+       bc.buffer.Clear()
+       bc.callbacks = []interface{}{}
+       bc.msgMetadata.ReplicateTo = nil
+       bc.msgMetadata.DeliverAtTime = nil
 }
 
 // Flush all the messages buffered in the client and wait until all messages 
have been successfully persisted.
-func (bb *BatchBuilder) Flush() (batchData Buffer, sequenceID uint64, 
callbacks []interface{}) {
-       if bb.numMessages == 0 {
+func (bc *batchContainer) Flush() (
+       batchData Buffer, sequenceID uint64, callbacks []interface{},
+) {
+       if bc.numMessages == 0 {
                // No-Op for empty batch
                return nil, 0, nil
        }
-       bb.log.Debug("BatchBuilder flush: messages: ", bb.numMessages)
+       bc.log.Debug("BatchBuilder flush: messages: ", bc.numMessages)
 
-       bb.msgMetadata.NumMessagesInBatch = proto.Int32(int32(bb.numMessages))
-       bb.cmdSend.Send.NumMessages = proto.Int32(int32(bb.numMessages))
+       bc.msgMetadata.NumMessagesInBatch = proto.Int32(int32(bc.numMessages))
+       bc.cmdSend.Send.NumMessages = proto.Int32(int32(bc.numMessages))
 
-       uncompressedSize := bb.buffer.ReadableBytes()
-       bb.msgMetadata.UncompressedSize = &uncompressedSize
+       uncompressedSize := bc.buffer.ReadableBytes()
+       bc.msgMetadata.UncompressedSize = &uncompressedSize
 
-       buffer := bb.buffersPool.GetBuffer()
+       buffer := bc.buffersPool.GetBuffer()
        if buffer == nil {
                buffer = NewBuffer(int(uncompressedSize * 3 / 2))
        }
-       serializeBatch(buffer, bb.cmdSend, bb.msgMetadata, bb.buffer, 
bb.compressionProvider)
+       serializeBatch(
+               buffer, bc.cmdSend, bc.msgMetadata, bc.buffer, 
bc.compressionProvider,
+       )
 
-       callbacks = bb.callbacks
-       sequenceID = bb.cmdSend.Send.GetSequenceId()
-       bb.reset()
+       callbacks = bc.callbacks
+       sequenceID = bc.cmdSend.Send.GetSequenceId()
+       bc.reset()
        return buffer, sequenceID, callbacks
 }
 
-func (bb *BatchBuilder) Close() error {
-       return bb.compressionProvider.Close()
+// FlushBatches only for multiple batches container
+func (bc *batchContainer) FlushBatches() (
+       batchData []Buffer, sequenceID []uint64, callbacks [][]interface{},
+) {
+       panic("single batch container not support FlushBatches(), please use 
Flush() instead")
+}
+
+// batchContainer as a single batch container
+func (bc *batchContainer) IsMultiBatches() bool {
+       return false
+}
+
+func (bc *batchContainer) Close() error {
+       return bc.compressionProvider.Close()
 }
 
-func getCompressionProvider(compressionType pb.CompressionType,
-       level compression.Level) compression.Provider {
+func getCompressionProvider(
+       compressionType pb.CompressionType,
+       level compression.Level,
+) compression.Provider {
        switch compressionType {
        case pb.CompressionType_NONE:
                return compression.NewNoopProvider()
diff --git a/pulsar/internal/key_based_batch_builder.go 
b/pulsar/internal/key_based_batch_builder.go
new file mode 100644
index 0000000..545c2c8
--- /dev/null
+++ b/pulsar/internal/key_based_batch_builder.go
@@ -0,0 +1,237 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+       "encoding/base64"
+       "sort"
+       "sync"
+       "time"
+
+       "github.com/apache/pulsar-client-go/pulsar/internal/compression"
+       pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto"
+       "github.com/apache/pulsar-client-go/pulsar/log"
+)
+
+/**
+ * Key based batch message container
+ *
+ * incoming single messages:
+ * (k1, v1), (k2, v1), (k3, v1), (k1, v2), (k2, v2), (k3, v2), (k1, v3), (k2, 
v3), (k3, v3)
+ *
+ * batched into multiple batch messages:
+ * [(k1, v1), (k1, v2), (k1, v3)], [(k2, v1), (k2, v2), (k2, v3)], [(k3, v1), 
(k3, v2), (k3, v3)]
+ */
+
+// keyBasedBatches is a simple concurrent-safe map for the batchContainer type
+type keyBasedBatches struct {
+       containers map[string]*batchContainer
+       l          *sync.RWMutex
+}
+
+// keyBasedBatchContainer wraps the objects needed to key based batch.
+// keyBasedBatchContainer implement BatchBuilder as a multiple batches
+// container.
+type keyBasedBatchContainer struct {
+       batches keyBasedBatches
+       batchContainer
+       compressionType pb.CompressionType
+       level           compression.Level
+}
+
+// newKeyBasedBatches init a keyBasedBatches
+func newKeyBasedBatches() keyBasedBatches {
+       return keyBasedBatches{
+               containers: map[string]*batchContainer{},
+               l:          &sync.RWMutex{},
+       }
+}
+
+func (h *keyBasedBatches) Add(key string, val *batchContainer) {
+       h.l.Lock()
+       defer h.l.Unlock()
+       h.containers[key] = val
+}
+
+func (h *keyBasedBatches) Del(key string) {
+       h.l.Lock()
+       defer h.l.Unlock()
+       delete(h.containers, key)
+}
+
+func (h *keyBasedBatches) Val(key string) *batchContainer {
+       h.l.RLock()
+       defer h.l.RUnlock()
+       return h.containers[key]
+}
+
+// NewKeyBasedBatchBuilder init batch builder and return BatchBuilder
+// pointer. Build a new key based batch message container.
+func NewKeyBasedBatchBuilder(
+       maxMessages uint, maxBatchSize uint, producerName string, producerID 
uint64,
+       compressionType pb.CompressionType, level compression.Level,
+       bufferPool BuffersPool, logger log.Logger,
+) (BatchBuilder, error) {
+
+       bb := &keyBasedBatchContainer{
+               batches: newKeyBasedBatches(),
+               batchContainer: newBatchContainer(
+                       maxMessages, maxBatchSize, producerName, producerID,
+                       compressionType, level, bufferPool, logger,
+               ),
+               compressionType: compressionType,
+               level:           level,
+       }
+
+       if compressionType != pb.CompressionType_NONE {
+               bb.msgMetadata.Compression = &compressionType
+       }
+
+       return bb, nil
+}
+
+// IsFull check if the size in the current batch exceeds the maximum size 
allowed by the batch
+func (bc *keyBasedBatchContainer) IsFull() bool {
+       return bc.numMessages >= bc.maxMessages || bc.buffer.ReadableBytes() > 
uint32(bc.maxBatchSize)
+}
+
+func (bc *keyBasedBatchContainer) IsMultiBatches() bool {
+       return true
+}
+
+func (bc *keyBasedBatchContainer) hasSpace(payload []byte) bool {
+       msgSize := uint32(len(payload))
+       return bc.numMessages > 0 && (bc.buffer.ReadableBytes()+msgSize) > 
uint32(bc.maxBatchSize)
+}
+
+// Add will add single message to key-based batch with message key.
+func (bc *keyBasedBatchContainer) Add(
+       metadata *pb.SingleMessageMetadata, sequenceIDGenerator *uint64,
+       payload []byte,
+       callback interface{}, replicateTo []string, deliverAt time.Time,
+) bool {
+       if replicateTo != nil && bc.numMessages != 0 {
+               // If the current batch is not empty and we're trying to set 
the replication clusters,
+               // then we need to force the current batch to flush and send 
the message individually
+               return false
+       } else if bc.msgMetadata.ReplicateTo != nil {
+               // There's already a message with cluster replication list. 
need to flush before next
+               // message can be sent
+               return false
+       } else if bc.hasSpace(payload) {
+               // The current batch is full. Producer has to call Flush() to
+               return false
+       }
+
+       var msgKey = getMessageKey(metadata)
+       batchPart := bc.batches.Val(msgKey)
+       if batchPart == nil {
+               // create batchContainer for new key
+               t := newBatchContainer(
+                       bc.maxMessages, bc.maxBatchSize, bc.producerName, 
bc.producerID,
+                       bc.compressionType, bc.level, bc.buffersPool, bc.log,
+               )
+               batchPart = &t
+               bc.batches.Add(msgKey, &t)
+       }
+
+       // add message to batch container
+       batchPart.Add(
+               metadata, sequenceIDGenerator, payload, callback, replicateTo,
+               deliverAt,
+       )
+       addSingleMessageToBatch(bc.buffer, metadata, payload)
+
+       bc.numMessages++
+       bc.callbacks = append(bc.callbacks, callback)
+       return true
+}
+
+func (bc *keyBasedBatchContainer) reset() {
+       bc.batches.l.RLock()
+       defer bc.batches.l.RUnlock()
+       for _, container := range bc.batches.containers {
+               container.reset()
+       }
+       bc.numMessages = 0
+       bc.buffer.Clear()
+       bc.callbacks = []interface{}{}
+       bc.msgMetadata.ReplicateTo = nil
+       bc.msgMetadata.DeliverAtTime = nil
+       bc.batches.containers = map[string]*batchContainer{}
+}
+
+// Flush all the messages buffered in multiple batches and wait until all
+// messages have been successfully persisted.
+func (bc *keyBasedBatchContainer) FlushBatches() (
+       batchesData []Buffer, sequenceIDs []uint64, callbacks [][]interface{},
+) {
+       if bc.numMessages == 0 {
+               // No-Op for empty batch
+               return nil, nil, nil
+       }
+
+       bc.log.Debug("keyBasedBatchContainer flush: messages: ", bc.numMessages)
+       var batchesLen = len(bc.batches.containers)
+       var idx = 0
+       sortedKeys := make([]string, 0, batchesLen)
+
+       batchesData = make([]Buffer, batchesLen)
+       sequenceIDs = make([]uint64, batchesLen)
+       callbacks = make([][]interface{}, batchesLen)
+
+       bc.batches.l.RLock()
+       defer bc.batches.l.RUnlock()
+       for k := range bc.batches.containers {
+               sortedKeys = append(sortedKeys, k)
+       }
+       sort.Strings(sortedKeys)
+       for _, k := range sortedKeys {
+               container := bc.batches.containers[k]
+               b, s, c := container.Flush()
+               if b != nil {
+                       batchesData[idx] = b
+                       sequenceIDs[idx] = s
+                       callbacks[idx] = c
+               }
+               idx++
+       }
+
+       bc.reset()
+       return batchesData, sequenceIDs, callbacks
+}
+
+func (bc *keyBasedBatchContainer) Flush() (
+       batchData Buffer, sequenceID uint64, callbacks []interface{},
+) {
+       panic("multi batches container not support Flush(), please use 
FlushBatches() instead")
+}
+
+func (bc *keyBasedBatchContainer) Close() error {
+       return bc.compressionProvider.Close()
+}
+
+// getMessageKey extracts message key from message metadata.
+// If the OrderingKey exists, the base64-encoded string is returned,
+// otherwise the PartitionKey is returned.
+func getMessageKey(metadata *pb.SingleMessageMetadata) string {
+       if k := metadata.GetOrderingKey(); k != nil {
+               return base64.StdEncoding.EncodeToString(k)
+       }
+       return metadata.GetPartitionKey()
+}
diff --git a/pulsar/producer.go b/pulsar/producer.go
index 1dc0775..b41415a 100644
--- a/pulsar/producer.go
+++ b/pulsar/producer.go
@@ -152,6 +152,13 @@ type ProducerOptions struct {
 
        // MaxReconnectToBroker set the maximum retry number of 
reconnectToBroker. (default: ultimate)
        MaxReconnectToBroker *uint
+
+       // BatcherBuilderType sets the batch builder type (default 
DefaultBatchBuilder)
+       // This will be used to create batch container when batching is enabled.
+       // Options:
+       // - DefaultBatchBuilder
+       // - KeyBasedBatchBuilder
+       BatcherBuilderType
 }
 
 // Producer is used to publish messages on a topic
diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 09d9eb8..615bc01 100644
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -115,7 +115,7 @@ type partitionProducer struct {
        options             *ProducerOptions
        producerName        string
        producerID          uint64
-       batchBuilder        *internal.BatchBuilder
+       batchBuilder        internal.BatchBuilder
        sequenceIDGenerator *uint64
        batchFlushTicker    *time.Ticker
 
@@ -242,8 +242,23 @@ func (p *partitionProducer) grabCnx() error {
        }
 
        p.producerName = res.Response.ProducerSuccess.GetProducerName()
-       if p.batchBuilder == nil {
-               p.batchBuilder, err = 
internal.NewBatchBuilder(p.options.BatchingMaxMessages, 
p.options.BatchingMaxSize,
+       if p.options.DisableBatching {
+               provider, _ := GetBatcherBuilderProvider(DefaultBatchBuilder)
+               p.batchBuilder, err = provider(p.options.BatchingMaxMessages, 
p.options.BatchingMaxSize,
+                       p.producerName, p.producerID, 
pb.CompressionType(p.options.CompressionType),
+                       compression.Level(p.options.CompressionLevel),
+                       p,
+                       p.log)
+               if err != nil {
+                       return err
+               }
+       } else if p.batchBuilder == nil {
+               provider, err := 
GetBatcherBuilderProvider(p.options.BatcherBuilderType)
+               if err != nil {
+                       provider, _ = 
GetBatcherBuilderProvider(DefaultBatchBuilder)
+               }
+
+               p.batchBuilder, err = provider(p.options.BatchingMaxMessages, 
p.options.BatchingMaxSize,
                        p.producerName, p.producerID, 
pb.CompressionType(p.options.CompressionType),
                        compression.Level(p.options.CompressionLevel),
                        p,
@@ -338,7 +353,11 @@ func (p *partitionProducer) runEventsLoop() {
                case <-p.connectClosedCh:
                        p.reconnectToBroker()
                case <-p.batchFlushTicker.C:
-                       p.internalFlushCurrentBatch()
+                       if p.batchBuilder.IsMultiBatches() {
+                               p.internalFlushCurrentBatches()
+                       } else {
+                               p.internalFlushCurrentBatch()
+                       }
                }
        }
 }
@@ -407,29 +426,26 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                smm.Properties = internal.ConvertFromStringMap(msg.Properties)
        }
 
-       var sequenceID uint64
        if msg.SequenceID != nil {
-               sequenceID = uint64(*msg.SequenceID)
-       } else {
-               sequenceID = internal.GetAndAdd(p.sequenceIDGenerator, 1)
+               sequenceID := uint64(*msg.SequenceID)
+               smm.SequenceId = proto.Uint64(sequenceID)
        }
 
        if !sendAsBatch {
                p.internalFlushCurrentBatch()
        }
-       added := p.batchBuilder.Add(smm, sequenceID, payload, request,
+       added := p.batchBuilder.Add(smm, p.sequenceIDGenerator, payload, 
request,
                msg.ReplicationClusters, deliverAt)
        if !added {
                // The current batch is full.. flush it and retry
                p.internalFlushCurrentBatch()
 
                // after flushing try again to add the current payload
-               if ok := p.batchBuilder.Add(smm, sequenceID, payload, request,
+               if ok := p.batchBuilder.Add(smm, p.sequenceIDGenerator, 
payload, request,
                        msg.ReplicationClusters, deliverAt); !ok {
                        p.publishSemaphore.Release()
                        request.callback(nil, request.msg, errFailAddBatch)
                        p.log.WithField("size", len(payload)).
-                               WithField("sequenceID", sequenceID).
                                WithField("properties", msg.Properties).
                                Error("unable to add message to batch")
                        return
@@ -437,7 +453,11 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
        }
 
        if !sendAsBatch || request.flushImmediately {
-               p.internalFlushCurrentBatch()
+               if p.batchBuilder.IsMultiBatches() {
+                       p.internalFlushCurrentBatches()
+               } else {
+                       p.internalFlushCurrentBatch()
+               }
        }
 }
 
@@ -513,8 +533,32 @@ func (p *partitionProducer) failTimeoutMessages() {
        }
 }
 
+func (p *partitionProducer) internalFlushCurrentBatches() {
+       batchesData, sequenceIDs, callbacks := p.batchBuilder.FlushBatches()
+       if batchesData == nil {
+               return
+       }
+
+       for i := range batchesData {
+               if batchesData[i] == nil {
+                       continue
+               }
+               p.pendingQueue.Put(&pendingItem{
+                       batchData:    batchesData[i],
+                       sequenceID:   sequenceIDs[i],
+                       sendRequests: callbacks[i],
+               })
+               p.cnx.WriteData(batchesData[i])
+       }
+
+}
+
 func (p *partitionProducer) internalFlush(fr *flushRequest) {
-       p.internalFlushCurrentBatch()
+       if p.batchBuilder.IsMultiBatches() {
+               p.internalFlushCurrentBatches()
+       } else {
+               p.internalFlushCurrentBatch()
+       }
 
        pi, ok := p.pendingQueue.PeekLast().(*pendingItem)
        if !ok {

Reply via email to