Repository: thrift
Updated Branches:
  refs/heads/master c9e535ea7 -> c975bbcc9


THRIFT-2388 GoLang - Fix data races in simple_server and server_socket

Patch: Chris Bannister


Project: http://git-wip-us.apache.org/repos/asf/thrift/repo
Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/c975bbcc
Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/c975bbcc
Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/c975bbcc

Branch: refs/heads/master
Commit: c975bbcc9c3c618a6ee8902ae47fed89a025b597
Parents: c9e535e
Author: Jens Geyer <[email protected]>
Authored: Thu Mar 6 21:11:46 2014 +0100
Committer: Jens Geyer <[email protected]>
Committed: Thu Mar 6 21:11:46 2014 +0100

----------------------------------------------------------------------
 lib/go/thrift/server_socket.go | 15 +++++++++++++--
 lib/go/thrift/simple_server.go | 19 ++++++++++++++-----
 2 files changed, 27 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/thrift/blob/c975bbcc/lib/go/thrift/server_socket.go
----------------------------------------------------------------------
diff --git a/lib/go/thrift/server_socket.go b/lib/go/thrift/server_socket.go
index 1a01095..4c80714 100644
--- a/lib/go/thrift/server_socket.go
+++ b/lib/go/thrift/server_socket.go
@@ -21,6 +21,7 @@ package thrift
 
 import (
        "net"
+       "sync"
        "time"
 )
 
@@ -28,7 +29,10 @@ type TServerSocket struct {
        listener      net.Listener
        addr          net.Addr
        clientTimeout time.Duration
-       interrupted   bool
+
+       // Protects the interrupted value to make it thread safe.
+       mu          sync.RWMutex
+       interrupted bool
 }
 
 func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
@@ -56,7 +60,11 @@ func (p *TServerSocket) Listen() error {
 }
 
 func (p *TServerSocket) Accept() (TTransport, error) {
-       if p.interrupted {
+       p.mu.RLock()
+       interrupted := p.interrupted
+       p.mu.RUnlock()
+
+       if interrupted {
                return nil, errTransportInterrupted
        }
        if p.listener == nil {
@@ -102,6 +110,9 @@ func (p *TServerSocket) Close() error {
 }
 
 func (p *TServerSocket) Interrupt() error {
+       p.mu.Lock()
        p.interrupted = true
+       p.mu.Unlock()
+
        return nil
 }

http://git-wip-us.apache.org/repos/asf/thrift/blob/c975bbcc/lib/go/thrift/simple_server.go
----------------------------------------------------------------------
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index b5cb0e1..521394c 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -25,7 +25,7 @@ import (
 
 // Simple, non-concurrent server for testing.
 type TSimpleServer struct {
-       stopped bool
+       quit chan struct{}
 
        processorFactory       TProcessorFactory
        serverTransport        TServerTransport
@@ -78,12 +78,14 @@ func NewTSimpleServerFactory4(processorFactory 
TProcessorFactory, serverTranspor
 }
 
 func NewTSimpleServerFactory6(processorFactory TProcessorFactory, 
serverTransport TServerTransport, inputTransportFactory TTransportFactory, 
outputTransportFactory TTransportFactory, inputProtocolFactory 
TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
-       return &TSimpleServer{processorFactory: processorFactory,
+       return &TSimpleServer{
+               processorFactory:       processorFactory,
                serverTransport:        serverTransport,
                inputTransportFactory:  inputTransportFactory,
                outputTransportFactory: outputTransportFactory,
                inputProtocolFactory:   inputProtocolFactory,
                outputProtocolFactory:  outputProtocolFactory,
+               quit: make(chan struct{}, 1),
        }
 }
 
@@ -112,12 +114,19 @@ func (p *TSimpleServer) OutputProtocolFactory() 
TProtocolFactory {
 }
 
 func (p *TSimpleServer) Serve() error {
-       p.stopped = false
        err := p.serverTransport.Listen()
        if err != nil {
                return err
        }
-       for !p.stopped {
+
+loop:
+       for {
+               select {
+               case <-p.quit:
+                       break loop
+               default:
+               }
+
                client, err := p.serverTransport.Accept()
                if err != nil {
                        log.Println("Accept err: ", err)
@@ -134,7 +143,7 @@ func (p *TSimpleServer) Serve() error {
 }
 
 func (p *TSimpleServer) Stop() error {
-       p.stopped = true
+       p.quit <- struct{}{}
        p.serverTransport.Interrupt()
        return nil
 }

Reply via email to