This is an automated email from the ASF dual-hosted git repository.

rxl pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d864c2  Fix possible race condition in connection pool (#561)
9d864c2 is described below

commit 9d864c224e0943cf1cba95ec3f967211edc4f0cb
Author: cckellogg <[email protected]>
AuthorDate: Mon Jul 12 03:30:17 2021 -0400

    Fix possible race condition in connection pool (#561)
    
    * Fix possible race condition in connection pool
    
    - Use mutex instead of sync map in connection pool
    - Create tickers in run function for the connection instead
    of when the connection is created
    
    * Remove commented out code
---
 pulsar/internal/connection.go      | 41 +++++++++++---------
 pulsar/internal/connection_pool.go | 77 +++++++++++++++++++-------------------
 2 files changed, 62 insertions(+), 56 deletions(-)

diff --git a/pulsar/internal/connection.go b/pulsar/internal/connection.go
index 2549e21..a632816 100644
--- a/pulsar/internal/connection.go
+++ b/pulsar/internal/connection.go
@@ -128,6 +128,7 @@ type incomingCmd struct {
 type connection struct {
        sync.Mutex
        cond              *sync.Cond
+       started           int32
        state             ua.Int32
        connectionTimeout time.Duration
        closeOnce         sync.Once
@@ -142,8 +143,6 @@ type connection struct {
 
        lastDataReceivedLock sync.Mutex
        lastDataReceivedTime time.Time
-       pingTicker           *time.Ticker
-       pingCheckTicker      *time.Ticker
 
        log log.Logger
 
@@ -191,8 +190,6 @@ func newConnection(opts connectionOptions) *connection {
                log:                  
opts.logger.SubLogger(log.Fields{"remote_addr": opts.physicalAddr}),
                pendingReqs:          make(map[uint64]*request),
                lastDataReceivedTime: time.Now(),
-               pingTicker:           time.NewTicker(keepAliveInterval),
-               pingCheckTicker:      time.NewTicker(keepAliveInterval),
                tlsOptions:           opts.tls,
                auth:                 opts.auth,
 
@@ -217,6 +214,11 @@ func newConnection(opts connectionOptions) *connection {
 }
 
 func (c *connection) start() {
+       if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
+               c.log.Warnf("connection has already started")
+               return
+       }
+
        // Each connection gets its own goroutine that will
        go func() {
                if c.connect() {
@@ -354,16 +356,17 @@ func (c *connection) failLeftRequestsWhenClose() {
 }
 
 func (c *connection) run() {
-       // All reads come from the reader goroutine
-       go c.reader.readFromConnection()
-       go c.runPingCheck()
-
-       c.log.Debugf("Connection run starting with request capacity=%d 
queued=%d",
-               cap(c.incomingRequestsCh), len(c.incomingRequestsCh))
+       pingSendTicker := time.NewTicker(keepAliveInterval)
+       pingCheckTicker := time.NewTicker(keepAliveInterval)
 
        defer func() {
+               // stop tickers
+               pingSendTicker.Stop()
+               pingCheckTicker.Stop()
+
                // all the accesses to the pendingReqs should be happened in 
this run loop thread,
-               // including the final cleanup, to avoid the issue 
https://github.com/apache/pulsar-client-go/issues/239
+               // including the final cleanup, to avoid the issue
+               // https://github.com/apache/pulsar-client-go/issues/239
                c.pendingLock.Lock()
                for id, req := range c.pendingReqs {
                        req.callback(nil, errConnectionClosed)
@@ -373,6 +376,13 @@ func (c *connection) run() {
                c.Close()
        }()
 
+       // All reads come from the reader goroutine
+       go c.reader.readFromConnection()
+       go c.runPingCheck(pingCheckTicker)
+
+       c.log.Debugf("Connection run starting with request capacity=%d 
queued=%d",
+               cap(c.incomingRequestsCh), len(c.incomingRequestsCh))
+
        go func() {
                for {
                        select {
@@ -402,18 +412,18 @@ func (c *connection) run() {
                        }
                        c.internalWriteData(data)
 
-               case <-c.pingTicker.C:
+               case <-pingSendTicker.C:
                        c.sendPing()
                }
        }
 }
 
-func (c *connection) runPingCheck() {
+func (c *connection) runPingCheck(pingCheckTicker *time.Ticker) {
        for {
                select {
                case <-c.closeCh:
                        return
-               case <-c.pingCheckTicker.C:
+               case <-pingCheckTicker.C:
                        if c.lastDataReceived().Add(2 * 
keepAliveInterval).Before(time.Now()) {
                                // We have not received a response to the 
previous Ping request, the
                                // connection to broker is stale
@@ -803,9 +813,6 @@ func (c *connection) Close() {
 
                close(c.closeCh)
 
-               c.pingTicker.Stop()
-               c.pingCheckTicker.Stop()
-
                listeners := make(map[uint64]ConnectionListener)
                c.listenersLock.Lock()
                for id, listener := range c.listeners {
diff --git a/pulsar/internal/connection_pool.go 
b/pulsar/internal/connection_pool.go
index 29e1267..2728a73 100644
--- a/pulsar/internal/connection_pool.go
+++ b/pulsar/internal/connection_pool.go
@@ -38,7 +38,8 @@ type ConnectionPool interface {
 }
 
 type connectionPool struct {
-       pool                  sync.Map
+       sync.Mutex
+       connections           map[string]*connection
        connectionTimeout     time.Duration
        tlsOptions            *TLSOptions
        auth                  auth.Provider
@@ -58,6 +59,7 @@ func NewConnectionPool(
        logger log.Logger,
        metrics *Metrics) ConnectionPool {
        return &connectionPool{
+               connections:           make(map[string]*connection),
                tlsOptions:            tlsOptions,
                auth:                  auth,
                connectionTimeout:     connectionTimeout,
@@ -69,54 +71,51 @@ func NewConnectionPool(
 
 func (p *connectionPool) GetConnection(logicalAddr *url.URL, physicalAddr 
*url.URL) (Connection, error) {
        key := p.getMapKey(logicalAddr)
-       cachedCnx, found := p.pool.Load(key)
-       if found {
-               cnx := cachedCnx.(*connection)
-               p.log.Debug("Found connection in cache:", cnx.logicalAddr, 
cnx.physicalAddr)
 
-               if err := cnx.waitUntilReady(); err == nil {
-                       // Connection is ready to be used
-                       return cnx, nil
+       p.Lock()
+       conn, ok := p.connections[key]
+       if ok {
+               p.log.Debugf("Found connection in pool key=%s logical_addr=%+v 
physical_addr=%+v",
+                       key, conn.logicalAddr, conn.physicalAddr)
+
+               // remove stale/failed connection
+               if conn.closed() {
+                       delete(p.connections, key)
+                       p.log.Debugf("Removed connection from pool key=%s 
logical_addr=%+v physical_addr=%+v",
+                               key, conn.logicalAddr, conn.physicalAddr)
+                       conn = nil // set to nil so we create a new one
                }
-               // The cached connection is failed
-               p.pool.Delete(key)
-               p.log.Debug("Removed failed connection from pool:", 
cnx.logicalAddr, cnx.physicalAddr)
        }
 
-       // Try to create a new connection
-       newConnection := newConnection(connectionOptions{
-               logicalAddr:       logicalAddr,
-               physicalAddr:      physicalAddr,
-               tls:               p.tlsOptions,
-               connectionTimeout: p.connectionTimeout,
-               auth:              p.auth,
-               logger:            p.log,
-               metrics:           p.metrics,
-       })
-       newCnx, wasCached := p.pool.LoadOrStore(key, newConnection)
-       cnx := newCnx.(*connection)
-
-       if !wasCached {
-               cnx.start()
+       if conn == nil {
+               conn = newConnection(connectionOptions{
+                       logicalAddr:       logicalAddr,
+                       physicalAddr:      physicalAddr,
+                       tls:               p.tlsOptions,
+                       connectionTimeout: p.connectionTimeout,
+                       auth:              p.auth,
+                       logger:            p.log,
+                       metrics:           p.metrics,
+               })
+               p.connections[key] = conn
+               p.Unlock()
+               conn.start()
        } else {
-               newConnection.Close()
+               // we already have a connection
+               p.Unlock()
        }
 
-       if err := cnx.waitUntilReady(); err != nil {
-               if !wasCached {
-                       p.pool.Delete(key)
-                       p.log.Debug("Removed failed connection from pool:", 
cnx.logicalAddr, cnx.physicalAddr)
-               }
-               return nil, err
-       }
-       return cnx, nil
+       err := conn.waitUntilReady()
+       return conn, err
 }
 
 func (p *connectionPool) Close() {
-       p.pool.Range(func(key, value interface{}) bool {
-               value.(Connection).Close()
-               return true
-       })
+       p.Lock()
+       for k, c := range p.connections {
+               delete(p.connections, k)
+               c.Close()
+       }
+       p.Unlock()
 }
 
 func (p *connectionPool) getMapKey(addr *url.URL) string {

Reply via email to