This is an automated email from the ASF dual-hosted git repository.

mmerli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new bd30a32  [ISSUE #72] Fix data race conditions. (#77)
bd30a32 is described below

commit bd30a324bb30bb7a049cb44576f392cc2ed01575
Author: cckellogg <[email protected]>
AuthorDate: Mon Oct 28 15:09:59 2019 -0700

    [ISSUE #72] Fix data race conditions. (#77)
    
    * [ISSUE #72] Fix data race conditions.
    
    * Remove commented out code.
    
    * Revert write request ch to unbuffered.
---
 license_test.go                                    |   2 +-
 pulsar/impl_consumer.go                            |  10 +-
 pulsar/impl_partition_producer.go                  |  19 ++-
 pulsar/internal/connection.go                      | 150 +++++++++++++--------
 ...unackedMsgTracker.go => unacked_msg_tracker.go} |   6 +-
 ...Tracker_test.go => unacked_msg_tracker_test.go} |   3 +-
 6 files changed, 119 insertions(+), 71 deletions(-)

diff --git a/license_test.go b/license_test.go
index 51b7f44..84998c5 100644
--- a/license_test.go
+++ b/license_test.go
@@ -65,7 +65,7 @@ var otherCheck = regexp.MustCompile(`#
 `)
 
 var skip = map[string]bool{
-       "pkg/pb/PulsarApi.pb.go":true,
+       "pkg/pb/PulsarApi.pb.go": true,
 }
 
 func TestLicense(t *testing.T) {
diff --git a/pulsar/impl_consumer.go b/pulsar/impl_consumer.go
index 0178943..6a196ad 100644
--- a/pulsar/impl_consumer.go
+++ b/pulsar/impl_consumer.go
@@ -98,10 +98,12 @@ func singleTopicSubscribe(client *client, options 
*ConsumerOptions, topic string
        ch := make(chan ConsumerError, numPartitions)
 
        for partitionIdx, partitionTopic := range partitions {
+               // this needs to be created outside in the same go routine since
+               // newPartitionConsumer can modify the shared options struct 
causing a race condition
+               cons, err := newPartitionConsumer(client, partitionTopic, 
options, partitionIdx, numPartitions, c.queue)
                go func(partitionIdx int, partitionTopic string) {
-                       cons, e := newPartitionConsumer(client, partitionTopic, 
options, partitionIdx, numPartitions, c.queue)
                        ch <- ConsumerError{
-                               err:       e,
+                               err:       err,
                                partition: partitionIdx,
                                cons:      cons,
                        }
@@ -141,8 +143,8 @@ func (c *consumer) Subscription() string {
 
 func (c *consumer) Unsubscribe() error {
        var errMsg string
-       for _, c := range c.consumers {
-               if err := c.Unsubscribe(); err != nil {
+       for _, consumer := range c.consumers {
+               if err := consumer.Unsubscribe(); err != nil {
                        errMsg += fmt.Sprintf("topic %s, subscription %s: %s", 
c.Topic(), c.Subscription(), err)
                }
        }
diff --git a/pulsar/impl_partition_producer.go 
b/pulsar/impl_partition_producer.go
index f09cd42..196dbbc 100644
--- a/pulsar/impl_partition_producer.go
+++ b/pulsar/impl_partition_producer.go
@@ -25,11 +25,11 @@ import (
 
        "github.com/golang/protobuf/proto"
 
+       log "github.com/sirupsen/logrus"
+
        "github.com/apache/pulsar-client-go/pkg/pb"
        "github.com/apache/pulsar-client-go/pulsar/internal"
        "github.com/apache/pulsar-client-go/util"
-
-       log "github.com/sirupsen/logrus"
 )
 
 type producerState int
@@ -272,6 +272,7 @@ func (p *partitionProducer) internalSend(request 
*sendRequest) {
 }
 
 type pendingItem struct {
+       sync.Mutex
        batchData    []byte
        sequenceID   uint64
        sendRequests []interface{}
@@ -300,13 +301,19 @@ func (p *partitionProducer) internalFlush(fr 
*flushRequest) {
                return
        }
 
-       pi.sendRequests = append(pi.sendRequests, &sendRequest{
+       sendReq := &sendRequest{
                msg: nil,
                callback: func(id MessageID, message *ProducerMessage, e error) 
{
                        fr.err = e
                        fr.waitGroup.Done()
                },
-       })
+       }
+
+       // lock the pending request while adding requests
+       // since the ReceivedSendReceipt func iterates over this list
+       pi.Lock()
+       pi.sendRequests = append(pi.sendRequests, sendReq)
+       pi.Unlock()
 }
 
 func (p *partitionProducer) Send(ctx context.Context, msg *ProducerMessage) 
error {
@@ -370,6 +377,10 @@ func (p *partitionProducer) ReceivedSendReceipt(response 
*pb.CommandSendReceipt)
 
        // The ack was indeed for the expected item in the queue, we can remove 
it and trigger the callback
        p.pendingQueue.Poll()
+
+       // lock the pending item while sending the requests
+       pi.Lock()
+       defer pi.Unlock()
        for idx, i := range pi.sendRequests {
                sr := i.(*sendRequest)
                if sr.msg != nil {
diff --git a/pulsar/internal/connection.go b/pulsar/internal/connection.go
index fe43f79..592beca 100644
--- a/pulsar/internal/connection.go
+++ b/pulsar/internal/connection.go
@@ -81,6 +81,23 @@ const (
        connectionClosed
 )
 
+func (s connectionState) String() string {
+       switch s {
+       case connectionInit:
+               return "Initializing"
+       case connectionConnecting:
+               return "Connecting"
+       case connectionTCPConnected:
+               return "TCPConnected"
+       case connectionReady:
+               return "Ready"
+       case connectionClosed:
+               return "Closed"
+       default:
+               return "Unknown"
+       }
+}
+
 const keepAliveInterval = 30 * time.Second
 
 type request struct {
@@ -98,8 +115,11 @@ type connection struct {
        physicalAddr *url.URL
        cnx          net.Conn
 
+       writeBufferLock sync.Mutex
        writeBuffer          Buffer
        reader               *connectionReader
+
+       lastDataReceivedLock sync.Mutex
        lastDataReceivedTime time.Time
        pingTicker           *time.Ticker
 
@@ -107,13 +127,15 @@ type connection struct {
 
        requestIDGenerator uint64
 
-       incomingRequests chan *request
-       writeRequests    chan []byte
+       incomingRequestsCh chan *request
+       writeRequestsCh    chan []byte
 
        mapMutex    sync.RWMutex
        pendingReqs map[uint64]*request
        listeners   map[uint64]ConnectionListener
-       connWrapper *ConnWrapper
+
+       consumerHandlersLock sync.RWMutex
+       consumerHandlers map[uint64]ConsumerHandler
 
        tlsOptions *TLSOptions
        auth       auth.Provider
@@ -125,17 +147,17 @@ func newConnection(logicalAddr *url.URL, physicalAddr 
*url.URL, tlsOptions *TLSO
                logicalAddr:          logicalAddr,
                physicalAddr:         physicalAddr,
                writeBuffer:          NewBuffer(4096),
-               log:                  log.WithField("raddr", physicalAddr),
+               log:                  log.WithField("remote_addr", 
physicalAddr),
                pendingReqs:          make(map[uint64]*request),
                lastDataReceivedTime: time.Now(),
                pingTicker:           time.NewTicker(keepAliveInterval),
                tlsOptions:           tlsOptions,
                auth:                 auth,
 
-               incomingRequests: make(chan *request),
-               writeRequests:    make(chan []byte),
-               listeners:        make(map[uint64]ConnectionListener),
-               connWrapper:      NewConnWrapper(),
+               incomingRequestsCh: make(chan *request),
+               writeRequestsCh:    make(chan []byte),
+               listeners:          make(map[uint64]ConnectionListener),
+               consumerHandlers:   make(map[uint64]ConsumerHandler),
        }
        cnx.reader = newConnectionReader(cnx)
        cnx.cond = sync.NewCond(cnx)
@@ -157,7 +179,7 @@ func (c *connection) start() {
        }()
 }
 
-func (c *connection) connect() (ok bool) {
+func (c *connection) connect() bool {
        c.log.Info("Connecting to broker")
 
        var (
@@ -185,15 +207,19 @@ func (c *connection) connect() (ok bool) {
                c.Close()
                return false
        }
+
+       c.Lock()
        c.cnx = cnx
-       c.log = c.log.WithField("laddr", c.cnx.LocalAddr())
-       c.log.Debug("TCP connection established")
-       c.state = connectionTCPConnected
+       c.log = c.log.WithField("local_addr", c.cnx.LocalAddr())
+       c.log.Info("TCP connection established")
+       c.Unlock()
+
+       c.changeState(connectionTCPConnected)
 
        return true
 }
 
-func (c *connection) doHandshake() (ok bool) {
+func (c *connection) doHandshake() bool {
        // Send 'Connect' command to initiate handshake
        version := int32(pb.ProtocolVersion_v13)
 
@@ -231,24 +257,16 @@ func (c *connection) waitUntilReady() error {
        c.Lock()
        defer c.Unlock()
 
-       for {
+       for c.state != connectionReady {
                c.log.Debug("Wait until connection is ready. State: ", c.state)
-               switch c.state {
-               case connectionInit:
-                       fallthrough
-               case connectionConnecting:
-                       fallthrough
-               case connectionTCPConnected:
-                       // Wait for the state to change
-                       c.cond.Wait()
-
-               case connectionReady:
-                       return nil
-
-               case connectionClosed:
+               if c.state == connectionClosed {
                        return errors.New("connection error")
                }
+               // wait for a new connection state change
+               c.cond.Wait()
        }
+
+       return nil
 }
 
 func (c *connection) run() {
@@ -257,7 +275,7 @@ func (c *connection) run() {
 
        for {
                select {
-               case req := <-c.incomingRequests:
+               case req := <-c.incomingRequestsCh:
                        if req == nil {
                                return
                        }
@@ -266,7 +284,7 @@ func (c *connection) run() {
                        c.mapMutex.Unlock()
                        c.writeCommand(req.cmd)
 
-               case data := <-c.writeRequests:
+               case data := <-c.writeRequestsCh:
                        if data == nil {
                                return
                        }
@@ -279,7 +297,7 @@ func (c *connection) run() {
 }
 
 func (c *connection) WriteData(data []byte) {
-       c.writeRequests <- data
+       c.writeRequestsCh <- data
 }
 
 func (c *connection) internalWriteData(data []byte) {
@@ -296,6 +314,9 @@ func (c *connection) writeCommand(cmd proto.Message) {
        cmdSize := uint32(proto.Size(cmd))
        frameSize := cmdSize + 4
 
+       c.writeBufferLock.Lock()
+       defer c.writeBufferLock.Unlock()
+
        c.writeBuffer.Clear()
        c.writeBuffer.WriteUint32(frameSize)
        c.writeBuffer.WriteUint32(cmdSize)
@@ -305,12 +326,13 @@ func (c *connection) writeCommand(cmd proto.Message) {
        }
 
        c.writeBuffer.Write(serialized)
-       c.internalWriteData(c.writeBuffer.ReadableSlice())
+       data := c.writeBuffer.ReadableSlice()
+       c.internalWriteData(data)
 }
 
 func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload 
[]byte) {
        c.log.Debugf("Received command: %s -- payload: %v", cmd, 
headersAndPayload)
-       c.lastDataReceivedTime = time.Now()
+       c.setLastDataReceived(time.Now())
        var err error
 
        switch *cmd.Type {
@@ -374,11 +396,11 @@ func (c *connection) receivedCommand(cmd *pb.BaseCommand, 
headersAndPayload []by
 }
 
 func (c *connection) Write(data []byte) {
-       c.writeRequests <- data
+       c.writeRequestsCh <- data
 }
 
 func (c *connection) SendRequest(requestID uint64, req *pb.BaseCommand, 
callback func(command *pb.BaseCommand)) {
-       c.incomingRequests <- &request{
+       c.incomingRequestsCh <- &request{
                id:       requestID,
                cmd:      req,
                callback: callback,
@@ -406,7 +428,6 @@ func (c *connection) handleResponse(requestID uint64, 
response *pb.BaseCommand)
 }
 
 func (c *connection) handleSendReceipt(response *pb.CommandSendReceipt) {
-       c.log.Debug("Got SEND_RECEIPT: ", response)
        producerID := response.GetProducerId()
        if producer, ok := c.listeners[producerID]; ok {
                producer.ReceivedSendReceipt(response)
@@ -418,7 +439,7 @@ func (c *connection) handleSendReceipt(response 
*pb.CommandSendReceipt) {
 func (c *connection) handleMessage(response *pb.CommandMessage, payload 
[]byte) error {
        c.log.Debug("Got Message: ", response)
        consumerID := response.GetConsumerId()
-       if consumer, ok := c.connWrapper.Consumers[consumerID]; ok {
+       if consumer, ok := c.consumerHandler(consumerID); ok {
                err := consumer.MessageReceived(response, payload)
                if err != nil {
                        c.log.WithField("consumerID", consumerID).Error("handle 
message err: ", response.MessageId)
@@ -430,8 +451,21 @@ func (c *connection) handleMessage(response 
*pb.CommandMessage, payload []byte)
        return nil
 }
 
+func (c *connection) lastDataReceived() time.Time {
+       c.lastDataReceivedLock.Lock()
+       defer c.lastDataReceivedLock.Unlock()
+       t := c.lastDataReceivedTime
+       return t;
+}
+
+func (c *connection) setLastDataReceived(t time.Time) {
+       c.lastDataReceivedLock.Lock()
+       defer c.lastDataReceivedLock.Unlock()
+       c.lastDataReceivedTime = t
+}
+
 func (c *connection) sendPing() {
-       if c.lastDataReceivedTime.Add(2 * keepAliveInterval).Before(time.Now()) 
{
+       if c.lastDataReceived().Add(2 * keepAliveInterval).Before(time.Now()) {
                // We have not received a response to the previous Ping 
request, the
                // connection to broker is stale
                c.log.Info("Detected stale connection to broker")
@@ -454,7 +488,7 @@ func (c *connection) handlePing() {
 func (c *connection) handleCloseConsumer(closeConsumer 
*pb.CommandCloseConsumer) {
        c.log.Infof("Broker notification of Closed consumer: %d", 
closeConsumer.GetConsumerId())
        consumerID := closeConsumer.GetConsumerId()
-       if consumer, ok := c.connWrapper.Consumers[consumerID]; ok {
+       if consumer, ok := c.consumerHandler(consumerID); ok {
                if !util.IsNil(consumer) {
                        consumer.ConnectionClosed()
                }
@@ -503,15 +537,17 @@ func (c *connection) Close() {
                c.cnx.Close()
        }
        c.pingTicker.Stop()
-       close(c.incomingRequests)
-       close(c.writeRequests)
+       close(c.incomingRequestsCh)
+       close(c.writeRequestsCh)
 
        for _, listener := range c.listeners {
                listener.ConnectionClosed()
        }
 
-       for _, cnx := range c.connWrapper.Consumers {
-               cnx.ConnectionClosed()
+       c.consumerHandlersLock.RLock()
+       defer c.consumerHandlersLock.RUnlock()
+       for _, handler := range c.consumerHandlers {
+               handler.ConnectionClosed()
        }
 }
 
@@ -560,25 +596,21 @@ func (c *connection) getTLSConfig() (*tls.Config, error) {
        return tlsConfig, nil
 }
 
-type ConnWrapper struct {
-       Rwmu      sync.RWMutex
-       Consumers map[uint64]ConsumerHandler
-}
-
-func NewConnWrapper() *ConnWrapper {
-       return &ConnWrapper{
-               Consumers: make(map[uint64]ConsumerHandler),
-       }
-}
-
 func (c *connection) AddConsumeHandler(id uint64, handler ConsumerHandler) {
-       c.connWrapper.Rwmu.Lock()
-       c.connWrapper.Consumers[id] = handler
-       c.connWrapper.Rwmu.Unlock()
+       c.consumerHandlersLock.Lock()
+       defer c.consumerHandlersLock.Unlock()
+       c.consumerHandlers[id] = handler
 }
 
 func (c *connection) DeleteConsumeHandler(id uint64) {
-       c.connWrapper.Rwmu.Lock()
-       delete(c.connWrapper.Consumers, id)
-       c.connWrapper.Rwmu.Unlock()
+       c.consumerHandlersLock.Lock()
+       defer c.consumerHandlersLock.Unlock()
+       delete(c.consumerHandlers, id)
+}
+
+func (c *connection) consumerHandler(id uint64) (ConsumerHandler, bool) {
+       c.consumerHandlersLock.RLock()
+       defer c.consumerHandlersLock.RUnlock()
+       h, ok := c.consumerHandlers[id]
+       return h, ok
 }
diff --git a/pulsar/unackedMsgTracker.go b/pulsar/unacked_msg_tracker.go
similarity index 99%
rename from pulsar/unackedMsgTracker.go
rename to pulsar/unacked_msg_tracker.go
index 09ff0cb..ffc6eff 100644
--- a/pulsar/unackedMsgTracker.go
+++ b/pulsar/unacked_msg_tracker.go
@@ -21,11 +21,12 @@ import (
        "sync"
        "time"
 
-       "github.com/apache/pulsar-client-go/pkg/pb"
        "github.com/golang/protobuf/proto"
 
        set "github.com/deckarep/golang-set"
        log "github.com/sirupsen/logrus"
+
+       "github.com/apache/pulsar-client-go/pkg/pb"
 )
 
 type UnackedMessageTracker struct {
@@ -146,6 +147,7 @@ func (t *UnackedMessageTracker) handlerCmd() {
                select {
                case tick := <-t.timeout.C:
                        if t.isAckTimeout() {
+                               t.cmu.Lock()
                                log.Debugf(" %d messages have timed-out", 
t.oldOpenSet.Cardinality())
                                messageIds := make([]*pb.MessageIdData, 0)
 
@@ -153,10 +155,10 @@ func (t *UnackedMessageTracker) handlerCmd() {
                                        messageIds = append(messageIds, 
i.(*pb.MessageIdData))
                                        return false
                                })
-
                                log.Debugf("messageID length is:%d", 
len(messageIds))
 
                                t.oldOpenSet.Clear()
+                               t.cmu.Unlock()
 
                                if t.pcs != nil {
                                        messageIdsMap := 
make(map[int32][]*pb.MessageIdData)
diff --git a/pulsar/unackMsgTracker_test.go b/pulsar/unacked_msg_tracker_test.go
similarity index 99%
rename from pulsar/unackMsgTracker_test.go
rename to pulsar/unacked_msg_tracker_test.go
index edf7ddc..3848ce9 100644
--- a/pulsar/unackMsgTracker_test.go
+++ b/pulsar/unacked_msg_tracker_test.go
@@ -20,9 +20,10 @@ package pulsar
 import (
        "testing"
 
-       "github.com/apache/pulsar-client-go/pkg/pb"
        "github.com/golang/protobuf/proto"
        "github.com/stretchr/testify/assert"
+
+       "github.com/apache/pulsar-client-go/pkg/pb"
 )
 
 func TestUnackedMessageTracker(t *testing.T) {

Reply via email to