This is an automated email from the ASF dual-hosted git repository.

jensg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 81334cd  THRIFT-5152: introduce connect timeout and socket timeout 
Client: Go Patch: Qian Lv
81334cd is described below

commit 81334cd7345d3b5af165aa875b733a491f1fd5c7
Author: lvqian <[email protected]>
AuthorDate: Thu Mar 26 19:08:55 2020 +0800

    THRIFT-5152: introduce connect timeout and socket timeout
    Client: Go
    Patch: Qian Lv
    
    This closes #2071
---
 lib/go/test/tests/multiplexed_protocol_test.go |  2 +-
 lib/go/test/tests/one_way_test.go              |  2 +-
 lib/go/test/tests/protocols_test.go            |  2 +-
 lib/go/thrift/socket.go                        | 37 +++++++++++++++-----------
 4 files changed, 25 insertions(+), 18 deletions(-)

diff --git a/lib/go/test/tests/multiplexed_protocol_test.go 
b/lib/go/test/tests/multiplexed_protocol_test.go
index 61ac628..4fb6f4f 100644
--- a/lib/go/test/tests/multiplexed_protocol_test.go
+++ b/lib/go/test/tests/multiplexed_protocol_test.go
@@ -50,7 +50,7 @@ func (s *SecondImpl) ReturnTwo(ctx context.Context) (r int64, 
err error) {
 }
 
 func createTransport(addr net.Addr) (thrift.TTransport, error) {
-       socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
+       socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT, TIMEOUT)
        transport := thrift.NewTFramedTransport(socket)
        err := transport.Open()
        if err != nil {
diff --git a/lib/go/test/tests/one_way_test.go 
b/lib/go/test/tests/one_way_test.go
index 48d0bbe..010e3bb 100644
--- a/lib/go/test/tests/one_way_test.go
+++ b/lib/go/test/tests/one_way_test.go
@@ -65,7 +65,7 @@ func TestInitOneway(t *testing.T) {
 }
 
 func TestInitOnewayClient(t *testing.T) {
-       transport := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
+       transport := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT, TIMEOUT)
        protocol := thrift.NewTBinaryProtocolTransport(transport)
        client = onewaytest.NewOneWayClient(thrift.NewTStandardClient(protocol, 
protocol))
        err := transport.Open()
diff --git a/lib/go/test/tests/protocols_test.go 
b/lib/go/test/tests/protocols_test.go
index cffd9c3..9030e9d 100644
--- a/lib/go/test/tests/protocols_test.go
+++ b/lib/go/test/tests/protocols_test.go
@@ -41,7 +41,7 @@ func RunSocketTestSuite(t *testing.T, protocolFactory 
thrift.TProtocolFactory,
        go server.Serve()
 
        // client
-       var transport thrift.TTransport = 
thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
+       var transport thrift.TTransport = 
thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT, TIMEOUT)
        transport, err = transportFactory.GetTransport(transport)
        if err != nil {
                t.Fatal(err)
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index 88b98f5..558818a 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -26,9 +26,10 @@ import (
 )
 
 type TSocket struct {
-       conn    net.Conn
-       addr    net.Addr
-       timeout time.Duration
+       conn           net.Conn
+       addr           net.Addr
+       connectTimeout time.Duration
+       socketTimeout  time.Duration
 }
 
 // NewTSocket creates a net.Conn-backed TTransport, given a host and port
@@ -36,40 +37,46 @@ type TSocket struct {
 // Example:
 //     trans, err := thrift.NewTSocket("localhost:9090")
 func NewTSocket(hostPort string) (*TSocket, error) {
-       return NewTSocketTimeout(hostPort, 0)
+       return NewTSocketTimeout(hostPort, 0, 0)
 }
 
 // NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and 
port
 // it also accepts a timeout as a time.Duration
-func NewTSocketTimeout(hostPort string, timeout time.Duration) (*TSocket, 
error) {
+func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout 
time.Duration) (*TSocket, error) {
        //conn, err := net.DialTimeout(network, address, timeout)
        addr, err := net.ResolveTCPAddr("tcp", hostPort)
        if err != nil {
                return nil, err
        }
-       return NewTSocketFromAddrTimeout(addr, timeout), nil
+       return NewTSocketFromAddrTimeout(addr, connTimeout, soTimeout), nil
 }
 
 // Creates a TSocket from a net.Addr
-func NewTSocketFromAddrTimeout(addr net.Addr, timeout time.Duration) *TSocket {
-       return &TSocket{addr: addr, timeout: timeout}
+func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, 
soTimeout time.Duration) *TSocket {
+       return &TSocket{addr: addr, connectTimeout: connTimeout, socketTimeout: 
soTimeout}
 }
 
 // Creates a TSocket from an existing net.Conn
-func NewTSocketFromConnTimeout(conn net.Conn, timeout time.Duration) *TSocket {
-       return &TSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout}
+func NewTSocketFromConnTimeout(conn net.Conn, connTimeout time.Duration) 
*TSocket {
+       return &TSocket{conn: conn, addr: conn.RemoteAddr(), connectTimeout: 
connTimeout, socketTimeout: connTimeout}
+}
+
+// Sets the connect timeout
+func (p *TSocket) SetConnTimeout(timeout time.Duration) error {
+       p.connectTimeout = timeout
+       return nil
 }
 
 // Sets the socket timeout
-func (p *TSocket) SetTimeout(timeout time.Duration) error {
-       p.timeout = timeout
+func (p *TSocket) SetSocketTimeout(timeout time.Duration) error {
+       p.socketTimeout = timeout
        return nil
 }
 
 func (p *TSocket) pushDeadline(read, write bool) {
        var t time.Time
-       if p.timeout > 0 {
-               t = time.Now().Add(time.Duration(p.timeout))
+       if p.socketTimeout > 0 {
+               t = time.Now().Add(time.Duration(p.socketTimeout))
        }
        if read && write {
                p.conn.SetDeadline(t)
@@ -95,7 +102,7 @@ func (p *TSocket) Open() error {
                return NewTTransportException(NOT_OPEN, "Cannot open bad 
address.")
        }
        var err error
-       if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), 
p.timeout); err != nil {
+       if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), 
p.connectTimeout); err != nil {
                return NewTTransportException(NOT_OPEN, err.Error())
        }
        return nil

Reply via email to