This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new d462c514 test(go/adbc/driver/flightsql): test for errors during 
polling (#1544)
d462c514 is described below

commit d462c5145ba76b9e50d68b80f8b3e74508af05ff
Author: David Li <[email protected]>
AuthorDate: Tue Feb 13 13:44:27 2024 -0500

    test(go/adbc/driver/flightsql): test for errors during polling (#1544)
    
    Fixes #1458.
---
 .../driver/flightsql/flightsql_adbc_server_test.go | 62 ++++++++++++++++++++--
 1 file changed, 58 insertions(+), 4 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 44ebb1b5..f779e6af 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -442,6 +442,9 @@ func (ts *ExecuteSchemaTests) TestQuery() {
 type IncrementalQuery struct {
        query     string
        nextIndex int
+       // if set, then return an error in the next poll and unset
+       // for testing the client's error handling
+       unavailable bool
 }
 
 type IncrementalPollTestServer struct {
@@ -451,6 +454,10 @@ type IncrementalPollTestServer struct {
        testCases map[string]IncrementalPollTestCase
 }
 
+var unavailableCase = IncrementalPollTestCase{
+       progress: []int{1, 1},
+}
+
 func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc 
*flight.FlightDescriptor) (*flight.PollInfo, error) {
        srv.mu.Lock()
        defer srv.mu.Unlock()
@@ -478,27 +485,46 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx 
context.Context, desc *
 
        testCase, ok := srv.testCases[query.query]
        if !ok {
-               return nil, status.Errorf(codes.Unimplemented, 
fmt.Sprintf("Invalid case %s", query.query))
+               if query.query == "unavailable" {
+                       testCase = unavailableCase
+               } else {
+                       return nil, status.Errorf(codes.Unimplemented, 
fmt.Sprintf("Invalid case %s", query.query))
+               }
        }
 
        if testCase.differentRetryDescriptor && progress != 
int64(query.nextIndex) {
                return nil, status.Errorf(codes.InvalidArgument, 
fmt.Sprintf("Used wrong retry descriptor, expected %d but got %d", 
query.nextIndex, progress))
        }
 
+       if query.unavailable {
+               query.unavailable = false
+               return nil, status.Errorf(codes.Unavailable, "Server 
temporarily unavailable")
+       }
+
        return srv.MakePollInfo(&testCase, query, queryId)
 }
 
 func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx 
context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) 
(*flight.PollInfo, error) {
+       srv.mu.Lock()
+       defer srv.mu.Unlock()
+
        queryId := uuid.New().String()
 
+       if query.GetQuery() == "unavailable" {
+               srv.queries[queryId] = &IncrementalQuery{
+                       query:       query.GetQuery(),
+                       nextIndex:   0,
+                       unavailable: true,
+               }
+
+               return srv.MakePollInfo(&unavailableCase, srv.queries[queryId], 
queryId)
+       }
+
        testCase, ok := srv.testCases[query.GetQuery()]
        if !ok {
                return nil, status.Errorf(codes.Unimplemented, 
fmt.Sprintf("Invalid case %s", query.GetQuery()))
        }
 
-       srv.mu.Lock()
-       defer srv.mu.Unlock()
-
        srv.queries[queryId] = &IncrementalQuery{
                query:     query.GetQuery(),
                nextIndex: 0,
@@ -701,6 +727,34 @@ func (ts *IncrementalPollTests) TestOptionValue() {
        ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
 }
 
+func (ts *IncrementalPollTests) TestUnavailable() {
+       // An error from the server should not tear down all the state.  We
+       // should be able to retry the request.
+       ctx := context.Background()
+       stmt, err := ts.cnxn.NewStatement()
+       ts.NoError(err)
+       defer stmt.Close()
+
+       ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, 
adbc.OptionValueEnabled))
+
+       ts.NoError(stmt.SetSqlQuery("unavailable"))
+       _, partitions, _, err := stmt.ExecutePartitions(ctx)
+       ts.NoError(err)
+       ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)
+
+       _, partitions, _, err = stmt.ExecutePartitions(ctx)
+       ts.ErrorContains(err, "Server temporarily unavailable")
+       ts.Equal(uint64(0), partitions.NumPartitions)
+
+       _, partitions, _, err = stmt.ExecutePartitions(ctx)
+       ts.NoError(err)
+       ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)
+
+       _, partitions, _, err = stmt.ExecutePartitions(ctx)
+       ts.NoError(err)
+       ts.Equal(uint64(0), partitions.NumPartitions)
+}
+
 func (ts *IncrementalPollTests) RunOneTestCase(ctx context.Context, stmt 
adbc.Statement, name string, testCase *IncrementalPollTestCase) {
        opts := stmt.(adbc.GetSetOptions)
 

Reply via email to