This is an automated email from the ASF dual-hosted git repository.

dcelasun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 05023e8  THRIFT-5214: Connectivity check on go's TSocket
05023e8 is described below

commit 05023e81b264f249affdacad4ebae788b3ada85c
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Tue May 26 15:31:20 2020 -0700

    THRIFT-5214: Connectivity check on go's TSocket
    
    Client: go
    
    Implement connectivity check on go's TSocket and TSSLSocket for
    non-Windows systems.
    
    The implementation is inspired by
    https://github.blog/2020-05-20-three-bugs-in-the-go-mysql-driver/
---
 lib/go/thrift/socket.go                |  23 +++---
 lib/go/thrift/socket_conn.go           | 111 +++++++++++++++++++++++++++++
 lib/go/thrift/socket_conn_test.go      | 125 +++++++++++++++++++++++++++++++++
 lib/go/thrift/socket_unix_conn.go      |  73 +++++++++++++++++++
 lib/go/thrift/socket_unix_conn_test.go | 105 +++++++++++++++++++++++++++
 lib/go/thrift/socket_windows_conn.go   |  34 +++++++++
 lib/go/thrift/ssl_socket.go            |  37 ++++++----
 7 files changed, 483 insertions(+), 25 deletions(-)

diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index 558818a..7c765f5 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -26,7 +26,7 @@ import (
 )
 
 type TSocket struct {
-       conn           net.Conn
+       conn           *socketConn
        addr           net.Addr
        connectTimeout time.Duration
        socketTimeout  time.Duration
@@ -58,7 +58,7 @@ func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout 
time.Duration, soTimeo
 
 // Creates a TSocket from an existing net.Conn
 func NewTSocketFromConnTimeout(conn net.Conn, connTimeout time.Duration) 
*TSocket {
-       return &TSocket{conn: conn, addr: conn.RemoteAddr(), connectTimeout: 
connTimeout, socketTimeout: connTimeout}
+       return &TSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), 
connectTimeout: connTimeout, socketTimeout: connTimeout}
 }
 
 // Sets the connect timeout
