This is an automated email from the ASF dual-hosted git repository.

nodece pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 0129b2dd fix: the access to consumer.consumers might not be thread 
safe (#1494)
0129b2dd is described below

commit 0129b2dd5d91a7b09dada4b101856d4300156bec
Author: Yunze Xu <[email protected]>
AuthorDate: Thu May 14 10:12:38 2026 +0800

    fix: the access to consumer.consumers might not be thread safe (#1494)
    
    * fix: the access to ConsumerImpl.consumers might not be thread safe
    
    * Revert "fix: the access to ConsumerImpl.consumers might not be thread 
safe"
    
    This reverts commit 0596fe1a26aaef48af084f79620a96127ecf527f.
    
    * change consumers to copy-on-write for thread safety
    
    * fix format
    
    * Revert "fix format"
    
    This reverts commit 9f80a0fd92052d2fe33e82186782e8d4b52b8e20.
    
    * Revert "change consumers to copy-on-write for thread safety"
    
    This reverts commit ba0903ee2f0862595238720efbf8db3aed552e58.
    
    * Revert "Revert "fix: the access to ConsumerImpl.consumers might not be 
thread safe""
    
    This reverts commit 90336afd7bab410d561fc61952cc58e280b7aa8e.
    
    * add lock for other access on consumers and explain why atomic.Value is 
unacceptable
---
 pulsar/consumer_impl.go | 93 +++++++++++++++++++++++++++++--------------------
 1 file changed, 55 insertions(+), 38 deletions(-)

diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index ece37370..5579d374 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -50,9 +50,16 @@ type acker interface {
 
 type consumer struct {
        sync.Mutex
-       topic                     string
-       client                    *client
-       options                   ConsumerOptions
+       topic   string
+       client  *client
+       options ConsumerOptions
+
+       // When accessing `consumers`, the lock must be acquired in case 
partitions are being added
+       // in the background by `internalTopicSubscribeToPartitions`. 
Currently, when a new sub-consumer
+       // is created, the current consumer can immediately receive messages 
from the new partition. However,
+       // before the new sub-consumers are visible in `consumers`, the Ack 
related methods cannot find the
+       // sub-consumer for the message's message ID, so we cannot simply 
change `consumers` to `atomic.Value`
+       // and perform copy-on-write when partitions are added.
        consumers                 []*partitionConsumer
        consumerName              string
        disableForceTopicCreation bool
@@ -549,11 +556,11 @@ func (c *consumer) Receive(ctx context.Context) (message 
Message, err error) {
 
 func (c *consumer) AckWithTxn(msg Message, txn Transaction) error {
        msgID := msg.ID()
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.findPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
-
-       return c.consumers[msgID.PartitionIdx()].AckIDWithTxn(msgID, txn)
+       return consumer.AckIDWithTxn(msgID, txn)
 }
 
 // Chan return the message chan to users
@@ -568,23 +575,19 @@ func (c *consumer) Ack(msg Message) error {
 
 // AckID the consumption of a single message, identified by its MessageID
 func (c *consumer) AckID(msgID MessageID) error {
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.findPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
-
        if c.options.AckWithResponse {
-               return 
c.consumers[msgID.PartitionIdx()].AckIDWithResponse(msgID)
+               return consumer.AckIDWithResponse(msgID)
        }
-
-       return c.consumers[msgID.PartitionIdx()].AckID(msgID)
+       return consumer.AckID(msgID)
 }
 
 func (c *consumer) AckIDList(msgIDs []MessageID) error {
        return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) 
(acker, error) {
-               if err := c.checkMsgIDPartition(msgID); err != nil {
-                       return nil, err
-               }
-               return c.consumers[msgID.PartitionIdx()], nil
+               return c.findPartitionConsumer(msgID)
        })
 }
 
@@ -597,15 +600,14 @@ func (c *consumer) AckCumulative(msg Message) error {
 // AckIDCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message, identified by its MessageID
 func (c *consumer) AckIDCumulative(msgID MessageID) error {
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.findPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
-
        if c.options.AckWithResponse {
-               return 
c.consumers[msgID.PartitionIdx()].AckIDWithResponseCumulative(msgID)
+               return consumer.AckIDWithResponseCumulative(msgID)
        }
-
-       return c.consumers[msgID.PartitionIdx()].AckIDCumulative(msgID)
+       return consumer.AckIDCumulative(msgID)
 }
 
 // ReconsumeLater mark a message for redelivery after custom delay
@@ -695,7 +697,9 @@ func (c *consumer) Nack(msg Message) {
                        mid.NackByMsg(msg)
                        return
                }
-               c.consumers[mid.partitionIdx].NackMsg(msg)
+               if consumer, err := c.findPartitionConsumer(mid); err == nil {
+                       consumer.NackMsg(msg)
+               }
                return
        }
 
@@ -703,11 +707,9 @@ func (c *consumer) Nack(msg Message) {
 }
 
 func (c *consumer) NackID(msgID MessageID) {
-       if err := c.checkMsgIDPartition(msgID); err != nil {
-               return
+       if consumer, err := c.findPartitionConsumer(msgID); err == nil {
+               consumer.NackID(msgID)
        }
-
-       c.consumers[msgID.PartitionIdx()].NackID(msgID)
 }
 
 func (c *consumer) Close() {
@@ -743,11 +745,10 @@ func (c *consumer) Seek(msgID MessageID) error {
                return newError(SeekFailed, "for partition topic, seek command 
should perform on the individual partitions")
        }
 
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.unsafeFindPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
-
-       consumer := c.consumers[msgID.PartitionIdx()]
        consumer.pauseDispatchMessage()
        // clear messageCh
        for len(c.messageCh) > 0 {
@@ -781,26 +782,41 @@ func (c *consumer) SeekByTime(time time.Time) error {
        return errs
 }
 
-func (c *consumer) checkMsgIDPartition(msgID MessageID) error {
-       partition := msgID.PartitionIdx()
-       if partition < 0 || int(partition) >= len(c.consumers) {
+func (c *consumer) findPartitionConsumer(msgID MessageID) (*partitionConsumer, 
error) {
+       c.Lock()
+       defer c.Unlock()
+       return c.unsafeFindPartitionConsumer(msgID)
+}
+
+// NOTE: This method must be called when c.Lock is held
+func (c *consumer) unsafeFindPartitionConsumer(msgID MessageID) 
(*partitionConsumer, error) {
+       partition := int(msgID.PartitionIdx())
+       if partition < 0 || partition >= len(c.consumers) {
                c.log.Errorf("invalid partition index %d expected a partition 
between [0-%d]",
                        partition, len(c.consumers))
-               return fmt.Errorf("invalid partition index %d expected a 
partition between [0-%d]",
+               return nil, fmt.Errorf("invalid partition index %d expected a 
partition between [0-%d]",
                        partition, len(c.consumers))
        }
-       return nil
+       return c.consumers[partition], nil
 }
 
 func (c *consumer) hasNext() bool {
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel() // Make sure all paths cancel the context to avoid 
context leak
 
+       // We have to make a snapshot consumers, because we have to iterate 
over all consumers in
+       // other goroutines. But when this method returns, there might be still 
other consumers
+       // not completing the `hasNext` call, so we cannot just call defer 
`c.Unlock()` after acquiring the lock.
+       c.Lock()
+       consumers := make([]*partitionConsumer, len(c.consumers))
+       copy(consumers, c.consumers)
+       c.Unlock()
+
        var wg sync.WaitGroup
-       wg.Add(len(c.consumers))
+       wg.Add(len(consumers))
 
        hasNext := make(chan bool)
-       for _, pc := range c.consumers {
+       for _, pc := range consumers {
                go func() {
                        defer wg.Done()
                        if pc.hasNext() {
@@ -828,10 +844,11 @@ func (c *consumer) hasNext() bool {
 }
 
 func (c *consumer) setLastDequeuedMsg(msgID MessageID) error {
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       consumer, err := c.findPartitionConsumer(msgID)
+       if err != nil {
                return err
        }
-       c.consumers[msgID.PartitionIdx()].lastDequeuedMsg = 
toTrackingMessageID(msgID)
+       consumer.lastDequeuedMsg = toTrackingMessageID(msgID)
        return nil
 }
 
@@ -894,7 +911,7 @@ func toProtoInitialPosition(p SubscriptionInitialPosition) 
pb.CommandSubscribe_I
 }
 
 func (c *consumer) messageID(msgID MessageID) *trackingMessageID {
-       if err := c.checkMsgIDPartition(msgID); err != nil {
+       if _, err := c.findPartitionConsumer(msgID); err != nil {
                return nil
        }
        return toTrackingMessageID(msgID)

Reply via email to