This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f32626f test(go/adbc/driver/flightsql): test handling of bad 
locations (#1533)
6f32626f is described below

commit 6f32626f0c8bb6f92848d5971f440302fcc3ce0a
Author: David Li <[email protected]>
AuthorDate: Fri Feb 9 10:29:56 2024 -0500

    test(go/adbc/driver/flightsql): test handling of bad locations (#1533)
    
    Add a test to ensure that the driver can still fetch data from a server
    that returns 1 unreachable and 1 reachable location. Also, add a new
    option to adjust the connect timeout.
    
    Fixes #1527.
---
 docs/source/driver/flight_sql.rst                  |  6 +++
 .../driver/flightsql/flightsql_adbc_server_test.go | 63 +++++++++++++++++++++-
 go/adbc/driver/flightsql/flightsql_database.go     | 23 ++++++--
 go/adbc/driver/flightsql/flightsql_driver.go       |  8 ++-
 go/adbc/driver/flightsql/timeouts.go               | 12 +++++
 5 files changed, 106 insertions(+), 6 deletions(-)

diff --git a/docs/source/driver/flight_sql.rst 
b/docs/source/driver/flight_sql.rst
index 7473a7cb..a9067fcb 100644
--- a/docs/source/driver/flight_sql.rst
+++ b/docs/source/driver/flight_sql.rst
@@ -326,6 +326,12 @@ The options are as follows:
     For example, this controls the timeout of the underlying Flight
     calls that implement bulk ingestion, or transaction support.
 
+There is also a timeout that is set on the :cpp:class:`AdbcDatabase`:
+
+``adbc.flight.sql.rpc.timeout_seconds.connect``
+    A timeout (in floating-point seconds) for establishing a connection.  The
+    default is 20 seconds.
+
 Transactions
 ------------
 
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 3df7c9a3..44ebb1b5 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -23,6 +23,7 @@ import (
        "context"
        "errors"
        "fmt"
+       "net"
        "net/textproto"
        "os"
        "strconv"
@@ -810,11 +811,18 @@ func (ts *IncrementalPollTests) TestQueryTransaction() {
 
 type TimeoutTestServer struct {
        flightsql.BaseServer
+       badPort  int
+       goodPort int
 }
 
 func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt 
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, 
error) {
-       if string(tkt.GetStatementHandle()) == "sleep and succeed" {
+       ticket := string(tkt.GetStatementHandle())
+       if ticket == "sleep and succeed" {
                time.Sleep(1 * time.Second)
+       }
+
+       switch ticket {
+       case "bad endpoint", "sleep and succeed":
                sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
                rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, 
sc, strings.NewReader(`[{"a": 5}]`))
                if err != nil {
@@ -850,6 +858,23 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx 
context.Context, cmd fli
        switch cmd.GetQuery() {
        case "timeout":
                <-ctx.Done()
+       case "bad endpoint":
+               tkt, _ := flightsql.CreateStatementQueryTicket([]byte("bad 
endpoint"))
+               info := &flight.FlightInfo{
+                       FlightDescriptor: desc,
+                       Endpoint: []*flight.FlightEndpoint{
+                               {
+                                       Ticket: &flight.Ticket{Ticket: tkt},
+                                       Location: []*flight.Location{
+                                               {Uri: 
fmt.Sprintf("grpc://localhost:%d", ts.badPort)},
+                                               {Uri: 
fmt.Sprintf("grpc://localhost:%d", ts.goodPort)},
+                                       },
+                               },
+                       },
+                       TotalRecords: -1,
+                       TotalBytes:   -1,
+               }
+               return info, nil
        case "fetch":
                tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch"))
                info := &flight.FlightInfo{
@@ -884,10 +909,23 @@ func (ts *TimeoutTestServer) CreatePreparedStatement(ctx 
context.Context, req fl
 
 type TimeoutTests struct {
        ServerBasedTests
+       server net.Listener
 }
 
 func (suite *TimeoutTests) SetupSuite() {
-       suite.DoSetupSuite(&TimeoutTestServer{}, nil, nil)
+       var err error
+       suite.server, err = net.Listen("tcp", "localhost:0")
+       suite.NoError(err)
+
+       badPort := suite.server.Addr().(*net.TCPAddr).Port
+       server := &TimeoutTestServer{badPort: badPort}
+       suite.DoSetupSuite(server, nil, nil)
+       server.goodPort = suite.s.Addr().(*net.TCPAddr).Port
+}
+
+func (suite *TimeoutTests) TearDownSuite() {
+       suite.ServerBasedTests.TearDownSuite()
+       suite.NoError(suite.server.Close())
 }
 
 func (ts *TimeoutTests) TestInvalidValues() {
@@ -1075,6 +1113,27 @@ func (ts *TimeoutTests) TestDontTimeout() {
        ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", 
expected, rec)
 }
 
+func (ts *TimeoutTests) TestBadAddress() {
+       stmt, err := ts.cnxn.NewStatement()
+       ts.Require().NoError(err)
+       defer stmt.Close()
+       ts.Require().NoError(stmt.SetSqlQuery("bad endpoint"))
+
+       
ts.Require().NoError(ts.db.(adbc.GetSetOptions).SetOptionDouble(driver.OptionTimeoutConnect,
 5))
+
+       rr, _, err := stmt.ExecuteQuery(context.Background())
+       ts.Require().NoError(err)
+       defer rr.Release()
+
+       rr, _, err = stmt.ExecuteQuery(context.Background())
+       ts.Require().NoError(err)
+       defer rr.Release()
+
+       rr, _, err = stmt.ExecuteQuery(context.Background())
+       ts.Require().NoError(err)
+       defer rr.Release()
+}
+
 // ---- Cookie Tests --------------------
 type CookieTestServer struct {
        flightsql.BaseServer
diff --git a/go/adbc/driver/flightsql/flightsql_database.go 
b/go/adbc/driver/flightsql/flightsql_database.go
index 1407fedf..5e5e3af9 100644
--- a/go/adbc/driver/flightsql/flightsql_database.go
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -194,6 +194,13 @@ func (d *databaseImpl) SetOptions(cnOptions 
map[string]string) error {
                delete(cnOptions, OptionTimeoutUpdate)
        }
 
+       if tv, ok := cnOptions[OptionTimeoutConnect]; ok {
+               if err = d.timeout.setTimeoutString(OptionTimeoutConnect, tv); 
err != nil {
+                       return err
+               }
+               delete(cnOptions, OptionTimeoutConnect)
+       }
+
        if val, ok := cnOptions[OptionWithBlock]; ok {
                if val == adbc.OptionValueEnabled {
                        d.dialOpts.block = true
@@ -257,6 +264,8 @@ func (d *databaseImpl) GetOption(key string) (string, 
error) {
                return d.timeout.queryTimeout.String(), nil
        case OptionTimeoutUpdate:
                return d.timeout.updateTimeout.String(), nil
+       case OptionTimeoutConnect:
+               return d.timeout.connectTimeout.String(), nil
        }
        if val, ok := d.options[key]; ok {
                return val, nil
@@ -271,6 +280,8 @@ func (d *databaseImpl) GetOptionInt(key string) (int64, 
error) {
        case OptionTimeoutQuery:
                fallthrough
        case OptionTimeoutUpdate:
+               fallthrough
+       case OptionTimeoutConnect:
                val, err := d.GetOptionDouble(key)
                if err != nil {
                        return 0, err
@@ -289,6 +300,8 @@ func (d *databaseImpl) GetOptionDouble(key string) 
(float64, error) {
                return d.timeout.queryTimeout.Seconds(), nil
        case OptionTimeoutUpdate:
                return d.timeout.updateTimeout.Seconds(), nil
+       case OptionTimeoutConnect:
+               return d.timeout.connectTimeout.Seconds(), nil
        }
 
        return d.DatabaseImplBase.GetOptionDouble(key)
@@ -297,7 +310,7 @@ func (d *databaseImpl) GetOptionDouble(key string) 
(float64, error) {
 func (d *databaseImpl) SetOption(key, value string) error {
        // We can't change most options post-init
        switch key {
-       case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
+       case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate, 
OptionTimeoutConnect:
                return d.timeout.setTimeoutString(key, value)
        }
        if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
@@ -313,6 +326,8 @@ func (d *databaseImpl) SetOptionInt(key string, value 
int64) error {
        case OptionTimeoutQuery:
                fallthrough
        case OptionTimeoutUpdate:
+               fallthrough
+       case OptionTimeoutConnect:
                return d.timeout.setTimeout(key, float64(value))
        }
 
@@ -326,6 +341,8 @@ func (d *databaseImpl) SetOptionDouble(key string, value 
float64) error {
        case OptionTimeoutQuery:
                fallthrough
        case OptionTimeoutUpdate:
+               fallthrough
+       case OptionTimeoutConnect:
                return d.timeout.setTimeout(key, value)
        }
 
@@ -366,8 +383,9 @@ func getFlightClient(ctx context.Context, loc string, d 
*databaseImpl, authMiddl
                creds = insecure.NewCredentials()
                target = "unix:" + uri.Path
        }
-       dialOpts := append(d.dialOpts.opts, 
grpc.WithTransportCredentials(creds))
+       dialOpts := append(d.dialOpts.opts, 
grpc.WithConnectParams(d.timeout.connectParams()), 
grpc.WithTransportCredentials(creds))
 
+       d.Logger.DebugContext(ctx, "new client", "location", loc)
        cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
        if err != nil {
                return nil, adbc.Error{
@@ -395,7 +413,6 @@ func getFlightClient(ctx context.Context, loc string, d 
*databaseImpl, authMiddl
                }
        }
 
-       d.Logger.DebugContext(ctx, "new client", "location", loc)
        return cl, nil
 }
 
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go 
b/go/adbc/driver/flightsql/flightsql_driver.go
index 727ed827..4914ad1c 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -35,6 +35,7 @@ import (
        "net/url"
        "runtime/debug"
        "strings"
+       "time"
 
        "github.com/apache/arrow-adbc/go/adbc"
        "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
@@ -53,6 +54,7 @@ const (
        OptionWithBlock           = "adbc.flight.sql.client_option.with_block"
        OptionWithMaxMsgSize      = 
"adbc.flight.sql.client_option.with_max_msg_size"
        OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
+       OptionTimeoutConnect      = 
"adbc.flight.sql.rpc.timeout_seconds.connect"
        OptionTimeoutFetch        = "adbc.flight.sql.rpc.timeout_seconds.fetch"
        OptionTimeoutQuery        = "adbc.flight.sql.rpc.timeout_seconds.query"
        OptionTimeoutUpdate       = "adbc.flight.sql.rpc.timeout_seconds.update"
@@ -126,7 +128,11 @@ func (d *driverImpl) NewDatabase(opts map[string]string) 
(adbc.Database, error)
 
        db := &databaseImpl{
                DatabaseImplBase: 
driverbase.NewDatabaseImplBase(&d.DriverImplBase),
-               hdrs:             make(metadata.MD),
+               timeout: timeoutOption{
+                       // Match gRPC default
+                       connectTimeout: time.Second * 20,
+               },
+               hdrs: make(metadata.MD),
        }
 
        var err error
diff --git a/go/adbc/driver/flightsql/timeouts.go 
b/go/adbc/driver/flightsql/timeouts.go
index 77375268..db39ca46 100644
--- a/go/adbc/driver/flightsql/timeouts.go
+++ b/go/adbc/driver/flightsql/timeouts.go
@@ -28,6 +28,7 @@ import (
 
        "github.com/apache/arrow-adbc/go/adbc"
        "google.golang.org/grpc"
+       "google.golang.org/grpc/backoff"
        "google.golang.org/grpc/metadata"
 )
 
@@ -40,6 +41,8 @@ type timeoutOption struct {
        queryTimeout time.Duration
        // timeout for DoPut or DoAction requests
        updateTimeout time.Duration
+       // timeout for establishing a new connection
+       connectTimeout time.Duration
 }
 
 func (t *timeoutOption) setTimeout(key string, value float64) error {
@@ -60,6 +63,8 @@ func (t *timeoutOption) setTimeout(key string, value float64) 
error {
                t.queryTimeout = timeout
        case OptionTimeoutUpdate:
                t.updateTimeout = timeout
+       case OptionTimeoutConnect:
+               t.connectTimeout = timeout
        default:
                return adbc.Error{
                        Msg:  fmt.Sprintf("[Flight SQL] Unknown timeout option 
'%s'", key),
@@ -81,6 +86,13 @@ func (t *timeoutOption) setTimeoutString(key string, value 
string) error {
        return t.setTimeout(key, timeout)
 }
 
+func (t *timeoutOption) connectParams() grpc.ConnectParams {
+       return grpc.ConnectParams{
+               Backoff:           backoff.DefaultConfig,
+               MinConnectTimeout: t.connectTimeout,
+       }
+}
+
 func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, 
bool) {
        for _, opt := range callOptions {
                if to, ok := opt.(timeoutOption); ok {

Reply via email to