This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 2fa907e  THRIFT-5495: close client when shutdown server in go lib 
Client: go
2fa907e is described below

commit 2fa907e84b5bf29c263c4cde657e99c2e894562f
Author: 郑桐 <[email protected]>
AuthorDate: Tue Jan 4 18:20:24 2022 +0800

    THRIFT-5495: close client when shutdown server in go lib
    Client: go
---
 lib/go/README.md                        |  24 ++++++
 lib/go/thrift/simple_server.go          |  56 +++++++++++--
 lib/go/thrift/simple_server_test.go     | 135 +++++++++++++++++++++++++++++++-
 test/go/src/common/clientserver_test.go |   2 +-
 4 files changed, 210 insertions(+), 7 deletions(-)

diff --git a/lib/go/README.md b/lib/go/README.md
index 75d7174..b2cf1df 100644
--- a/lib/go/README.md
+++ b/lib/go/README.md
@@ -132,3 +132,27 @@ if this interval is set to a value too low (for example, 
1ms), it might cause
 excessive cpu overhead.
 
 This feature is also only enabled on non-oneway endpoints.
+
+A note about server stop implementations
+========================================
+
+[TSimpleServer.Stop](https://pkg.go.dev/github.com/apache/thrift/lib/go/thrift#TSimpleServer.Stop)
 will wait for all client connections to be closed after 
+the last received request to be handled, as the time spent by Stop
+ may sometimes be too long:
+* When socket timeout is not set, server might be hanged before all active
+  clients to finish handling the last received request.
+* When the socket timeout is too long (e.g one hour), server will
+  hang for that duration before all active clients to finish handling the
+  last received request.
+
+To prevent Stop from hanging for too long, you can set 
+thrift.ServerStopTimeout in your main or init function:
+
+    thrift.ServerStopTimeout = <max_duration_to_stop>
+
+If it's set to <=0, the feature will be disabled (by default), and server 
+will wait for all the client connections to be closed gracefully with 
+zero err time. Otherwise, the stop will wait for all the client 
+connections to be closed gracefully util thrift.ServerStopTimeout is 
+reached, and client connections that are not closed after 
thrift.ServerStopTimeout 
+will be closed abruptly which may cause some client errors.
\ No newline at end of file
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 02863ec..1cfc375 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -20,6 +20,7 @@
 package thrift
 
 import (
+       "context"
        "errors"
        "fmt"
        "io"
@@ -48,15 +49,26 @@ var ErrAbandonRequest = errors.New("request abandoned")
 // If it's changed to <=0, the feature will be disabled.
 var ServerConnectivityCheckInterval = time.Millisecond * 5
 
+// ServerStopTimeout defines max stop wait duration used by
+// server stop to avoid hanging too long to wait for all client connections to 
be closed gracefully.
+//
+// It's defined as a variable instead of constant, so that thrift server
+// implementations can change its value to control the behavior.
+//
+// If it's set to <=0, the feature will be disabled(by default), and the 
server will wait for
+// for all the client connections to be closed gracefully.
+var ServerStopTimeout = time.Duration(0)
+
 /*
  * This is not a typical TSimpleServer as it is not blocked after accept a 
socket.
  * It is more like a TThreadedServer that can handle different connections in 
different goroutines.
  * This will work if golang user implements a conn-pool like thing in client 
side.
  */
 type TSimpleServer struct {
-       closed int32
-       wg     sync.WaitGroup
-       mu     sync.Mutex
+       closed   int32
+       wg       sync.WaitGroup
+       mu       sync.Mutex
+       stopChan chan struct{}
 
        processorFactory       TProcessorFactory
        serverTransport        TServerTransport
@@ -121,6 +133,7 @@ func NewTSimpleServerFactory6(processorFactory 
TProcessorFactory, serverTranspor
                outputTransportFactory: outputTransportFactory,
                inputProtocolFactory:   inputProtocolFactory,
                outputProtocolFactory:  outputProtocolFactory,
+               stopChan:               make(chan struct{}),
        }
 }
 
@@ -192,13 +205,27 @@ func (p *TSimpleServer) innerAccept() (int32, error) {
                return 0, err
        }
        if client != nil {
-               p.wg.Add(1)
+               ctx, cancel := context.WithCancel(context.Background())
+               p.wg.Add(2)
+
                go func() {
                        defer p.wg.Done()
+                       defer cancel()
                        if err := p.processRequests(client); err != nil {
                                p.logger(fmt.Sprintf("error processing request: 
%v", err))
                        }
                }()
+
+               go func() {
+                       defer p.wg.Done()
+                       select {
+                       case <-ctx.Done():
+                               // client exited, do nothing
+                       case <-p.stopChan:
+                               // TSimpleServer.Close called, close the client 
connection
+                               client.Close()
+                       }
+               }()
        }
        return 0, nil
 }
@@ -229,12 +256,31 @@ func (p *TSimpleServer) Serve() error {
 func (p *TSimpleServer) Stop() error {
        p.mu.Lock()
        defer p.mu.Unlock()
+
        if atomic.LoadInt32(&p.closed) != 0 {
                return nil
        }
        atomic.StoreInt32(&p.closed, 1)
        p.serverTransport.Interrupt()
-       p.wg.Wait()
+
+       ctx, cancel := context.WithCancel(context.Background())
+       go func() {
+               defer cancel()
+               p.wg.Wait()
+       }()
+
+       if ServerStopTimeout > 0 {
+               timer := time.NewTimer(ServerStopTimeout)
+               select {
+               case <-timer.C:
+               case <-ctx.Done():
+               }
+               close(p.stopChan)
+               timer.Stop()
+       }
+
+       <-ctx.Done()
+       p.stopChan = make(chan struct{})
        return nil
 }
 
diff --git a/lib/go/thrift/simple_server_test.go 
b/lib/go/thrift/simple_server_test.go
index 58149a8..b92d50f 100644
--- a/lib/go/thrift/simple_server_test.go
+++ b/lib/go/thrift/simple_server_test.go
@@ -20,11 +20,17 @@
 package thrift
 
 import (
-       "testing"
+       "context"
        "errors"
+       "net"
        "runtime"
+       "sync"
+       "testing"
+       "time"
 )
 
+const networkWaitDuration = 10 * time.Millisecond
+
 type mockServerTransport struct {
        ListenFunc    func() error
        AcceptFunc    func() (TTransport, error)
@@ -154,3 +160,130 @@ func 
TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) {
        runtime.Gosched()
        serv.Stop()
 }
+
+func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) {
+       ln, err := net.Listen("tcp", "localhost:0")
+
+       if err != nil {
+               t.Fatalf("Failed to listen: %v", err)
+       }
+
+       proc := &mockProcessor{
+               ProcessFunc: func(in, out TProtocol) (bool, TException) {
+                       in.ReadMessageBegin(context.Background())
+                       return false, nil
+               },
+       }
+
+       trans := &mockServerTransport{
+               ListenFunc: func() error {
+                       return nil
+               },
+               AcceptFunc: func() (TTransport, error) {
+                       conn, err := ln.Accept()
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       return NewTSocketFromConnConf(conn, nil), nil
+               },
+               CloseFunc: func() error {
+                       return nil
+               },
+               InterruptFunc: func() error {
+                       return ln.Close()
+               },
+       }
+
+       serv := NewTSimpleServer2(proc, trans)
+       go serv.Serve()
+       time.Sleep(networkWaitDuration)
+
+       netConn, err := net.Dial("tcp", ln.Addr().String())
+       if err != nil || netConn == nil {
+               t.Fatal("error when dial server")
+       }
+       time.Sleep(networkWaitDuration)
+
+       serverStopTimeout := 50 * time.Millisecond
+       backupServerStopTimeout := ServerStopTimeout
+       t.Cleanup(func() {
+               ServerStopTimeout = backupServerStopTimeout
+       })
+       ServerStopTimeout = serverStopTimeout
+
+       st := time.Now()
+       err = serv.Stop()
+       if err != nil {
+               t.Errorf("error when stop server:%v", err)
+       }
+
+       if elapsed := time.Since(st); elapsed < serverStopTimeout {
+               t.Errorf("stop cost less time than server stop timeout, server 
stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed)
+       }
+}
+
+func TestStopTimeoutWithSocketTimeout(t *testing.T) {
+       ln, err := net.Listen("tcp", "localhost:0")
+
+       if err != nil {
+               t.Fatalf("Failed to listen: %v", err)
+       }
+
+       proc := &mockProcessor{
+               ProcessFunc: func(in, out TProtocol) (bool, TException) {
+                       in.ReadMessageBegin(context.Background())
+                       return false, nil
+               },
+       }
+
+       conf := &TConfiguration{SocketTimeout: 5 * time.Millisecond}
+       wg := &sync.WaitGroup{}
+       trans := &mockServerTransport{
+               ListenFunc: func() error {
+                       return nil
+               },
+               AcceptFunc: func() (TTransport, error) {
+                       conn, err := ln.Accept()
+                       if err != nil {
+                               return nil, err
+                       }
+                       defer wg.Done()
+                       return NewTSocketFromConnConf(conn, conf), nil
+               },
+               CloseFunc: func() error {
+                       return nil
+               },
+               InterruptFunc: func() error {
+                       return ln.Close()
+               },
+       }
+
+       serv := NewTSimpleServer2(proc, trans)
+       go serv.Serve()
+       time.Sleep(networkWaitDuration)
+
+       wg.Add(1)
+       netConn, err := net.Dial("tcp", ln.Addr().String())
+       if err != nil || netConn == nil {
+               t.Fatal("error when dial server")
+       }
+       wg.Wait()
+
+       expectedStopTimeout := time.Second
+       backupServerStopTimeout := ServerStopTimeout
+       t.Cleanup(func() {
+               ServerStopTimeout = backupServerStopTimeout
+       })
+       ServerStopTimeout = expectedStopTimeout
+
+       st := time.Now()
+       err = serv.Stop()
+       if elapsed := time.Since(st); elapsed > expectedStopTimeout/2 {
+               t.Errorf("stop cost more time than socket timeout, socket 
timeout:%v,server stop timeout:%v,cost time:%v", conf.SocketTimeout, 
ServerStopTimeout, elapsed)
+       }
+
+       if err != nil {
+               t.Fatalf("error when stop server:%v", err)
+       }
+}
diff --git a/test/go/src/common/clientserver_test.go 
b/test/go/src/common/clientserver_test.go
index 609086b..64b326a 100644
--- a/test/go/src/common/clientserver_test.go
+++ b/test/go/src/common/clientserver_test.go
@@ -75,7 +75,7 @@ func doUnit(t *testing.T, unit *test_unit) {
                t.Errorf("Unable to start server: %v", err)
                return
        }
-       go server.AcceptLoop()
+       go server.Serve()
        defer server.Stop()
        client, trans, err := StartClient(unit.host, unit.port, 
unit.domain_socket, unit.transport, unit.protocol, unit.ssl)
        if err != nil {

Reply via email to