This is an automated email from the ASF dual-hosted git repository.

baodi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 5fa431d  [feat] Support memory limit for producer. (#955)
5fa431d is described below

commit 5fa431d06dae5615fbadda78e9c8ef68de0aab4f
Author: Baodi Shi <[email protected]>
AuthorDate: Wed Mar 1 16:33:22 2023 +0800

    [feat] Support memory limit for producer. (#955)
    
    * [feat] Support memory limit for producer.
    
    * Fix code reviews.
    
    * Change channel_cond access level
    
    * Change name and add note.
    
    * change name.
---
 pulsar/client.go                                |   4 +
 pulsar/client_impl.go                           |  12 +-
 pulsar/error.go                                 |   5 +
 pulsar/internal/channel_cond.go                 |  76 ++++++++++
 pulsar/internal/channel_cond_test.go            |  55 +++++++
 pulsar/internal/memory_limit_controller.go      | 101 +++++++++++++
 pulsar/internal/memory_limit_controller_test.go | 184 ++++++++++++++++++++++++
 pulsar/producer_partition.go                    |  68 ++++++---
 pulsar/producer_test.go                         | 126 ++++++++++++++++
 9 files changed, 606 insertions(+), 25 deletions(-)

diff --git a/pulsar/client.go b/pulsar/client.go
index 135d22b..75b363d 100644
--- a/pulsar/client.go
+++ b/pulsar/client.go
@@ -144,6 +144,10 @@ type ClientOptions struct {
        MetricsRegisterer prometheus.Registerer
 
        EnableTransaction bool
+
+       // Limit of client memory usage (in byte). The 64M default can 
guarantee a high producer throughput.
+       // Config less than 0 indicates off memory limit.
+       MemoryLimitBytes int64
 }
 
 // Client represents a pulsar client
diff --git a/pulsar/client_impl.go b/pulsar/client_impl.go
index f444804..7d90922 100644
--- a/pulsar/client_impl.go
+++ b/pulsar/client_impl.go
@@ -33,6 +33,7 @@ const (
        defaultConnectionTimeout = 10 * time.Second
        defaultOperationTimeout  = 30 * time.Second
        defaultKeepAliveInterval = 30 * time.Second
+       defaultMemoryLimitBytes  = 64 * 1024 * 1024
 )
 
 type client struct {
@@ -42,6 +43,7 @@ type client struct {
        lookupService internal.LookupService
        metrics       *internal.Metrics
        tcClient      *transactionCoordinatorClient
+       memLimit      internal.MemoryLimitController
 
        log log.Logger
 }
@@ -134,11 +136,17 @@ func newClient(options ClientOptions) (Client, error) {
                keepAliveInterval = defaultKeepAliveInterval
        }
 
+       memLimitBytes := options.MemoryLimitBytes
+       if memLimitBytes == 0 {
+               memLimitBytes = defaultMemoryLimitBytes
+       }
+
        c := &client{
                cnxPool: internal.NewConnectionPool(tlsConfig, authProvider, 
connectionTimeout, keepAliveInterval,
                        maxConnectionsPerHost, logger, metrics),
-               log:     logger,
-               metrics: metrics,
+               log:      logger,
+               metrics:  metrics,
+               memLimit: internal.NewMemoryLimitController(memLimitBytes),
        }
        serviceNameResolver := internal.NewPulsarServiceNameResolver(url)
 
diff --git a/pulsar/error.go b/pulsar/error.go
index ce366f5..0aa1e3c 100644
--- a/pulsar/error.go
+++ b/pulsar/error.go
@@ -110,6 +110,9 @@ const (
        InvalidStatus
        // TransactionError means this is a transaction related error
        TransactionError
+
+       // ClientMemoryBufferIsFull client limit buffer is full
+       ClientMemoryBufferIsFull
 )
 
 // Error implement error interface, composed of two parts: msg and result.
@@ -216,6 +219,8 @@ func getResultStr(r Result) string {
                return "ProducerClosed"
        case SchemaFailure:
                return "SchemaFailure"
+       case ClientMemoryBufferIsFull:
+               return "ClientMemoryBufferIsFull"
        default:
                return fmt.Sprintf("Result(%d)", r)
        }
diff --git a/pulsar/internal/channel_cond.go b/pulsar/internal/channel_cond.go
new file mode 100644
index 0000000..38301ab
--- /dev/null
+++ b/pulsar/internal/channel_cond.go
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+       "context"
+       "sync"
+       "sync/atomic"
+       "unsafe"
+)
+
+type chCond struct {
+       L sync.Locker
+       // The pointer to the channel, the channel pointed to may change,
+       // because we will use the channel's close mechanism to implement 
broadcast notifications.
+       notifyChPtr unsafe.Pointer
+}
+
+func newCond(l sync.Locker) *chCond {
+       c := &chCond{L: l}
+       n := make(chan struct{})
+       c.notifyChPtr = unsafe.Pointer(&n)
+       return c
+}
+
+// wait for broadcast calls. Similar to regular sync.Cond
+func (c *chCond) wait() {
+       n := c.notifyChan()
+       c.L.Unlock()
+       <-n
+       c.L.Lock()
+}
+
+// waitWithContext Same as wait() call, but the end condition can also be 
controlled through the context.
+func (c *chCond) waitWithContext(ctx context.Context) bool {
+       n := c.notifyChan()
+       c.L.Unlock()
+       defer c.L.Lock()
+       select {
+       case <-n:
+               return true
+       case <-ctx.Done():
+               return false
+       default:
+               return true
+       }
+}
+
+// broadcast wakes all goroutines waiting on c.
+// It is not required for the caller to hold c.L during the call.
+func (c *chCond) broadcast() {
+       n := make(chan struct{})
+       ptrOld := atomic.SwapPointer(&c.notifyChPtr, unsafe.Pointer(&n))
+       // close old channels to trigger broadcast.
+       close(*(*chan struct{})(ptrOld))
+}
+
+func (c *chCond) notifyChan() <-chan struct{} {
+       ptr := atomic.LoadPointer(&c.notifyChPtr)
+       return *((*chan struct{})(ptr))
+}
diff --git a/pulsar/internal/channel_cond_test.go 
b/pulsar/internal/channel_cond_test.go
new file mode 100644
index 0000000..a73d44e
--- /dev/null
+++ b/pulsar/internal/channel_cond_test.go
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+       "context"
+       "sync"
+       "testing"
+       "time"
+)
+
+func TestChCond(t *testing.T) {
+       cond := newCond(&sync.Mutex{})
+       wg := sync.WaitGroup{}
+       wg.Add(1)
+       go func() {
+               cond.L.Lock()
+               cond.wait()
+               cond.L.Unlock()
+               wg.Done()
+       }()
+       time.Sleep(10 * time.Millisecond)
+       cond.broadcast()
+       wg.Wait()
+}
+
+func TestChCondWithContext(t *testing.T) {
+       cond := newCond(&sync.Mutex{})
+       wg := sync.WaitGroup{}
+       ctx, cancel := context.WithCancel(context.Background())
+       wg.Add(1)
+       go func() {
+               cond.L.Lock()
+               cond.waitWithContext(ctx)
+               cond.L.Unlock()
+               wg.Done()
+       }()
+       cancel()
+       wg.Wait()
+}
diff --git a/pulsar/internal/memory_limit_controller.go 
b/pulsar/internal/memory_limit_controller.go
new file mode 100644
index 0000000..5bf8d59
--- /dev/null
+++ b/pulsar/internal/memory_limit_controller.go
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+       "context"
+       "sync"
+       "sync/atomic"
+)
+
+type MemoryLimitController interface {
+       ReserveMemory(ctx context.Context, size int64) bool
+       TryReserveMemory(size int64) bool
+       ForceReserveMemory(size int64)
+       ReleaseMemory(size int64)
+       CurrentUsage() int64
+       CurrentUsagePercent() float64
+       IsMemoryLimited() bool
+}
+
+type memoryLimitController struct {
+       limit        int64
+       chCond       *chCond
+       currentUsage int64
+}
+
+func NewMemoryLimitController(limit int64) MemoryLimitController {
+       mlc := &memoryLimitController{
+               limit:  limit,
+               chCond: newCond(&sync.Mutex{}),
+       }
+       return mlc
+}
+
+func (m *memoryLimitController) ReserveMemory(ctx context.Context, size int64) 
bool {
+       if !m.TryReserveMemory(size) {
+               m.chCond.L.Lock()
+               defer m.chCond.L.Unlock()
+
+               for !m.TryReserveMemory(size) {
+                       if !m.chCond.waitWithContext(ctx) {
+                               return false
+                       }
+               }
+       }
+       return true
+}
+
+func (m *memoryLimitController) TryReserveMemory(size int64) bool {
+       for {
+               current := atomic.LoadInt64(&m.currentUsage)
+               newUsage := current + size
+
+               // This condition means we allowed one request to go over the 
limit.
+               if m.IsMemoryLimited() && current > m.limit {
+                       return false
+               }
+
+               if atomic.CompareAndSwapInt64(&m.currentUsage, current, 
newUsage) {
+                       return true
+               }
+       }
+}
+
+func (m *memoryLimitController) ForceReserveMemory(size int64) {
+       atomic.AddInt64(&m.currentUsage, size)
+}
+
+func (m *memoryLimitController) ReleaseMemory(size int64) {
+       newUsage := atomic.AddInt64(&m.currentUsage, -size)
+       if newUsage+size > m.limit && newUsage <= m.limit {
+               m.chCond.broadcast()
+       }
+}
+
+func (m *memoryLimitController) CurrentUsage() int64 {
+       return atomic.LoadInt64(&m.currentUsage)
+}
+
+func (m *memoryLimitController) CurrentUsagePercent() float64 {
+       return float64(atomic.LoadInt64(&m.currentUsage)) / float64(m.limit)
+}
+
+func (m *memoryLimitController) IsMemoryLimited() bool {
+       return m.limit > 0
+}
diff --git a/pulsar/internal/memory_limit_controller_test.go 
b/pulsar/internal/memory_limit_controller_test.go
new file mode 100644
index 0000000..a62c6e6
--- /dev/null
+++ b/pulsar/internal/memory_limit_controller_test.go
@@ -0,0 +1,184 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+       "context"
+       "sync"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestLimit(t *testing.T) {
+
+       mlc := NewMemoryLimitController(100)
+
+       for i := 0; i < 101; i++ {
+               assert.True(t, mlc.TryReserveMemory(1))
+       }
+
+       assert.False(t, mlc.TryReserveMemory(1))
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+       assert.InDelta(t, 1.01, mlc.CurrentUsagePercent(), 0.000001)
+
+       mlc.ReleaseMemory(1)
+       assert.Equal(t, int64(100), mlc.CurrentUsage())
+       assert.InDelta(t, 1.0, mlc.CurrentUsagePercent(), 0.000001)
+
+       assert.True(t, mlc.TryReserveMemory(1))
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+       mlc.ForceReserveMemory(99)
+       assert.False(t, mlc.TryReserveMemory(1))
+       assert.Equal(t, int64(200), mlc.CurrentUsage())
+       assert.InDelta(t, 2.0, mlc.CurrentUsagePercent(), 0.000001)
+
+       mlc.ReleaseMemory(50)
+       assert.False(t, mlc.TryReserveMemory(1))
+       assert.Equal(t, int64(150), mlc.CurrentUsage())
+       assert.InDelta(t, 1.5, mlc.CurrentUsagePercent(), 0.000001)
+}
+
+func TestDisableLimit(t *testing.T) {
+       mlc := NewMemoryLimitController(-1)
+       assert.True(t, mlc.TryReserveMemory(1000000))
+       assert.True(t, mlc.ReserveMemory(context.Background(), 1000000))
+       mlc.ReleaseMemory(1000000)
+       assert.Equal(t, int64(1000000), mlc.CurrentUsage())
+}
+
+func TestMultiGoroutineTryReserveMem(t *testing.T) {
+       mlc := NewMemoryLimitController(10000)
+
+       // Multi goroutine try reserve memory.
+       wg := sync.WaitGroup{}
+
+       wg.Add(10)
+       for i := 0; i < 10; i++ {
+               go func() {
+                       for i := 0; i < 1000; i++ {
+                               assert.True(t, mlc.TryReserveMemory(1))
+                       }
+                       wg.Done()
+               }()
+       }
+       assert.True(t, mlc.TryReserveMemory(1))
+       wg.Wait()
+       assert.False(t, mlc.TryReserveMemory(1))
+       assert.Equal(t, int64(10001), mlc.CurrentUsage())
+       assert.InDelta(t, 1.0001, mlc.CurrentUsagePercent(), 0.000001)
+}
+
+func TestReserveWithContext(t *testing.T) {
+       mlc := NewMemoryLimitController(100)
+       assert.True(t, mlc.TryReserveMemory(101))
+       gorNum := 10
+
+       // Reserve ctx timeout
+       waitGroup := sync.WaitGroup{}
+       waitGroup.Add(gorNum)
+       ctx, cancel := context.WithTimeout(context.Background(), 
100*time.Millisecond)
+       defer cancel()
+       for i := 0; i < gorNum; i++ {
+               go func() {
+                       assert.False(t, mlc.ReserveMemory(ctx, 1))
+                       waitGroup.Done()
+               }()
+       }
+       waitGroup.Wait()
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+       // Reserve ctx cancel
+       waitGroup.Add(gorNum)
+       cancelCtx, cancel := context.WithCancel(context.Background())
+       for i := 0; i < gorNum; i++ {
+               go func() {
+                       assert.False(t, mlc.ReserveMemory(cancelCtx, 1))
+                       waitGroup.Done()
+               }()
+       }
+       cancel()
+       waitGroup.Wait()
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+}
+
+func TestBlocking(t *testing.T) {
+       mlc := NewMemoryLimitController(100)
+       assert.True(t, mlc.TryReserveMemory(101))
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+       assert.InDelta(t, 1.01, mlc.CurrentUsagePercent(), 0.000001)
+
+       gorNum := 10
+       chs := make([]chan int, gorNum)
+       for i := 0; i < gorNum; i++ {
+               chs[i] = make(chan int, 1)
+               go reserveMemory(mlc, chs[i])
+       }
+
+       // The threads are blocked since the quota is full
+       for i := 0; i < gorNum; i++ {
+               assert.False(t, awaitCh(chs[i]))
+       }
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+       mlc.ReleaseMemory(int64(gorNum))
+       for i := 0; i < gorNum; i++ {
+               assert.True(t, awaitCh(chs[i]))
+       }
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+}
+
+func TestStepRelease(t *testing.T) {
+       mlc := NewMemoryLimitController(100)
+       assert.True(t, mlc.TryReserveMemory(101))
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+       assert.InDelta(t, 1.01, mlc.CurrentUsagePercent(), 0.000001)
+
+       gorNum := 10
+       ch := make(chan int, 1)
+       for i := 0; i < gorNum; i++ {
+               go reserveMemory(mlc, ch)
+       }
+
+       // The threads are blocked since the quota is full
+       assert.False(t, awaitCh(ch))
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+       for i := 0; i < gorNum; i++ {
+               mlc.ReleaseMemory(1)
+               assert.True(t, awaitCh(ch))
+               assert.False(t, awaitCh(ch))
+       }
+       assert.Equal(t, int64(101), mlc.CurrentUsage())
+}
+
+func reserveMemory(mlc MemoryLimitController, ch chan int) {
+       mlc.ReserveMemory(context.Background(), 1)
+       ch <- 1
+}
+
+func awaitCh(ch chan int) bool {
+       select {
+       case <-ch:
+               return true
+       case <-time.After(100 * time.Millisecond):
+               return false
+       }
+}
diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 160693c..c3a0aa9 100644
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -51,13 +51,14 @@ const (
 )
 
 var (
-       errFailAddToBatch  = newError(AddToBatchFailed, "message add to batch 
failed")
-       errSendTimeout     = newError(TimeoutError, "message send timeout")
-       errSendQueueIsFull = newError(ProducerQueueIsFull, "producer send queue 
is full")
-       errContextExpired  = newError(TimeoutError, "message send context 
expired")
-       errMessageTooLarge = newError(MessageTooBig, "message size exceeds 
MaxMessageSize")
-       errMetaTooLarge    = newError(InvalidMessage, "message metadata size 
exceeds MaxMessageSize")
-       errProducerClosed  = newError(ProducerClosed, "producer already been 
closed")
+       errFailAddToBatch     = newError(AddToBatchFailed, "message add to 
batch failed")
+       errSendTimeout        = newError(TimeoutError, "message send timeout")
+       errSendQueueIsFull    = newError(ProducerQueueIsFull, "producer send 
queue is full")
+       errContextExpired     = newError(TimeoutError, "message send context 
expired")
+       errMessageTooLarge    = newError(MessageTooBig, "message size exceeds 
MaxMessageSize")
+       errMetaTooLarge       = newError(InvalidMessage, "message metadata size 
exceeds MaxMessageSize")
+       errProducerClosed     = newError(ProducerClosed, "producer already been 
closed")
+       errMemoryBufferIsFull = newError(ClientMemoryBufferIsFull, "client 
memory buffer is full")
 
        buffersPool sync.Pool
 )
@@ -483,6 +484,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
 
        // read payload from message
        uncompressedPayload := msg.Payload
+       uncompressedPayloadSize := int64(len(uncompressedPayload))
 
        var schemaPayload []byte
        var err error
@@ -494,14 +496,14 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
 
        // The block chan must be closed when returned with exception
        defer request.stopBlock()
-       if !p.canAddToQueue(request) {
+       if !p.canAddToQueue(request, uncompressedPayloadSize) {
                return
        }
 
        if p.options.DisableMultiSchema {
                if msg.Schema != nil && p.options.Schema != nil &&
                        msg.Schema.GetSchemaInfo().hash() != 
p.options.Schema.GetSchemaInfo().hash() {
-                       p.publishSemaphore.Release()
+                       p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                        request.callback(nil, request.msg, fmt.Errorf("msg 
schema can not match with producer schema"))
                        p.log.WithError(err).Errorf("The producer %s of the 
topic %s is disabled the `MultiSchema`", p.producerName, p.topic)
                        return
@@ -520,7 +522,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                if uncompressedPayload == nil && schema != nil {
                        schemaPayload, err = schema.Encode(msg.Value)
                        if err != nil {
-                               p.publishSemaphore.Release()
+                               
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                                request.callback(nil, request.msg, 
newError(SchemaFailure, err.Error()))
                                p.log.WithError(err).Errorf("Schema encode 
message failed %s", msg.Value)
                                return
@@ -536,7 +538,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                if schemaVersion == nil {
                        schemaVersion, err = 
p.getOrCreateSchema(schema.GetSchemaInfo())
                        if err != nil {
-                               p.publishSemaphore.Release()
+                               
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                                p.log.WithError(err).Error("get schema version 
fail")
                                request.callback(nil, request.msg, 
fmt.Errorf("get schema version fail, err: %w", err))
                                return
@@ -596,7 +598,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
 
        // if msg is too large and chunking is disabled
        if checkSize > maxMessageSize && !p.options.EnableChunking {
-               p.publishSemaphore.Release()
+               p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                request.callback(nil, request.msg, errMessageTooLarge)
                p.log.WithError(errMessageTooLarge).
                        WithField("size", checkSize).
@@ -615,7 +617,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
        } else {
                payloadChunkSize = int(p._getConn().GetMaxMessageSize()) - 
proto.Size(mm)
                if payloadChunkSize <= 0 {
-                       p.publishSemaphore.Release()
+                       p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                        request.callback(nil, msg, errMetaTooLarge)
                        p.log.WithError(errMetaTooLarge).
                                WithField("metadata size", proto.Size(mm)).
@@ -663,7 +665,8 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                                        chunkRecorder:    cr,
                                }
                                // the permit of first chunk has acquired
-                               if chunkID != 0 && !p.canAddToQueue(nsr) {
+                               if chunkID != 0 && !p.canAddToQueue(nsr, 0) {
+                                       
p.releaseSemaphoreAndMem(uncompressedPayloadSize - int64(rhs))
                                        return
                                }
                                p.internalSingleSend(mm, 
compressedPayload[lhs:rhs], nsr, uint32(maxMessageSize))
@@ -688,7 +691,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                        // after flushing try again to add the current payload
                        if ok := p.batchBuilder.Add(smm, p.sequenceIDGenerator, 
uncompressedPayload, request,
                                msg.ReplicationClusters, deliverAt, 
schemaVersion, multiSchemaEnabled); !ok {
-                               p.publishSemaphore.Release()
+                               
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                                request.callback(nil, request.msg, 
errFailAddToBatch)
                                p.log.WithField("size", uncompressedSize).
                                        WithField("properties", msg.Properties).
@@ -797,7 +800,7 @@ func (p *partitionProducer) internalSingleSend(mm 
*pb.MessageMetadata,
                maxMessageSize,
        ); err != nil {
                request.callback(nil, request.msg, err)
-               p.publishSemaphore.Release()
+               p.releaseSemaphoreAndMem(int64(len(msg.Payload)))
                p.log.WithError(err).Errorf("Single message serialize failed 
%s", msg.Value)
                return
        }
@@ -935,7 +938,7 @@ func (p *partitionProducer) failTimeoutMessages() {
                                sr := i.(*sendRequest)
                                if sr.msg != nil {
                                        size := len(sr.msg.Payload)
-                                       p.publishSemaphore.Release()
+                                       p.releaseSemaphoreAndMem(int64(size))
                                        p.metrics.MessagesPending.Dec()
                                        
p.metrics.BytesPending.Sub(float64(size))
                                        p.metrics.PublishErrorsTimeout.Inc()
@@ -1139,8 +1142,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
                        sr := i.(*sendRequest)
                        if sr.msg != nil {
                                atomic.StoreInt64(&p.lastSequenceID, 
int64(pi.sequenceID))
-                               p.publishSemaphore.Release()
-
+                               
p.releaseSemaphoreAndMem(int64(len(sr.msg.Payload)))
                                
p.metrics.PublishLatency.Observe(float64(now-sr.publishTime.UnixNano()) / 1.0e9)
                                p.metrics.MessagesPublished.Inc()
                                p.metrics.MessagesPending.Dec()
@@ -1326,7 +1328,12 @@ func (p *partitionProducer) _getConn() 
internal.Connection {
        return p.conn.Load().(internal.Connection)
 }
 
-func (p *partitionProducer) canAddToQueue(sr *sendRequest) bool {
+func (p *partitionProducer) releaseSemaphoreAndMem(size int64) {
+       p.publishSemaphore.Release()
+       p.client.memLimit.ReleaseMemory(size)
+}
+
+func (p *partitionProducer) canAddToQueue(sr *sendRequest, 
uncompressedPayloadSize int64) bool {
        if p.options.DisableBlockIfQueueFull {
                if !p.publishSemaphore.TryAcquire() {
                        if sr.callback != nil {
@@ -1334,9 +1341,24 @@ func (p *partitionProducer) canAddToQueue(sr 
*sendRequest) bool {
                        }
                        return false
                }
-       } else if !p.publishSemaphore.Acquire(sr.ctx) {
-               sr.callback(nil, sr.msg, errContextExpired)
-               return false
+               if !p.client.memLimit.TryReserveMemory(uncompressedPayloadSize) 
{
+                       p.publishSemaphore.Release()
+                       if sr.callback != nil {
+                               sr.callback(nil, sr.msg, errMemoryBufferIsFull)
+                       }
+                       return false
+               }
+
+       } else {
+               if !p.publishSemaphore.Acquire(sr.ctx) {
+                       sr.callback(nil, sr.msg, errContextExpired)
+                       return false
+               }
+               if !p.client.memLimit.ReserveMemory(sr.ctx, 
uncompressedPayloadSize) {
+                       p.publishSemaphore.Release()
+                       sr.callback(nil, sr.msg, errContextExpired)
+                       return false
+               }
        }
        p.metrics.MessagesPending.Inc()
        p.metrics.BytesPending.Add(float64(len(sr.msg.Payload)))
diff --git a/pulsar/producer_test.go b/pulsar/producer_test.go
index e69a14c..f86d01a 100644
--- a/pulsar/producer_test.go
+++ b/pulsar/producer_test.go
@@ -1779,3 +1779,129 @@ func TestWaitForExclusiveProducer(t *testing.T) {
        producer1.Close()
        wg.Wait()
 }
+
+func TestMemLimitRejectProducerMessages(t *testing.T) {
+
+       c, err := NewClient(ClientOptions{
+               URL:              serviceURL,
+               MemoryLimitBytes: 100 * 1024,
+       })
+       assert.NoError(t, err)
+       defer c.Close()
+
+       topicName := newTopicName()
+       producer1, _ := c.CreateProducer(ProducerOptions{
+               Topic:                   topicName,
+               DisableBlockIfQueueFull: true,
+               DisableBatching:         false,
+               BatchingMaxPublishDelay: 100 * time.Second,
+               SendTimeout:             2 * time.Second,
+       })
+
+       producer2, _ := c.CreateProducer(ProducerOptions{
+               Topic:                   topicName,
+               DisableBlockIfQueueFull: true,
+               DisableBatching:         false,
+               BatchingMaxPublishDelay: 100 * time.Second,
+               SendTimeout:             2 * time.Second,
+       })
+
+       n := 101
+       for i := 0; i < n/2; i++ {
+               producer1.SendAsync(context.Background(), &ProducerMessage{
+                       Payload: make([]byte, 1024),
+               }, func(id MessageID, message *ProducerMessage, e error) {})
+
+               producer2.SendAsync(context.Background(), &ProducerMessage{
+                       Payload: make([]byte, 1024),
+               }, func(id MessageID, message *ProducerMessage, e error) {})
+       }
+       // Last message in order to reach the limit
+       producer1.SendAsync(context.Background(), &ProducerMessage{
+               Payload: make([]byte, 1024),
+       }, func(id MessageID, message *ProducerMessage, e error) {})
+       time.Sleep(100 * time.Millisecond)
+       assert.Equal(t, int64(n*1024), c.(*client).memLimit.CurrentUsage())
+
+       _, err = producer1.Send(context.Background(), &ProducerMessage{
+               Payload: make([]byte, 1024),
+       })
+       assert.Error(t, err)
+       assert.ErrorContains(t, err, getResultStr(ClientMemoryBufferIsFull))
+
+       _, err = producer2.Send(context.Background(), &ProducerMessage{
+               Payload: make([]byte, 1024),
+       })
+       assert.Error(t, err)
+       assert.ErrorContains(t, err, getResultStr(ClientMemoryBufferIsFull))
+
+       // flush pending msg
+       err = producer1.Flush()
+       assert.NoError(t, err)
+       err = producer2.Flush()
+       assert.NoError(t, err)
+       assert.Equal(t, int64(0), c.(*client).memLimit.CurrentUsage())
+
+       _, err = producer1.Send(context.Background(), &ProducerMessage{
+               Payload: make([]byte, 1024),
+       })
+       assert.NoError(t, err)
+       _, err = producer2.Send(context.Background(), &ProducerMessage{
+               Payload: make([]byte, 1024),
+       })
+       assert.NoError(t, err)
+}
+
+func TestMemLimitContextCancel(t *testing.T) {
+
+       c, err := NewClient(ClientOptions{
+               URL:              serviceURL,
+               MemoryLimitBytes: 100 * 1024,
+       })
+       assert.NoError(t, err)
+       defer c.Close()
+
+       topicName := newTopicName()
+       producer, _ := c.CreateProducer(ProducerOptions{
+               Topic:                   topicName,
+               DisableBlockIfQueueFull: false,
+               DisableBatching:         false,
+               BatchingMaxPublishDelay: 100 * time.Second,
+               SendTimeout:             2 * time.Second,
+       })
+
+       n := 101
+       ctx, cancel := context.WithCancel(context.Background())
+       for i := 0; i < n; i++ {
+               producer.SendAsync(ctx, &ProducerMessage{
+                       Payload: make([]byte, 1024),
+               }, func(id MessageID, message *ProducerMessage, e error) {})
+       }
+       time.Sleep(100 * time.Millisecond)
+       assert.Equal(t, int64(n*1024), c.(*client).memLimit.CurrentUsage())
+
+       wg := sync.WaitGroup{}
+       wg.Add(1)
+       go func() {
+               producer.SendAsync(ctx, &ProducerMessage{
+                       Payload: make([]byte, 1024),
+               }, func(id MessageID, message *ProducerMessage, e error) {
+                       assert.Error(t, e)
+                       assert.ErrorContains(t, e, getResultStr(TimeoutError))
+                       wg.Done()
+               })
+       }()
+
+       // cancel pending msg
+       cancel()
+       wg.Wait()
+
+       err = producer.Flush()
+       assert.NoError(t, err)
+       assert.Equal(t, int64(0), c.(*client).memLimit.CurrentUsage())
+
+       _, err = producer.Send(context.Background(), &ProducerMessage{
+               Payload: make([]byte, 1024),
+       })
+       assert.NoError(t, err)
+}

Reply via email to