This is an automated email from the ASF dual-hosted git repository.
baodi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
The following commit(s) were added to refs/heads/master by this push:
new 5fa431d [feat] Support memory limit for producer. (#955)
5fa431d is described below
commit 5fa431d06dae5615fbadda78e9c8ef68de0aab4f
Author: Baodi Shi <[email protected]>
AuthorDate: Wed Mar 1 16:33:22 2023 +0800
[feat] Support memory limit for producer. (#955)
* [feat] Support memory limit for producer.
* Fix code reviews.
* Change channel_cond access level
* Change name and add note.
* change name.
---
pulsar/client.go | 4 +
pulsar/client_impl.go | 12 +-
pulsar/error.go | 5 +
pulsar/internal/channel_cond.go | 76 ++++++++++
pulsar/internal/channel_cond_test.go | 55 +++++++
pulsar/internal/memory_limit_controller.go | 101 +++++++++++++
pulsar/internal/memory_limit_controller_test.go | 184 ++++++++++++++++++++++++
pulsar/producer_partition.go | 68 ++++++---
pulsar/producer_test.go | 126 ++++++++++++++++
9 files changed, 606 insertions(+), 25 deletions(-)
diff --git a/pulsar/client.go b/pulsar/client.go
index 135d22b..75b363d 100644
--- a/pulsar/client.go
+++ b/pulsar/client.go
@@ -144,6 +144,10 @@ type ClientOptions struct {
MetricsRegisterer prometheus.Registerer
EnableTransaction bool
+
+ // Limit of client memory usage (in byte). The 64M default can
guarantee a high producer throughput.
+ // Config less than 0 indicates off memory limit.
+ MemoryLimitBytes int64
}
// Client represents a pulsar client
diff --git a/pulsar/client_impl.go b/pulsar/client_impl.go
index f444804..7d90922 100644
--- a/pulsar/client_impl.go
+++ b/pulsar/client_impl.go
@@ -33,6 +33,7 @@ const (
defaultConnectionTimeout = 10 * time.Second
defaultOperationTimeout = 30 * time.Second
defaultKeepAliveInterval = 30 * time.Second
+ defaultMemoryLimitBytes = 64 * 1024 * 1024
)
type client struct {
@@ -42,6 +43,7 @@ type client struct {
lookupService internal.LookupService
metrics *internal.Metrics
tcClient *transactionCoordinatorClient
+ memLimit internal.MemoryLimitController
log log.Logger
}
@@ -134,11 +136,17 @@ func newClient(options ClientOptions) (Client, error) {
keepAliveInterval = defaultKeepAliveInterval
}
+ memLimitBytes := options.MemoryLimitBytes
+ if memLimitBytes == 0 {
+ memLimitBytes = defaultMemoryLimitBytes
+ }
+
c := &client{
cnxPool: internal.NewConnectionPool(tlsConfig, authProvider,
connectionTimeout, keepAliveInterval,
maxConnectionsPerHost, logger, metrics),
- log: logger,
- metrics: metrics,
+ log: logger,
+ metrics: metrics,
+ memLimit: internal.NewMemoryLimitController(memLimitBytes),
}
serviceNameResolver := internal.NewPulsarServiceNameResolver(url)
diff --git a/pulsar/error.go b/pulsar/error.go
index ce366f5..0aa1e3c 100644
--- a/pulsar/error.go
+++ b/pulsar/error.go
@@ -110,6 +110,9 @@ const (
InvalidStatus
// TransactionError means this is a transaction related error
TransactionError
+
+ // ClientMemoryBufferIsFull client limit buffer is full
+ ClientMemoryBufferIsFull
)
// Error implement error interface, composed of two parts: msg and result.
@@ -216,6 +219,8 @@ func getResultStr(r Result) string {
return "ProducerClosed"
case SchemaFailure:
return "SchemaFailure"
+ case ClientMemoryBufferIsFull:
+ return "ClientMemoryBufferIsFull"
default:
return fmt.Sprintf("Result(%d)", r)
}
diff --git a/pulsar/internal/channel_cond.go b/pulsar/internal/channel_cond.go
new file mode 100644
index 0000000..38301ab
--- /dev/null
+++ b/pulsar/internal/channel_cond.go
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+)
+
+type chCond struct {
+ L sync.Locker
+ // The pointer to the channel, the channel pointed to may change,
+ // because we will use the channel's close mechanism to implement
broadcast notifications.
+ notifyChPtr unsafe.Pointer
+}
+
+func newCond(l sync.Locker) *chCond {
+ c := &chCond{L: l}
+ n := make(chan struct{})
+ c.notifyChPtr = unsafe.Pointer(&n)
+ return c
+}
+
+// wait for broadcast calls. Similar to regular sync.Cond
+func (c *chCond) wait() {
+ n := c.notifyChan()
+ c.L.Unlock()
+ <-n
+ c.L.Lock()
+}
+
+// waitWithContext Same as wait() call, but the end condition can also be
controlled through the context.
+func (c *chCond) waitWithContext(ctx context.Context) bool {
+ n := c.notifyChan()
+ c.L.Unlock()
+ defer c.L.Lock()
+ select {
+ case <-n:
+ return true
+ case <-ctx.Done():
+ return false
+ default:
+ return true
+ }
+}
+
+// broadcast wakes all goroutines waiting on c.
+// It is not required for the caller to hold c.L during the call.
+func (c *chCond) broadcast() {
+ n := make(chan struct{})
+ ptrOld := atomic.SwapPointer(&c.notifyChPtr, unsafe.Pointer(&n))
+ // close old channels to trigger broadcast.
+ close(*(*chan struct{})(ptrOld))
+}
+
+func (c *chCond) notifyChan() <-chan struct{} {
+ ptr := atomic.LoadPointer(&c.notifyChPtr)
+ return *((*chan struct{})(ptr))
+}
diff --git a/pulsar/internal/channel_cond_test.go
b/pulsar/internal/channel_cond_test.go
new file mode 100644
index 0000000..a73d44e
--- /dev/null
+++ b/pulsar/internal/channel_cond_test.go
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestChCond(t *testing.T) {
+ cond := newCond(&sync.Mutex{})
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ cond.L.Lock()
+ cond.wait()
+ cond.L.Unlock()
+ wg.Done()
+ }()
+ time.Sleep(10 * time.Millisecond)
+ cond.broadcast()
+ wg.Wait()
+}
+
+func TestChCondWithContext(t *testing.T) {
+ cond := newCond(&sync.Mutex{})
+ wg := sync.WaitGroup{}
+ ctx, cancel := context.WithCancel(context.Background())
+ wg.Add(1)
+ go func() {
+ cond.L.Lock()
+ cond.waitWithContext(ctx)
+ cond.L.Unlock()
+ wg.Done()
+ }()
+ cancel()
+ wg.Wait()
+}
diff --git a/pulsar/internal/memory_limit_controller.go
b/pulsar/internal/memory_limit_controller.go
new file mode 100644
index 0000000..5bf8d59
--- /dev/null
+++ b/pulsar/internal/memory_limit_controller.go
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+)
+
+type MemoryLimitController interface {
+ ReserveMemory(ctx context.Context, size int64) bool
+ TryReserveMemory(size int64) bool
+ ForceReserveMemory(size int64)
+ ReleaseMemory(size int64)
+ CurrentUsage() int64
+ CurrentUsagePercent() float64
+ IsMemoryLimited() bool
+}
+
+type memoryLimitController struct {
+ limit int64
+ chCond *chCond
+ currentUsage int64
+}
+
+func NewMemoryLimitController(limit int64) MemoryLimitController {
+ mlc := &memoryLimitController{
+ limit: limit,
+ chCond: newCond(&sync.Mutex{}),
+ }
+ return mlc
+}
+
+func (m *memoryLimitController) ReserveMemory(ctx context.Context, size int64)
bool {
+ if !m.TryReserveMemory(size) {
+ m.chCond.L.Lock()
+ defer m.chCond.L.Unlock()
+
+ for !m.TryReserveMemory(size) {
+ if !m.chCond.waitWithContext(ctx) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+func (m *memoryLimitController) TryReserveMemory(size int64) bool {
+ for {
+ current := atomic.LoadInt64(&m.currentUsage)
+ newUsage := current + size
+
+ // This condition means we allowed one request to go over the
limit.
+ if m.IsMemoryLimited() && current > m.limit {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt64(&m.currentUsage, current,
newUsage) {
+ return true
+ }
+ }
+}
+
+func (m *memoryLimitController) ForceReserveMemory(size int64) {
+ atomic.AddInt64(&m.currentUsage, size)
+}
+
+func (m *memoryLimitController) ReleaseMemory(size int64) {
+ newUsage := atomic.AddInt64(&m.currentUsage, -size)
+ if newUsage+size > m.limit && newUsage <= m.limit {
+ m.chCond.broadcast()
+ }
+}
+
+func (m *memoryLimitController) CurrentUsage() int64 {
+ return atomic.LoadInt64(&m.currentUsage)
+}
+
+func (m *memoryLimitController) CurrentUsagePercent() float64 {
+ return float64(atomic.LoadInt64(&m.currentUsage)) / float64(m.limit)
+}
+
+func (m *memoryLimitController) IsMemoryLimited() bool {
+ return m.limit > 0
+}
diff --git a/pulsar/internal/memory_limit_controller_test.go
b/pulsar/internal/memory_limit_controller_test.go
new file mode 100644
index 0000000..a62c6e6
--- /dev/null
+++ b/pulsar/internal/memory_limit_controller_test.go
@@ -0,0 +1,184 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package internal
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLimit(t *testing.T) {
+
+ mlc := NewMemoryLimitController(100)
+
+ for i := 0; i < 101; i++ {
+ assert.True(t, mlc.TryReserveMemory(1))
+ }
+
+ assert.False(t, mlc.TryReserveMemory(1))
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+ assert.InDelta(t, 1.01, mlc.CurrentUsagePercent(), 0.000001)
+
+ mlc.ReleaseMemory(1)
+ assert.Equal(t, int64(100), mlc.CurrentUsage())
+ assert.InDelta(t, 1.0, mlc.CurrentUsagePercent(), 0.000001)
+
+ assert.True(t, mlc.TryReserveMemory(1))
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+ mlc.ForceReserveMemory(99)
+ assert.False(t, mlc.TryReserveMemory(1))
+ assert.Equal(t, int64(200), mlc.CurrentUsage())
+ assert.InDelta(t, 2.0, mlc.CurrentUsagePercent(), 0.000001)
+
+ mlc.ReleaseMemory(50)
+ assert.False(t, mlc.TryReserveMemory(1))
+ assert.Equal(t, int64(150), mlc.CurrentUsage())
+ assert.InDelta(t, 1.5, mlc.CurrentUsagePercent(), 0.000001)
+}
+
+func TestDisableLimit(t *testing.T) {
+ mlc := NewMemoryLimitController(-1)
+ assert.True(t, mlc.TryReserveMemory(1000000))
+ assert.True(t, mlc.ReserveMemory(context.Background(), 1000000))
+ mlc.ReleaseMemory(1000000)
+ assert.Equal(t, int64(1000000), mlc.CurrentUsage())
+}
+
+func TestMultiGoroutineTryReserveMem(t *testing.T) {
+ mlc := NewMemoryLimitController(10000)
+
+ // Multi goroutine try reserve memory.
+ wg := sync.WaitGroup{}
+
+ wg.Add(10)
+ for i := 0; i < 10; i++ {
+ go func() {
+ for i := 0; i < 1000; i++ {
+ assert.True(t, mlc.TryReserveMemory(1))
+ }
+ wg.Done()
+ }()
+ }
+ assert.True(t, mlc.TryReserveMemory(1))
+ wg.Wait()
+ assert.False(t, mlc.TryReserveMemory(1))
+ assert.Equal(t, int64(10001), mlc.CurrentUsage())
+ assert.InDelta(t, 1.0001, mlc.CurrentUsagePercent(), 0.000001)
+}
+
+func TestReserveWithContext(t *testing.T) {
+ mlc := NewMemoryLimitController(100)
+ assert.True(t, mlc.TryReserveMemory(101))
+ gorNum := 10
+
+ // Reserve ctx timeout
+ waitGroup := sync.WaitGroup{}
+ waitGroup.Add(gorNum)
+ ctx, cancel := context.WithTimeout(context.Background(),
100*time.Millisecond)
+ defer cancel()
+ for i := 0; i < gorNum; i++ {
+ go func() {
+ assert.False(t, mlc.ReserveMemory(ctx, 1))
+ waitGroup.Done()
+ }()
+ }
+ waitGroup.Wait()
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+ // Reserve ctx cancel
+ waitGroup.Add(gorNum)
+ cancelCtx, cancel := context.WithCancel(context.Background())
+ for i := 0; i < gorNum; i++ {
+ go func() {
+ assert.False(t, mlc.ReserveMemory(cancelCtx, 1))
+ waitGroup.Done()
+ }()
+ }
+ cancel()
+ waitGroup.Wait()
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+}
+
+func TestBlocking(t *testing.T) {
+ mlc := NewMemoryLimitController(100)
+ assert.True(t, mlc.TryReserveMemory(101))
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+ assert.InDelta(t, 1.01, mlc.CurrentUsagePercent(), 0.000001)
+
+ gorNum := 10
+ chs := make([]chan int, gorNum)
+ for i := 0; i < gorNum; i++ {
+ chs[i] = make(chan int, 1)
+ go reserveMemory(mlc, chs[i])
+ }
+
+ // The threads are blocked since the quota is full
+ for i := 0; i < gorNum; i++ {
+ assert.False(t, awaitCh(chs[i]))
+ }
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+ mlc.ReleaseMemory(int64(gorNum))
+ for i := 0; i < gorNum; i++ {
+ assert.True(t, awaitCh(chs[i]))
+ }
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+}
+
+func TestStepRelease(t *testing.T) {
+ mlc := NewMemoryLimitController(100)
+ assert.True(t, mlc.TryReserveMemory(101))
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+ assert.InDelta(t, 1.01, mlc.CurrentUsagePercent(), 0.000001)
+
+ gorNum := 10
+ ch := make(chan int, 1)
+ for i := 0; i < gorNum; i++ {
+ go reserveMemory(mlc, ch)
+ }
+
+ // The threads are blocked since the quota is full
+ assert.False(t, awaitCh(ch))
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+
+ for i := 0; i < gorNum; i++ {
+ mlc.ReleaseMemory(1)
+ assert.True(t, awaitCh(ch))
+ assert.False(t, awaitCh(ch))
+ }
+ assert.Equal(t, int64(101), mlc.CurrentUsage())
+}
+
+func reserveMemory(mlc MemoryLimitController, ch chan int) {
+ mlc.ReserveMemory(context.Background(), 1)
+ ch <- 1
+}
+
+func awaitCh(ch chan int) bool {
+ select {
+ case <-ch:
+ return true
+ case <-time.After(100 * time.Millisecond):
+ return false
+ }
+}
diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 160693c..c3a0aa9 100644
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -51,13 +51,14 @@ const (
)
var (
- errFailAddToBatch = newError(AddToBatchFailed, "message add to batch
failed")
- errSendTimeout = newError(TimeoutError, "message send timeout")
- errSendQueueIsFull = newError(ProducerQueueIsFull, "producer send queue
is full")
- errContextExpired = newError(TimeoutError, "message send context
expired")
- errMessageTooLarge = newError(MessageTooBig, "message size exceeds
MaxMessageSize")
- errMetaTooLarge = newError(InvalidMessage, "message metadata size
exceeds MaxMessageSize")
- errProducerClosed = newError(ProducerClosed, "producer already been
closed")
+ errFailAddToBatch = newError(AddToBatchFailed, "message add to
batch failed")
+ errSendTimeout = newError(TimeoutError, "message send timeout")
+ errSendQueueIsFull = newError(ProducerQueueIsFull, "producer send
queue is full")
+ errContextExpired = newError(TimeoutError, "message send context
expired")
+ errMessageTooLarge = newError(MessageTooBig, "message size exceeds
MaxMessageSize")
+ errMetaTooLarge = newError(InvalidMessage, "message metadata size
exceeds MaxMessageSize")
+ errProducerClosed = newError(ProducerClosed, "producer already been
closed")
+ errMemoryBufferIsFull = newError(ClientMemoryBufferIsFull, "client
memory buffer is full")
buffersPool sync.Pool
)
@@ -483,6 +484,7 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
// read payload from message
uncompressedPayload := msg.Payload
+ uncompressedPayloadSize := int64(len(uncompressedPayload))
var schemaPayload []byte
var err error
@@ -494,14 +496,14 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
// The block chan must be closed when returned with exception
defer request.stopBlock()
- if !p.canAddToQueue(request) {
+ if !p.canAddToQueue(request, uncompressedPayloadSize) {
return
}
if p.options.DisableMultiSchema {
if msg.Schema != nil && p.options.Schema != nil &&
msg.Schema.GetSchemaInfo().hash() !=
p.options.Schema.GetSchemaInfo().hash() {
- p.publishSemaphore.Release()
+ p.releaseSemaphoreAndMem(uncompressedPayloadSize)
request.callback(nil, request.msg, fmt.Errorf("msg
schema can not match with producer schema"))
p.log.WithError(err).Errorf("The producer %s of the
topic %s is disabled the `MultiSchema`", p.producerName, p.topic)
return
@@ -520,7 +522,7 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
if uncompressedPayload == nil && schema != nil {
schemaPayload, err = schema.Encode(msg.Value)
if err != nil {
- p.publishSemaphore.Release()
+
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
request.callback(nil, request.msg,
newError(SchemaFailure, err.Error()))
p.log.WithError(err).Errorf("Schema encode
message failed %s", msg.Value)
return
@@ -536,7 +538,7 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
if schemaVersion == nil {
schemaVersion, err =
p.getOrCreateSchema(schema.GetSchemaInfo())
if err != nil {
- p.publishSemaphore.Release()
+
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
p.log.WithError(err).Error("get schema version
fail")
request.callback(nil, request.msg,
fmt.Errorf("get schema version fail, err: %w", err))
return
@@ -596,7 +598,7 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
// if msg is too large and chunking is disabled
if checkSize > maxMessageSize && !p.options.EnableChunking {
- p.publishSemaphore.Release()
+ p.releaseSemaphoreAndMem(uncompressedPayloadSize)
request.callback(nil, request.msg, errMessageTooLarge)
p.log.WithError(errMessageTooLarge).
WithField("size", checkSize).
@@ -615,7 +617,7 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
} else {
payloadChunkSize = int(p._getConn().GetMaxMessageSize()) -
proto.Size(mm)
if payloadChunkSize <= 0 {
- p.publishSemaphore.Release()
+ p.releaseSemaphoreAndMem(uncompressedPayloadSize)
request.callback(nil, msg, errMetaTooLarge)
p.log.WithError(errMetaTooLarge).
WithField("metadata size", proto.Size(mm)).
@@ -663,7 +665,8 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
chunkRecorder: cr,
}
// the permit of first chunk has acquired
- if chunkID != 0 && !p.canAddToQueue(nsr) {
+ if chunkID != 0 && !p.canAddToQueue(nsr, 0) {
+
p.releaseSemaphoreAndMem(uncompressedPayloadSize - int64(rhs))
return
}
p.internalSingleSend(mm,
compressedPayload[lhs:rhs], nsr, uint32(maxMessageSize))
@@ -688,7 +691,7 @@ func (p *partitionProducer) internalSend(request
*sendRequest) {
// after flushing try again to add the current payload
if ok := p.batchBuilder.Add(smm, p.sequenceIDGenerator,
uncompressedPayload, request,
msg.ReplicationClusters, deliverAt,
schemaVersion, multiSchemaEnabled); !ok {
- p.publishSemaphore.Release()
+
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
request.callback(nil, request.msg,
errFailAddToBatch)
p.log.WithField("size", uncompressedSize).
WithField("properties", msg.Properties).
@@ -797,7 +800,7 @@ func (p *partitionProducer) internalSingleSend(mm
*pb.MessageMetadata,
maxMessageSize,
); err != nil {
request.callback(nil, request.msg, err)
- p.publishSemaphore.Release()
+ p.releaseSemaphoreAndMem(int64(len(msg.Payload)))
p.log.WithError(err).Errorf("Single message serialize failed
%s", msg.Value)
return
}
@@ -935,7 +938,7 @@ func (p *partitionProducer) failTimeoutMessages() {
sr := i.(*sendRequest)
if sr.msg != nil {
size := len(sr.msg.Payload)
- p.publishSemaphore.Release()
+ p.releaseSemaphoreAndMem(int64(size))
p.metrics.MessagesPending.Dec()
p.metrics.BytesPending.Sub(float64(size))
p.metrics.PublishErrorsTimeout.Inc()
@@ -1139,8 +1142,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response
*pb.CommandSendReceipt)
sr := i.(*sendRequest)
if sr.msg != nil {
atomic.StoreInt64(&p.lastSequenceID,
int64(pi.sequenceID))
- p.publishSemaphore.Release()
-
+
p.releaseSemaphoreAndMem(int64(len(sr.msg.Payload)))
p.metrics.PublishLatency.Observe(float64(now-sr.publishTime.UnixNano()) / 1.0e9)
p.metrics.MessagesPublished.Inc()
p.metrics.MessagesPending.Dec()
@@ -1326,7 +1328,12 @@ func (p *partitionProducer) _getConn()
internal.Connection {
return p.conn.Load().(internal.Connection)
}
-func (p *partitionProducer) canAddToQueue(sr *sendRequest) bool {
+func (p *partitionProducer) releaseSemaphoreAndMem(size int64) {
+ p.publishSemaphore.Release()
+ p.client.memLimit.ReleaseMemory(size)
+}
+
+func (p *partitionProducer) canAddToQueue(sr *sendRequest,
uncompressedPayloadSize int64) bool {
if p.options.DisableBlockIfQueueFull {
if !p.publishSemaphore.TryAcquire() {
if sr.callback != nil {
@@ -1334,9 +1341,24 @@ func (p *partitionProducer) canAddToQueue(sr
*sendRequest) bool {
}
return false
}
- } else if !p.publishSemaphore.Acquire(sr.ctx) {
- sr.callback(nil, sr.msg, errContextExpired)
- return false
+ if !p.client.memLimit.TryReserveMemory(uncompressedPayloadSize)
{
+ p.publishSemaphore.Release()
+ if sr.callback != nil {
+ sr.callback(nil, sr.msg, errMemoryBufferIsFull)
+ }
+ return false
+ }
+
+ } else {
+ if !p.publishSemaphore.Acquire(sr.ctx) {
+ sr.callback(nil, sr.msg, errContextExpired)
+ return false
+ }
+ if !p.client.memLimit.ReserveMemory(sr.ctx,
uncompressedPayloadSize) {
+ p.publishSemaphore.Release()
+ sr.callback(nil, sr.msg, errContextExpired)
+ return false
+ }
}
p.metrics.MessagesPending.Inc()
p.metrics.BytesPending.Add(float64(len(sr.msg.Payload)))
diff --git a/pulsar/producer_test.go b/pulsar/producer_test.go
index e69a14c..f86d01a 100644
--- a/pulsar/producer_test.go
+++ b/pulsar/producer_test.go
@@ -1779,3 +1779,129 @@ func TestWaitForExclusiveProducer(t *testing.T) {
producer1.Close()
wg.Wait()
}
+
+func TestMemLimitRejectProducerMessages(t *testing.T) {
+
+ c, err := NewClient(ClientOptions{
+ URL: serviceURL,
+ MemoryLimitBytes: 100 * 1024,
+ })
+ assert.NoError(t, err)
+ defer c.Close()
+
+ topicName := newTopicName()
+ producer1, _ := c.CreateProducer(ProducerOptions{
+ Topic: topicName,
+ DisableBlockIfQueueFull: true,
+ DisableBatching: false,
+ BatchingMaxPublishDelay: 100 * time.Second,
+ SendTimeout: 2 * time.Second,
+ })
+
+ producer2, _ := c.CreateProducer(ProducerOptions{
+ Topic: topicName,
+ DisableBlockIfQueueFull: true,
+ DisableBatching: false,
+ BatchingMaxPublishDelay: 100 * time.Second,
+ SendTimeout: 2 * time.Second,
+ })
+
+ n := 101
+ for i := 0; i < n/2; i++ {
+ producer1.SendAsync(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ }, func(id MessageID, message *ProducerMessage, e error) {})
+
+ producer2.SendAsync(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ }, func(id MessageID, message *ProducerMessage, e error) {})
+ }
+ // Last message in order to reach the limit
+ producer1.SendAsync(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ }, func(id MessageID, message *ProducerMessage, e error) {})
+ time.Sleep(100 * time.Millisecond)
+ assert.Equal(t, int64(n*1024), c.(*client).memLimit.CurrentUsage())
+
+ _, err = producer1.Send(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ })
+ assert.Error(t, err)
+ assert.ErrorContains(t, err, getResultStr(ClientMemoryBufferIsFull))
+
+ _, err = producer2.Send(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ })
+ assert.Error(t, err)
+ assert.ErrorContains(t, err, getResultStr(ClientMemoryBufferIsFull))
+
+ // flush pending msg
+ err = producer1.Flush()
+ assert.NoError(t, err)
+ err = producer2.Flush()
+ assert.NoError(t, err)
+ assert.Equal(t, int64(0), c.(*client).memLimit.CurrentUsage())
+
+ _, err = producer1.Send(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ })
+ assert.NoError(t, err)
+ _, err = producer2.Send(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ })
+ assert.NoError(t, err)
+}
+
+func TestMemLimitContextCancel(t *testing.T) {
+
+ c, err := NewClient(ClientOptions{
+ URL: serviceURL,
+ MemoryLimitBytes: 100 * 1024,
+ })
+ assert.NoError(t, err)
+ defer c.Close()
+
+ topicName := newTopicName()
+ producer, _ := c.CreateProducer(ProducerOptions{
+ Topic: topicName,
+ DisableBlockIfQueueFull: false,
+ DisableBatching: false,
+ BatchingMaxPublishDelay: 100 * time.Second,
+ SendTimeout: 2 * time.Second,
+ })
+
+ n := 101
+ ctx, cancel := context.WithCancel(context.Background())
+ for i := 0; i < n; i++ {
+ producer.SendAsync(ctx, &ProducerMessage{
+ Payload: make([]byte, 1024),
+ }, func(id MessageID, message *ProducerMessage, e error) {})
+ }
+ time.Sleep(100 * time.Millisecond)
+ assert.Equal(t, int64(n*1024), c.(*client).memLimit.CurrentUsage())
+
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ producer.SendAsync(ctx, &ProducerMessage{
+ Payload: make([]byte, 1024),
+ }, func(id MessageID, message *ProducerMessage, e error) {
+ assert.Error(t, e)
+ assert.ErrorContains(t, e, getResultStr(TimeoutError))
+ wg.Done()
+ })
+ }()
+
+ // cancel pending msg
+ cancel()
+ wg.Wait()
+
+ err = producer.Flush()
+ assert.NoError(t, err)
+ assert.Equal(t, int64(0), c.(*client).memLimit.CurrentUsage())
+
+ _, err = producer.Send(context.Background(), &ProducerMessage{
+ Payload: make([]byte, 1024),
+ })
+ assert.NoError(t, err)
+}