This is an automated email from the ASF dual-hosted git repository.

baodi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f01a7c  [Fix] check if callback is nil before calling it (#1036)
6f01a7c is described below

commit 6f01a7cead8704aa59afcd819545512d0259af07
Author: gunli <[email protected]>
AuthorDate: Mon Jul 3 09:44:33 2023 +0800

    [Fix] check if callback is nil before calling it (#1036)
    
    Co-authored-by: gunli <[email protected]>
---
 pulsar/producer_partition.go | 58 +++++++++++++++++++++++---------------------
 1 file changed, 30 insertions(+), 28 deletions(-)

diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 837d1d7..98c6c98 100644
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -467,6 +467,14 @@ func (p *partitionProducer) Name() string {
        return p.producerName
 }
 
+func runCallback(cb func(MessageID, *ProducerMessage, error), id MessageID, 
msg *ProducerMessage, err error) {
+       if cb == nil {
+               return
+       }
+
+       cb(id, msg, err)
+}
+
 func (p *partitionProducer) internalSend(request *sendRequest) {
        p.log.Debug("Received send request: ", *request.msg)
 
@@ -480,7 +488,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
        var err error
        if msg.Value != nil && msg.Payload != nil {
                p.log.Error("Can not set Value and Payload both")
-               request.callback(nil, request.msg, errors.New("can not set 
Value and Payload both"))
+               runCallback(request.callback, nil, request.msg, errors.New("can 
not set Value and Payload both"))
                return
        }
 
@@ -494,7 +502,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                if msg.Schema != nil && p.options.Schema != nil &&
                        msg.Schema.GetSchemaInfo().hash() != 
p.options.Schema.GetSchemaInfo().hash() {
                        p.releaseSemaphoreAndMem(uncompressedPayloadSize)
-                       request.callback(nil, request.msg, fmt.Errorf("msg 
schema can not match with producer schema"))
+                       runCallback(request.callback, nil, request.msg, 
fmt.Errorf("msg schema can not match with producer schema"))
                        p.log.WithError(err).Errorf("The producer %s of the 
topic %s is disabled the `MultiSchema`", p.producerName, p.topic)
                        return
                }
@@ -513,7 +521,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                        schemaPayload, err = schema.Encode(msg.Value)
                        if err != nil {
                                
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
-                               request.callback(nil, request.msg, 
newError(SchemaFailure, err.Error()))
+                               runCallback(request.callback, nil, request.msg, 
newError(SchemaFailure, err.Error()))
                                p.log.WithError(err).Errorf("Schema encode 
message failed %s", msg.Value)
                                return
                        }
@@ -530,7 +538,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                        if err != nil {
                                
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
                                p.log.WithError(err).Error("get schema version 
fail")
-                               request.callback(nil, request.msg, 
fmt.Errorf("get schema version fail, err: %w", err))
+                               runCallback(request.callback, nil, request.msg, 
fmt.Errorf("get schema version fail, err: %w", err))
                                return
                        }
                        p.schemaCache.Put(schema.GetSchemaInfo(), schemaVersion)
@@ -589,7 +597,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
        // if msg is too large and chunking is disabled
        if checkSize > maxMessageSize && !p.options.EnableChunking {
                p.releaseSemaphoreAndMem(uncompressedPayloadSize)
-               request.callback(nil, request.msg, errMessageTooLarge)
+               runCallback(request.callback, nil, request.msg, 
errMessageTooLarge)
                p.log.WithError(errMessageTooLarge).
                        WithField("size", checkSize).
                        WithField("properties", msg.Properties).
@@ -608,7 +616,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                payloadChunkSize = int(p._getConn().GetMaxMessageSize()) - 
proto.Size(mm)
                if payloadChunkSize <= 0 {
                        p.releaseSemaphoreAndMem(uncompressedPayloadSize)
-                       request.callback(nil, msg, errMetaTooLarge)
+                       runCallback(request.callback, nil, msg, errMetaTooLarge)
                        p.log.WithError(errMetaTooLarge).
                                WithField("metadata size", proto.Size(mm)).
                                WithField("properties", msg.Properties).
@@ -683,7 +691,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
                        if ok := addRequestToBatch(smm, p, uncompressedPayload, 
request, msg, deliverAt, schemaVersion,
                                multiSchemaEnabled); !ok {
                                
p.releaseSemaphoreAndMem(uncompressedPayloadSize)
-                               request.callback(nil, request.msg, 
errFailAddToBatch)
+                               runCallback(request.callback, nil, request.msg, 
errFailAddToBatch)
                                p.log.WithField("size", uncompressedSize).
                                        WithField("properties", msg.Properties).
                                        Error("unable to add message to batch")
@@ -835,7 +843,7 @@ func (p *partitionProducer) internalSingleSend(mm 
*pb.MessageMetadata,
                )
        }
        if err != nil {
-               request.callback(nil, request.msg, err)
+               runCallback(request.callback, nil, request.msg, err)
                p.releaseSemaphoreAndMem(int64(len(msg.Payload)))
                p.log.WithError(err).Errorf("Single message serialize failed 
%s", msg.Value)
                return
@@ -875,7 +883,7 @@ func (p *partitionProducer) internalFlushCurrentBatch() {
        if err != nil {
                for _, cb := range callbacks {
                        if sr, ok := cb.(*sendRequest); ok {
-                               sr.callback(nil, sr.msg, err)
+                               runCallback(sr.callback, nil, sr.msg, err)
                        }
                }
                if errors.Is(err, internal.ErrExceedMaxMessageSize) {
@@ -985,7 +993,7 @@ func (p *partitionProducer) failTimeoutMessages() {
 
                                if sr.callback != nil {
                                        sr.callbackOnce.Do(func() {
-                                               sr.callback(nil, sr.msg, 
errSendTimeout)
+                                               runCallback(sr.callback, nil, 
sr.msg, errSendTimeout)
                                        })
                                }
                                if sr.transaction != nil {
@@ -1018,7 +1026,7 @@ func (p *partitionProducer) internalFlushCurrentBatches() 
{
                if errs[i] != nil {
                        for _, cb := range callbacks[i] {
                                if sr, ok := cb.(*sendRequest); ok {
-                                       sr.callback(nil, sr.msg, errs[i])
+                                       runCallback(sr.callback, nil, sr.msg, 
errs[i])
                                }
                        }
                        if errors.Is(errs[i], internal.ErrExceedMaxMessageSize) 
{
@@ -1106,26 +1114,26 @@ func (p *partitionProducer) SendAsync(ctx 
context.Context, msg *ProducerMessage,
 
 func (p *partitionProducer) internalSendAsync(ctx context.Context, msg 
*ProducerMessage,
        callback func(MessageID, *ProducerMessage, error), flushImmediately 
bool) {
-       //Register transaction operation to transaction and the transaction 
coordinator.
+       // Register transaction operation to transaction and the transaction 
coordinator.
        var newCallback func(MessageID, *ProducerMessage, error)
        if msg.Transaction != nil {
                transactionImpl := (msg.Transaction).(*transaction)
                if transactionImpl.state != TxnOpen {
                        p.log.WithField("state", 
transactionImpl.state).Error("Failed to send message" +
                                " by a non-open transaction.")
-                       callback(nil, msg, newError(InvalidStatus, "Failed to 
send message by a non-open transaction."))
+                       runCallback(callback, nil, msg, newError(InvalidStatus, 
"Failed to send message by a non-open transaction."))
                        return
                }
 
                if err := transactionImpl.registerProducerTopic(p.topic); err 
!= nil {
-                       callback(nil, msg, err)
+                       runCallback(callback, nil, msg, err)
                        return
                }
                if err := transactionImpl.registerSendOrAckOp(); err != nil {
-                       callback(nil, msg, err)
+                       runCallback(callback, nil, msg, err)
                }
                newCallback = func(id MessageID, producerMessage 
*ProducerMessage, err error) {
-                       callback(id, producerMessage, err)
+                       runCallback(callback, id, producerMessage, err)
                        transactionImpl.endSendOrAckOp(err)
                }
        } else {
@@ -1133,7 +1141,7 @@ func (p *partitionProducer) internalSendAsync(ctx 
context.Context, msg *Producer
        }
        if p.getProducerState() != producerReady {
                // Producer is closing
-               newCallback(nil, msg, errProducerClosed)
+               runCallback(newCallback, nil, msg, errProducerClosed)
                return
        }
 
@@ -1253,9 +1261,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
                                }
 
                                if sr.totalChunks <= 1 || sr.chunkID == 
sr.totalChunks-1 {
-                                       if sr.callback != nil {
-                                               sr.callback(msgID, sr.msg, nil)
-                                       }
+                                       runCallback(sr.callback, msgID, sr.msg, 
nil)
                                        
p.options.Interceptors.OnSendAcknowledgement(p, sr.msg, msgID)
                                }
                        }
@@ -1406,27 +1412,23 @@ func (p *partitionProducer) releaseSemaphoreAndMem(size 
int64) {
 func (p *partitionProducer) canAddToQueue(sr *sendRequest, 
uncompressedPayloadSize int64) bool {
        if p.options.DisableBlockIfQueueFull {
                if !p.publishSemaphore.TryAcquire() {
-                       if sr.callback != nil {
-                               sr.callback(nil, sr.msg, errSendQueueIsFull)
-                       }
+                       runCallback(sr.callback, nil, sr.msg, 
errSendQueueIsFull)
                        return false
                }
                if !p.client.memLimit.TryReserveMemory(uncompressedPayloadSize) 
{
                        p.publishSemaphore.Release()
-                       if sr.callback != nil {
-                               sr.callback(nil, sr.msg, errMemoryBufferIsFull)
-                       }
+                       runCallback(sr.callback, nil, sr.msg, 
errMemoryBufferIsFull)
                        return false
                }
 
        } else {
                if !p.publishSemaphore.Acquire(sr.ctx) {
-                       sr.callback(nil, sr.msg, errContextExpired)
+                       runCallback(sr.callback, nil, sr.msg, errContextExpired)
                        return false
                }
                if !p.client.memLimit.ReserveMemory(sr.ctx, 
uncompressedPayloadSize) {
                        p.publishSemaphore.Release()
-                       sr.callback(nil, sr.msg, errContextExpired)
+                       runCallback(sr.callback, nil, sr.msg, errContextExpired)
                        return false
                }
        }

Reply via email to