@@ -89,7 +89,7 @@ func (p *TSocket) pushDeadline(read, write bool) {
 
 // Connects the socket, creating a new socket object if necessary.
 func (p *TSocket) Open() error {
-       if p.IsOpen() {
+       if p.conn.isValid() {
                return NewTTransportException(ALREADY_OPEN, "Socket already 
connected.")
        }
        if p.addr == nil {
@@ -102,7 +102,11 @@ func (p *TSocket) Open() error {
                return NewTTransportException(NOT_OPEN, "Cannot open bad 
address.")
        }
        var err error
-       if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), 
p.connectTimeout); err != nil {
+       if p.conn, err = createSocketConnFromReturn(net.DialTimeout(
+               p.addr.Network(),
+               p.addr.String(),
+               p.connectTimeout,
+       )); err != nil {
                return NewTTransportException(NOT_OPEN, err.Error())
        }
        return nil
@@ -115,10 +119,7 @@ func (p *TSocket) Conn() net.Conn {
 
 // Returns true if the connection is open
 func (p *TSocket) IsOpen() bool {
-       if p.conn == nil {
-               return false
-       }
-       return true
+       return p.conn.IsOpen()
 }
 
 // Closes the socket.
@@ -140,7 +141,7 @@ func (p *TSocket) Addr() net.Addr {
 }
 
 func (p *TSocket) Read(buf []byte) (int, error) {
-       if !p.IsOpen() {
+       if !p.conn.isValid() {
                return 0, NewTTransportException(NOT_OPEN, "Connection not 
open")
        }
        p.pushDeadline(true, false)
@@ -149,7 +150,7 @@ func (p *TSocket) Read(buf []byte) (int, error) {
 }
 
 func (p *TSocket) Write(buf []byte) (int, error) {
-       if !p.IsOpen() {
+       if !p.conn.isValid() {
                return 0, NewTTransportException(NOT_OPEN, "Connection not 
open")
        }
        p.pushDeadline(false, true)
@@ -161,7 +162,7 @@ func (p *TSocket) Flush(ctx context.Context) error {
 }
 
 func (p *TSocket) Interrupt() error {
-       if !p.IsOpen() {
+       if !p.conn.isValid() {
                return nil
        }
        return p.conn.Close()
diff --git a/lib/go/thrift/socket_conn.go b/lib/go/thrift/socket_conn.go
new file mode 100644
index 0000000..b0f7b3e
--- /dev/null
+++ b/lib/go/thrift/socket_conn.go
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "bytes"
+       "io"
+       "net"
+)
+
+// socketConn is a wrapped net.Conn that tries to do connectivity check.
+type socketConn struct {
+       net.Conn
+
+       buf bytes.Buffer
+}
+
+var _ net.Conn = (*socketConn)(nil)
+
+// createSocketConnFromReturn is a language sugar to help create socketConn 
from
+// return values of functions like net.Dial, tls.Dial, net.Listener.Accept, 
etc.
+func createSocketConnFromReturn(conn net.Conn, err error) (*socketConn, error) 
{
+       if err != nil {
+               return nil, err
+       }
+       return &socketConn{
+               Conn: conn,
+       }, nil
+}
+
+// wrapSocketConn wraps an existing net.Conn into *socketConn.
+func wrapSocketConn(conn net.Conn) *socketConn {
+       // In case conn is already wrapped,
+       // return it as-is and avoid double wrapping.
+       if sc, ok := conn.(*socketConn); ok {
+               return sc
+       }
+
+       return &socketConn{
+               Conn: conn,
+       }
+}
+
+// isValid checks whether there's a valid connection.
+//
+// It's nil safe, and returns false if sc itself is nil, or if the underlying
+// connection is nil.
+//
+// It's the same as the previous implementation of TSocket.IsOpen and
+// TSSLSocket.IsOpen before we added connectivity check.
+func (sc *socketConn) isValid() bool {
+       return sc != nil && sc.Conn != nil
+}
+
+// IsOpen checks whether the connection is open.
+//
+// It's nil safe, and returns false if sc itself is nil, or if the underlying
+// connection is nil.
+//
+// Otherwise, it tries to do a connectivity check and returns the result.
+func (sc *socketConn) IsOpen() bool {
+       if !sc.isValid() {
+               return false
+       }
+       return sc.checkConn() == nil
+}
+
+// Read implements io.Reader.
+//
+// On Windows, it behaves the same as the underlying net.Conn.Read.
+//
+// On non-Windows, it treats len(p) == 0 as a connectivity check instead of
+// readability check, which means instead of blocking until there's something 
to
+// read (readability check), or always return (0, nil) (the default behavior of
+// go's stdlib implementation on non-Windows), it never blocks, and will return
+// an error if the connection is lost.
+func (sc *socketConn) Read(p []byte) (n int, err error) {
+       if len(p) == 0 {
+               return 0, sc.read0()
+       }
+
+       n, err = sc.buf.Read(p)
+       if err != nil && err != io.EOF {
+               return
+       }
+       if n == len(p) {
+               return n, nil
+       }
+       // Continue reading from the wire.
+       var newRead int
+       newRead, err = sc.Conn.Read(p[n:])
+       n += newRead
+       return
+}
diff --git a/lib/go/thrift/socket_conn_test.go 
b/lib/go/thrift/socket_conn_test.go
new file mode 100644
index 0000000..ab92462
--- /dev/null
+++ b/lib/go/thrift/socket_conn_test.go
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "io"
+       "net"
+       "strings"
+       "testing"
+       "time"
+)
+
+type serverSocketConnCallback func(testing.TB, *socketConn)
+
+func serverSocketConn(tb testing.TB, f serverSocketConnCallback) 
(net.Listener, error) {
+       tb.Helper()
+
+       ln, err := net.Listen("tcp", "localhost:0")
+       if err != nil {
+               return nil, err
+       }
+       go func() {
+               for {
+                       sc, err := createSocketConnFromReturn(ln.Accept())
+                       if err != nil {
+                               // This is usually caused by Listener being
+                               // closed, not really an error.
+                               return
+                       }
+                       go f(tb, sc)
+               }
+       }()
+       return ln, nil
+}
+
+func writeFully(tb testing.TB, w io.Writer, s string) bool {
+       tb.Helper()
+
+       n, err := io.Copy(w, strings.NewReader(s))
+       if err != nil {
+               tb.Errorf("Failed to write %q: %v", s, err)
+               return false
+       }
+       if int(n) < len(s) {
+               tb.Errorf("Only wrote %d out of %q", n, s)
+               return false
+       }
+       return true
+}
+
+func TestSocketConn(t *testing.T) {
+       const (
+               interval = time.Millisecond * 10
+               first    = "hello"
+               second   = "world"
+       )
+
+       ln, err := serverSocketConn(
+               t,
+               func(tb testing.TB, sc *socketConn) {
+                       defer sc.Close()
+
+                       if !writeFully(tb, sc, first) {
+                               return
+                       }
+                       time.Sleep(interval)
+                       writeFully(tb, sc, second)
+               },
+       )
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer ln.Close()
+
+       sc, err := createSocketConnFromReturn(net.Dial("tcp", 
ln.Addr().String()))
+       if err != nil {
+               t.Fatal(err)
+       }
+       buf := make([]byte, 1024)
+
+       n, err := sc.Read(buf)
+       if err != nil {
+               t.Fatal(err)
+       }
+       read := string(buf[:n])
+       if read != first {
+               t.Errorf("Expected read %q, got %q", first, read)
+       }
+
+       n, err = sc.Read(buf)
+       if err != nil {
+               t.Fatal(err)
+       }
+       read = string(buf[:n])
+       if read != second {
+               t.Errorf("Expected read %q, got %q", second, read)
+       }
+}
+
+func TestSocketConnNilSafe(t *testing.T) {
+       sc := (*socketConn)(nil)
+       if sc.isValid() {
+               t.Error("Expected false for nil.isValid(), got true")
+       }
+       if sc.IsOpen() {
+               t.Error("Expected false for nil.IsOpen(), got true")
+       }
+}
diff --git a/lib/go/thrift/socket_unix_conn.go 
b/lib/go/thrift/socket_unix_conn.go
new file mode 100644
index 0000000..f18e0e6
--- /dev/null
+++ b/lib/go/thrift/socket_unix_conn.go
@@ -0,0 +1,73 @@
+// +build !windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "io"
+       "syscall"
+)
+
+func (sc *socketConn) read0() error {
+       return sc.checkConn()
+}
+
+func (sc *socketConn) checkConn() error {
+       syscallConn, ok := sc.Conn.(syscall.Conn)
+       if !ok {
+               // No way to check, return nil
+               return nil
+       }
+       rc, err := syscallConn.SyscallConn()
+       if err != nil {
+               return err
+       }
+
+       var n int
+       var buf [1]byte
+
+       if readErr := rc.Read(func(fd uintptr) bool {
+               n, err = syscall.Read(int(fd), buf[:])
+               return true
+       }); readErr != nil {
+               return readErr
+       }
+
+       if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
+               // This means the connection is still open but we don't have
+               // anything to read right now.
+               return nil
+       }
+
+       if n > 0 {
+               // We got 1 byte,
+               // put it to sc's buf for the next real read to use.
+               sc.buf.Write(buf[:])
+               return nil
+       }
+
+       if err != nil {
+               return err
+       }
+
+       // At this point, it means the other side already closed the connection.
+       return io.EOF
+}
diff --git a/lib/go/thrift/socket_unix_conn_test.go 
b/lib/go/thrift/socket_unix_conn_test.go
new file mode 100644
index 0000000..3563a25
--- /dev/null
+++ b/lib/go/thrift/socket_unix_conn_test.go
@@ -0,0 +1,105 @@
+// +build !windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "io"
+       "net"
+       "testing"
+       "time"
+)
+
+func TestSocketConnUnix(t *testing.T) {
+       const (
+               interval = time.Millisecond * 10
+               first    = "hello"
+               second   = "world"
+       )
+
+       ln, err := serverSocketConn(
+               t,
+               func(tb testing.TB, sc *socketConn) {
+                       defer sc.Close()
+
+                       time.Sleep(interval)
+                       if !writeFully(tb, sc, first) {
+                               return
+                       }
+                       time.Sleep(interval)
+                       writeFully(tb, sc, second)
+               },
+       )
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer ln.Close()
+
+       sc, err := createSocketConnFromReturn(net.Dial("tcp", 
ln.Addr().String()))
+       if err != nil {
+               t.Fatal(err)
+       }
+       buf := make([]byte, 1024)
+
+       if !sc.IsOpen() {
+               t.Error("Expected sc to report open, got false")
+       }
+       n, err := sc.Read(buf)
+       if err != nil {
+               t.Fatal(err)
+       }
+       read := string(buf[:n])
+       if read != first {
+               t.Errorf("Expected read %q, got %q", first, read)
+       }
+
+       if !sc.IsOpen() {
+               t.Error("Expected sc to report open, got false")
+       }
+       // Do connection check again twice after server already wrote new data,
+       // make sure we correctly buffered the read bytes
+       time.Sleep(interval * 10)
+       if !sc.IsOpen() {
+               t.Error("Expected sc to report open, got false")
+       }
+       if !sc.IsOpen() {
+               t.Error("Expected sc to report open, got false")
+       }
+       if sc.buf.Len() == 0 {
+               t.Error("Expected sc to buffer read bytes, got empty buffer")
+       }
+       n, err = sc.Read(buf)
+       if err != nil {
+               t.Fatal(err)
+       }
+       read = string(buf[:n])
+       if read != second {
+               t.Errorf("Expected read %q, got %q", second, read)
+       }
+
+       // Now it's supposed to be closed on the server side
+       if err := sc.read0(); err != io.EOF {
+               t.Errorf("Expected to get EOF on read0, got %v", err)
+       }
+       if sc.IsOpen() {
+               t.Error("Expected sc to report not open, got true")
+       }
+}
diff --git a/lib/go/thrift/socket_windows_conn.go 
b/lib/go/thrift/socket_windows_conn.go
new file mode 100644
index 0000000..679838c
--- /dev/null
+++ b/lib/go/thrift/socket_windows_conn.go
@@ -0,0 +1,34 @@
+// +build windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+func (sc *socketConn) read0() error {
+       // On windows, we fallback to the default behavior of reading 0 bytes.
+       var p []byte
+       _, err := sc.Conn.Read(p)
+       return err
+}
+
+func (sc *socketConn) checkConn() error {
+       // On windows, we always return nil for this check.
+       return nil
+}
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index 45bf38a..661111c 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -27,7 +27,7 @@ import (
 )
 
 type TSSLSocket struct {
-       conn net.Conn
+       conn *socketConn
        // hostPort contains host:port (e.g. "asdf.com:12345"). The field is
        // only valid if addr is nil.
        hostPort string
@@ -62,7 +62,7 @@ func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg 
*tls.Config, timeout time.D
 
 // Creates a TSSLSocket from an existing net.Conn
 func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout 
time.Duration) *TSSLSocket {
-       return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: 
timeout, cfg: cfg}
+       return &TSSLSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), 
timeout: timeout, cfg: cfg}
 }
 
 // Sets the socket timeout
@@ -91,12 +91,18 @@ func (p *TSSLSocket) Open() error {
        // If we have a hostname, we need to pass the hostname to tls.Dial for
        // certificate hostname checks.
        if p.hostPort != "" {
-               if p.conn, err = tls.DialWithDialer(&net.Dialer{
-                       Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != 
nil {
+               if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+                       &net.Dialer{
+                               Timeout: p.timeout,
+                       },
+                       "tcp",
+                       p.hostPort,
+                       p.cfg,
+               )); err != nil {
                        return NewTTransportException(NOT_OPEN, err.Error())
                }
        } else {
-               if p.IsOpen() {
+               if p.conn.isValid() {
                        return NewTTransportException(ALREADY_OPEN, "Socket 
already connected.")
                }
                if p.addr == nil {
@@ -108,8 +114,14 @@ func (p *TSSLSocket) Open() error {
                if len(p.addr.String()) == 0 {
                        return NewTTransportException(NOT_OPEN, "Cannot open 
bad address.")
                }
-               if p.conn, err = tls.DialWithDialer(&net.Dialer{
-                       Timeout: p.timeout}, p.addr.Network(), p.addr.String(), 
p.cfg); err != nil {
+               if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+                       &net.Dialer{
+                               Timeout: p.timeout,
+                       },
+                       p.addr.Network(),
+                       p.addr.String(),
+                       p.cfg,
+               )); err != nil {
                        return NewTTransportException(NOT_OPEN, err.Error())
                }
        }
@@ -123,10 +135,7 @@ func (p *TSSLSocket) Conn() net.Conn {
 
 // Returns true if the connection is open
 func (p *TSSLSocket) IsOpen() bool {
-       if p.conn == nil {
-               return false
-       }
-       return true
+       return p.conn.IsOpen()
 }
 
 // Closes the socket.
@@ -143,7 +152,7 @@ func (p *TSSLSocket) Close() error {
 }
 
 func (p *TSSLSocket) Read(buf []byte) (int, error) {
-       if !p.IsOpen() {
+       if !p.conn.isValid() {
                return 0, NewTTransportException(NOT_OPEN, "Connection not 
open")
        }
        p.pushDeadline(true, false)
@@ -152,7 +161,7 @@ func (p *TSSLSocket) Read(buf []byte) (int, error) {
 }
 
 func (p *TSSLSocket) Write(buf []byte) (int, error) {
-       if !p.IsOpen() {
+       if !p.conn.isValid() {
                return 0, NewTTransportException(NOT_OPEN, "Connection not 
open")
        }
        p.pushDeadline(false, true)
@@ -164,7 +173,7 @@ func (p *TSSLSocket) Flush(ctx context.Context) error {
 }
 
 func (p *TSSLSocket) Interrupt() error {
-       if !p.IsOpen() {
+       if !p.conn.isValid() {
                return nil
        }
        return p.conn.Close()

Reply via email to