This is an automated email from the ASF dual-hosted git repository.

huzongtang pushed a commit to branch native
in repository https://gitbox.apache.org/repos/asf/rocketmq-client-go.git


The following commit(s) were added to refs/heads/native by this push:
     new e696fd8  [ISSIUE #75] support shutdown mehtod (#97)
e696fd8 is described below

commit e696fd8a71de1b39e206e734b2b366f9db325886
Author: 高峰 <[email protected]>
AuthorDate: Tue Jul 9 14:47:34 2019 +0800

    [ISSIUE #75] support shutdown mehtod (#97)
    
    * make all remote unit tests pass
    
    * 1: using time.Duration for timeout
    2: add InvokeAsync method unit tests
    3: change some fields visibility
    
    * add shutdown
    
    * 1: using time.Duraton as time related unit; 2: support shutdown method; 
3: fix typo
---
 internal/client.go                    | 29 ++++++++---
 internal/remote/codec.go              |  2 +-
 internal/remote/codec_test.go         |  6 +--
 internal/remote/future.go             | 16 +++---
 internal/remote/remote_client.go      | 57 ++++++++++-----------
 internal/remote/remote_client_test.go | 94 +++++++++++++++++------------------
 internal/validators.go                |  2 +-
 7 files changed, 106 insertions(+), 100 deletions(-)

diff --git a/internal/client.go b/internal/client.go
index c3fafbc..a6e1bc5 100644
--- a/internal/client.go
+++ b/internal/client.go
@@ -43,7 +43,7 @@ const (
        // Pulling topic information interval from the named server
        _PullNameServerInterval = 30 * time.Second
 
-       // Pulling topic information interval from the named server
+       // Sending heart beat interval to all broker
        _HeartbeatBrokerInterval = 30 * time.Second
 
        // Offset persistent interval for consumer
@@ -54,7 +54,7 @@ const (
 )
 
 var (
-       ErrServiceState = errors.New("service state is not running, please 
check")
+       ErrServiceState = errors.New("service close is not running, please 
check")
 
        _VIPChannelEnable = false
 )
@@ -129,6 +129,7 @@ type RMQClient struct {
 
        remoteClient *remote.RemotingClient
        hbMutex      sync.Mutex
+       close        bool
 }
 
 var clientMap sync.Map
@@ -150,6 +151,9 @@ func GetOrNewRocketMQClient(option ClientOptions) 
*RMQClient {
 }
 
 func (c *RMQClient) Start() {
+       //ctx, cancel := context.WithCancel(context.Background())
+       //c.cancel = cancel
+       c.close = false
        c.once.Do(func() {
                // TODO fetchNameServerAddr
                go func() {}()
@@ -158,7 +162,7 @@ func (c *RMQClient) Start() {
                go func() {
                        // delay
                        time.Sleep(50 * time.Millisecond)
-                       for {
+                       for !c.close{
                                c.UpdateTopicRouteInfo()
                                time.Sleep(_PullNameServerInterval)
                        }
@@ -166,7 +170,7 @@ func (c *RMQClient) Start() {
 
                // TODO cleanOfflineBroker & sendHeartbeatToAllBrokerWithLock
                go func() {
-                       for {
+                       for !c.close{
                                cleanOfflineBroker()
                                c.SendHeartbeatToAllBrokerWithLock()
                                time.Sleep(_HeartbeatBrokerInterval)
@@ -176,7 +180,7 @@ func (c *RMQClient) Start() {
                // schedule persist offset
                go func() {
                        //time.Sleep(10 * time.Second)
-                       for {
+                       for !c.close{
                                c.consumerMap.Range(func(key, value 
interface{}) bool {
                                        consumer := value.(InnerConsumer)
                                        consumer.PersistConsumerOffset()
@@ -187,7 +191,7 @@ func (c *RMQClient) Start() {
                }()
 
                go func() {
-                       for {
+                       for !c.close{
                                c.RebalanceImmediately()
                                time.Sleep(_RebalanceInterval)
                        }
@@ -196,7 +200,8 @@ func (c *RMQClient) Start() {
 }
 
 func (c *RMQClient) Shutdown() {
-       // TODO
+       c.remoteClient.ShutDown()
+       c.close = true
 }
 
 func (c *RMQClient) ClientID() string {
@@ -209,18 +214,28 @@ func (c *RMQClient) ClientID() string {
 
 func (c *RMQClient) InvokeSync(addr string, request *remote.RemotingCommand,
        timeoutMillis time.Duration) (*remote.RemotingCommand, error) {
+       if c.close {
+               return nil, ErrServiceState
+       }
        return c.remoteClient.InvokeSync(addr, request, timeoutMillis)
 }
 
 func (c *RMQClient) InvokeAsync(addr string, request *remote.RemotingCommand,
        timeoutMillis time.Duration, f func(*remote.RemotingCommand, error)) 
error {
+       if c.close {
+               return ErrServiceState
+       }
        return c.remoteClient.InvokeAsync(addr, request, timeoutMillis, 
func(future *remote.ResponseFuture) {
                f(future.ResponseCommand, future.Err)
        })
+
 }
 
 func (c *RMQClient) InvokeOneWay(addr string, request *remote.RemotingCommand,
        timeoutMillis time.Duration) error {
+       if c.close {
+               return ErrServiceState
+       }
        return c.remoteClient.InvokeOneWay(addr, request, timeoutMillis)
 }
 
diff --git a/internal/remote/codec.go b/internal/remote/codec.go
index aebf75b..434678e 100644
--- a/internal/remote/codec.go
+++ b/internal/remote/codec.go
@@ -73,7 +73,7 @@ func NewRemotingCommand(code int16, header CustomHeader, body 
[]byte) *RemotingC
 }
 
 func (command *RemotingCommand) String() string {
-       return fmt.Sprintf("Code: %d, Opaque: %d, Remark: %s, ExtFields: %v",
+       return fmt.Sprintf("Code: %d, opaque: %d, Remark: %s, ExtFields: %v",
                command.Code, command.Opaque, command.Remark, command.ExtFields)
 }
 
diff --git a/internal/remote/codec_test.go b/internal/remote/codec_test.go
index 2e836c7..a315947 100644
--- a/internal/remote/codec_test.go
+++ b/internal/remote/codec_test.go
@@ -96,7 +96,7 @@ func Test_decode(t *testing.T) {
                        t.Fatalf("wrong Version. want=%d, got=%d", rc.Version, 
decodedRc.Version)
                }
                if rc.Opaque != decodedRc.Opaque {
-                       t.Fatalf("wrong Opaque. want=%d, got=%d", rc.Opaque, 
decodedRc.Opaque)
+                       t.Fatalf("wrong opaque. want=%d, got=%d", rc.Opaque, 
decodedRc.Opaque)
                }
                if rc.Remark != decodedRc.Remark {
                        t.Fatalf("wrong remark. want=%s, got=%s", rc.Remark, 
decodedRc.Remark)
@@ -167,7 +167,7 @@ func Test_jsonCodec_decodeHeader(t *testing.T) {
                        t.Fatalf("wrong Version. want=%d, got=%d", rc.Version, 
decodedRc.Version)
                }
                if rc.Opaque != decodedRc.Opaque {
-                       t.Fatalf("wrong Opaque. want=%d, got=%d", rc.Opaque, 
decodedRc.Opaque)
+                       t.Fatalf("wrong opaque. want=%d, got=%d", rc.Opaque, 
decodedRc.Opaque)
                }
                if rc.Remark != decodedRc.Remark {
                        t.Fatalf("wrong remark. want=%s, got=%s", rc.Remark, 
decodedRc.Remark)
@@ -237,7 +237,7 @@ func Test_rmqCodec_decodeHeader(t *testing.T) {
                        t.Fatalf("wrong Version. want=%d, got=%d", rc.Version, 
decodedRc.Version)
                }
                if rc.Opaque != decodedRc.Opaque {
-                       t.Fatalf("wrong Opaque. want=%d, got=%d", rc.Opaque, 
decodedRc.Opaque)
+                       t.Fatalf("wrong opaque. want=%d, got=%d", rc.Opaque, 
decodedRc.Opaque)
                }
                if rc.Remark != decodedRc.Remark {
                        t.Fatalf("wrong remark. want=%s, got=%s", rc.Remark, 
decodedRc.Remark)
diff --git a/internal/remote/future.go b/internal/remote/future.go
index 62a9640..93990e5 100644
--- a/internal/remote/future.go
+++ b/internal/remote/future.go
@@ -28,21 +28,21 @@ type ResponseFuture struct {
        SendRequestOK   bool
        Err             error
        Opaque          int32
-       TimeoutMillis   time.Duration
+       Timeout         time.Duration
        callback        func(*ResponseFuture)
-       BeginTimestamp  int64
+       BeginTimestamp  time.Duration
        Done            chan bool
        callbackOnce    sync.Once
 }
 
 // NewResponseFuture create ResponseFuture with opaque, timeout and callback
-func NewResponseFuture(opaque int32, timeoutMillis time.Duration, callback 
func(*ResponseFuture)) *ResponseFuture {
+func NewResponseFuture(opaque int32, timeout time.Duration, callback 
func(*ResponseFuture)) *ResponseFuture {
        return &ResponseFuture{
                Opaque:         opaque,
                Done:           make(chan bool),
-               TimeoutMillis:  timeoutMillis,
+               Timeout:        timeout,
                callback:       callback,
-               BeginTimestamp: time.Now().Unix() * 1000,
+               BeginTimestamp: time.Duration(time.Now().Unix()) * time.Second,
        }
 }
 
@@ -55,8 +55,8 @@ func (r *ResponseFuture) executeInvokeCallback() {
 }
 
 func (r *ResponseFuture) isTimeout() bool {
-       diff := time.Now().Unix()*1000 - r.BeginTimestamp
-       return diff > int64(r.TimeoutMillis)
+       elapse := time.Duration(time.Now().Unix())*time.Second - 
r.BeginTimestamp
+       return elapse > r.Timeout
 }
 
 func (r *ResponseFuture) waitResponse() (*RemotingCommand, error) {
@@ -64,7 +64,7 @@ func (r *ResponseFuture) waitResponse() (*RemotingCommand, 
error) {
                cmd *RemotingCommand
                err error
        )
-       timer := time.NewTimer(r.TimeoutMillis * time.Millisecond)
+       timer := time.NewTimer(r.Timeout)
        for {
                select {
                case <-r.Done:
diff --git a/internal/remote/remote_client.go b/internal/remote/remote_client.go
index fde17e4..1d0330a 100644
--- a/internal/remote/remote_client.go
+++ b/internal/remote/remote_client.go
@@ -32,7 +32,6 @@ import (
 var (
        //ErrRequestTimeout for request timeout error
        ErrRequestTimeout = errors.New("request timeout")
-       connectionLocker  sync.Mutex
 )
 
 type ClientRequestFunc func(*RemotingCommand) *RemotingCommand
@@ -42,10 +41,11 @@ type TcpOption struct {
 }
 
 type RemotingClient struct {
-       responseTable   sync.Map
-       connectionTable sync.Map
-       option          TcpOption
-       processors      map[int16]ClientRequestFunc
+       responseTable    sync.Map
+       connectionTable  sync.Map
+       option           TcpOption
+       processors       map[int16]ClientRequestFunc
+       connectionLocker sync.Mutex
 }
 
 func NewRemotingClient() *RemotingClient {
@@ -59,15 +59,15 @@ func (c *RemotingClient) RegisterRequestFunc(code int16, f 
ClientRequestFunc) {
 }
 
 // TODO: merge sync and async model. sync should run on async model by 
blocking on chan
-func (c *RemotingClient) InvokeSync(addr string, request *RemotingCommand, 
timeoutMillis time.Duration) (*RemotingCommand, error) {
+func (c *RemotingClient) InvokeSync(addr string, request *RemotingCommand, 
timeout time.Duration) (*RemotingCommand, error) {
        conn, err := c.connect(addr)
        if err != nil {
                return nil, err
        }
-       resp := NewResponseFuture(request.Opaque, timeoutMillis, nil)
+       resp := NewResponseFuture(request.Opaque, timeout, nil)
        c.responseTable.Store(resp.Opaque, resp)
-       err = c.sendRequest(conn, request)
        defer c.responseTable.Delete(request.Opaque)
+       err = c.sendRequest(conn, request)
        if err != nil {
                return nil, err
        }
@@ -75,13 +75,13 @@ func (c *RemotingClient) InvokeSync(addr string, request 
*RemotingCommand, timeo
        return resp.waitResponse()
 }
 
-// InvokeAsync send request witout blocking, just return immediately.
-func (c *RemotingClient) InvokeAsync(addr string, request *RemotingCommand, 
timeoutMillis time.Duration, callback func(*ResponseFuture)) error {
+// InvokeAsync send request without blocking, just return immediately.
+func (c *RemotingClient) InvokeAsync(addr string, request *RemotingCommand, 
timeout time.Duration, callback func(*ResponseFuture)) error {
        conn, err := c.connect(addr)
        if err != nil {
                return err
        }
-       resp := NewResponseFuture(request.Opaque, timeoutMillis, callback)
+       resp := NewResponseFuture(request.Opaque, timeout, callback)
        c.responseTable.Store(resp.Opaque, resp)
        err = c.sendRequest(conn, request)
        if err != nil {
@@ -107,27 +107,10 @@ func (c *RemotingClient) InvokeOneWay(addr string, 
request *RemotingCommand, tim
        return c.sendRequest(conn, request)
 }
 
-func (c *RemotingClient) ScanResponseTable() {
-       rfs := make([]*ResponseFuture, 0)
-       c.responseTable.Range(func(key, value interface{}) bool {
-               if resp, ok := value.(*ResponseFuture); ok {
-                       if (resp.BeginTimestamp + int64(resp.TimeoutMillis) + 
1000) <= time.Now().Unix()*1000 {
-                               rfs = append(rfs, resp)
-                               c.responseTable.Delete(key)
-                       }
-               }
-               return true
-       })
-       for _, rf := range rfs {
-               rf.Err = ErrRequestTimeout
-               rf.executeInvokeCallback()
-       }
-}
-
 func (c *RemotingClient) connect(addr string) (net.Conn, error) {
        //it needs additional locker.
-       connectionLocker.Lock()
-       defer connectionLocker.Unlock()
+       c.connectionLocker.Lock()
+       defer c.connectionLocker.Unlock()
        conn, ok := c.connectionTable.Load(addr)
        if ok {
                return conn.(net.Conn), nil
@@ -181,7 +164,7 @@ func (c *RemotingClient) receiveResponse(r net.Conn) {
                }
        }
        if scanner.Err() != nil {
-               rlog.Errorf("net: %s scanner exit, err: %s.", 
r.RemoteAddr().String(), scanner.Err())
+               rlog.Errorf("net: %s scanner exit, Err: %s.", 
r.RemoteAddr().String(), scanner.Err())
        } else {
                rlog.Infof("net: %s scanner exit.", r.RemoteAddr().String())
        }
@@ -237,3 +220,15 @@ func (c *RemotingClient) closeConnection(toCloseConn 
net.Conn) {
                }
        })
 }
+
+func (c *RemotingClient) ShutDown() {
+       c.responseTable.Range(func(key, value interface{}) bool {
+               c.responseTable.Delete(key)
+               return true
+       })
+       c.connectionTable.Range(func(key, value interface{}) bool {
+               conn := value.(net.Conn)
+               conn.Close()
+               return true
+       })
+}
diff --git a/internal/remote/remote_client_test.go 
b/internal/remote/remote_client_test.go
index 6cc63cf..4efb990 100644
--- a/internal/remote/remote_client_test.go
+++ b/internal/remote/remote_client_test.go
@@ -19,6 +19,7 @@ package remote
 import (
        "bytes"
        "errors"
+       "math/rand"
        "net"
        "reflect"
        "sync"
@@ -31,23 +32,23 @@ import (
 func TestNewResponseFuture(t *testing.T) {
        future := NewResponseFuture(10, time.Duration(1000), nil)
        if future.Opaque != 10 {
-               t.Errorf("wrong ResponseFuture's Opaque. want=%d, got=%d", 10, 
future.Opaque)
+               t.Errorf("wrong ResponseFuture's opaque. want=%d, got=%d", 10, 
future.Opaque)
        }
        if future.SendRequestOK != false {
-               t.Errorf("wrong ResposneFutrue's SendRequestOK. want=%t, 
got=%t", false, future.SendRequestOK)
+               t.Errorf("wrong ResposneFutrue's sendRequestOK. want=%t, 
got=%t", false, future.SendRequestOK)
        }
        if future.Err != nil {
                t.Errorf("wrong RespnseFuture's Err. want=<nil>, got=%v", 
future.Err)
        }
-       if future.TimeoutMillis != time.Duration(1000) {
+       if future.Timeout != time.Duration(1000) {
                t.Errorf("wrong ResponseFuture's TimeoutMills. want=%d, got=%d",
-                       future.TimeoutMillis, time.Duration(1000))
+                       future.Timeout, time.Duration(1000))
        }
        if future.callback != nil {
                t.Errorf("wrong ResponseFuture's callback. want=<nil>, 
got!=<nil>")
        }
        if future.Done == nil {
-               t.Errorf("wrong ResponseFuture's Done. want=<channel>, 
got=<nil>")
+               t.Errorf("wrong ResponseFuture's done. want=<channel>, 
got=<nil>")
        }
 }
 
@@ -80,11 +81,11 @@ func TestResponseFutureTimeout(t *testing.T) {
 }
 
 func TestResponseFutureIsTimeout(t *testing.T) {
-       future := NewResponseFuture(10, time.Duration(500), nil)
+       future := NewResponseFuture(10, 500 * time.Millisecond, nil)
        if future.isTimeout() != false {
                t.Errorf("wrong ResponseFuture's istimeout. want=%t, got=%t", 
false, future.isTimeout())
        }
-       time.Sleep(time.Duration(2000) * time.Millisecond)
+       time.Sleep(time.Duration(700) * time.Millisecond)
        if future.isTimeout() != true {
                t.Errorf("wrong ResponseFuture's istimeout. want=%t, got=%t", 
true, future.isTimeout())
        }
@@ -92,12 +93,12 @@ func TestResponseFutureIsTimeout(t *testing.T) {
 }
 
 func TestResponseFutureWaitResponse(t *testing.T) {
-       future := NewResponseFuture(10, time.Duration(500), nil)
+       future := NewResponseFuture(10, 500 * time.Millisecond, nil)
        if _, err := future.waitResponse(); err != ErrRequestTimeout {
                t.Errorf("wrong ResponseFuture waitResponse. want=%v, got=%v",
                        ErrRequestTimeout, err)
        }
-       future = NewResponseFuture(10, time.Duration(500), nil)
+       future = NewResponseFuture(10, 500 * time.Millisecond, nil)
        responseError := errors.New("response error")
        go func() {
                time.Sleep(100 * time.Millisecond)
@@ -108,7 +109,7 @@ func TestResponseFutureWaitResponse(t *testing.T) {
                t.Errorf("wrong ResponseFuture waitResponse. want=%v. got=%v",
                        responseError, err)
        }
-       future = NewResponseFuture(10, time.Duration(500), nil)
+       future = NewResponseFuture(10, 500 * time.Millisecond, nil)
        responseRemotingCommand := NewRemotingCommand(202, nil, nil)
        go func() {
                time.Sleep(100 * time.Millisecond)
@@ -146,7 +147,7 @@ func TestCreateScanner(t *testing.T) {
                        t.Fatalf("wrong Version. want=%d, got=%d", r.Version, 
rcr.Version)
                }
                if r.Opaque != rcr.Opaque {
-                       t.Fatalf("wrong Opaque. want=%d, got=%d", r.Opaque, 
rcr.Opaque)
+                       t.Fatalf("wrong opaque. want=%d, got=%d", r.Opaque, 
rcr.Opaque)
                }
                if r.Flag != rcr.Flag {
                        t.Fatalf("wrong flag. want=%d, got=%d", r.Opaque, 
rcr.Opaque)
@@ -167,7 +168,7 @@ func TestInvokeSync(t *testing.T) {
        client := NewRemotingClient()
        go func() {
                receiveCommand, err := client.InvokeSync(":3000",
-                       clientSendRemtingCommand, time.Duration(1000))
+                       clientSendRemtingCommand, time.Second)
                if err != nil {
                        t.Fatalf("failed to invoke synchronous. %s", err)
                } else {
@@ -214,65 +215,58 @@ func TestInvokeSync(t *testing.T) {
 }
 
 func TestInvokeAsync(t *testing.T) {
-       clientSendRemtingCommand := NewRemotingCommand(10, nil, []byte("Hello 
RocketMQ"))
-       serverSendRemotingCommand := NewRemotingCommand(20, nil, 
[]byte("Welcome native"))
-       serverSendRemotingCommand.Opaque = clientSendRemtingCommand.Opaque
-       serverSendRemotingCommand.Flag = ResponseType
-
        var wg sync.WaitGroup
-       wg.Add(1)
+       cnt := 50
+       wg.Add(cnt)
        client := NewRemotingClient()
-       go func() {
-               time.Sleep(1 * time.Second)
-               t.Logf("invoke async method")
-               err := client.InvokeAsync(":3000", clientSendRemtingCommand,
-                       time.Duration(1000), func(r *ResponseFuture) {
-                               t.Logf("invoke async callback")
-                               if string(r.ResponseCommand.Body) != "Welcome 
native" {
-                                       t.Errorf("wrong responseCommand.Body. 
want=%s, got=%s",
-                                               "Welcome native", 
string(r.ResponseCommand.Body))
+       for i:=0; i < cnt; i++ {
+               go func(index int) {
+                       time.Sleep(time.Duration(rand.Intn(100)) * 
time.Millisecond)
+                       t.Logf("[Send: %d] asychronous message", index)
+                       sendRemotingCommand := randomNewRemotingCommand()
+                       err := client.InvokeAsync(":3000", sendRemotingCommand, 
time.Second, func(r *ResponseFuture) {
+                               t.Logf("[Receive: %d] asychronous message 
response", index)
+                               if string(sendRemotingCommand.Body) != 
string(r.ResponseCommand.Body) {
+                                       t.Errorf("wrong response message. 
want=%s, got=%s", string(sendRemotingCommand.Body),
+                                               string(r.ResponseCommand.Body))
                                }
                                wg.Done()
                        })
-               if err != nil {
-                       t.Errorf("failed to invokeSync. %s", err)
-               }
-
-       }()
+                       if err != nil {
+                               t.Errorf("failed to invokeAsync. %s", err)
+                       }
 
+               }(i)
+       }
        l, err := net.Listen("tcp", ":3000")
        if err != nil {
-               t.Fatal(err)
+               t.Fatalf("failed to create tcp network. %s", err)
        }
        defer l.Close()
+       count := 0
        for {
                conn, err := l.Accept()
                if err != nil {
-                       return
+                       t.Fatalf("failed to create connection. %s", err)
                }
                defer conn.Close()
                scanner := client.createScanner(conn)
                for scanner.Scan() {
-                       t.Logf("receive request.")
-                       receivedRemotingCommand, err := decode(scanner.Bytes())
+                       t.Log("receive request")
+                       r, err := decode(scanner.Bytes())
                        if err != nil {
-                               t.Errorf("failed to decode RemotingCommnad. 
%s", err)
-                       }
-                       if clientSendRemtingCommand.Code != 
receivedRemotingCommand.Code {
-                               t.Errorf("wrong code. want=%d, got=%d", 
receivedRemotingCommand.Code,
-                                       clientSendRemtingCommand.Code)
-                       }
-                       t.Logf("encoding response")
-                       body, err := encode(serverSendRemotingCommand)
-                       if err != nil {
-                               t.Fatalf("failed to encode RemotingCommand")
+                               t.Errorf("failed to decode RemotingCommand %s", 
err)
                        }
+                       r.markResponseType()
+                       body, _ := encode(r)
                        _, err = conn.Write(body)
-                       t.Logf("sent response")
                        if err != nil {
-                               t.Fatalf("failed to write body to conneciton.")
+                               t.Fatalf("failed to send response %s", err)
+                       }
+                       count++
+                       if count >= cnt {
+                               goto done
                        }
-                       goto done
                }
        }
 done:
@@ -367,3 +361,5 @@ func TestInvokeOneWay(t *testing.T) {
        }
        wg.Done()
 }
+
+
diff --git a/internal/validators.go b/internal/validators.go
index 1f8f5f4..8b6f122 100644
--- a/internal/validators.go
+++ b/internal/validators.go
@@ -29,7 +29,7 @@ const (
 )
 
 var (
-       _Pattern, _ = regexp.Compile("_ValidPattern")
+       _Pattern, _ = regexp.Compile(_ValidPattern)
 )
 
 func ValidateGroup(group string) {

Reply via email to