This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new d462c514 test(go/adbc/driver/flightsql): test for errors during
polling (#1544)
d462c514 is described below
commit d462c5145ba76b9e50d68b80f8b3e74508af05ff
Author: David Li <[email protected]>
AuthorDate: Tue Feb 13 13:44:27 2024 -0500
test(go/adbc/driver/flightsql): test for errors during polling (#1544)
Fixes #1458.
---
.../driver/flightsql/flightsql_adbc_server_test.go | 62 ++++++++++++++++++++--
1 file changed, 58 insertions(+), 4 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 44ebb1b5..f779e6af 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -442,6 +442,9 @@ func (ts *ExecuteSchemaTests) TestQuery() {
type IncrementalQuery struct {
query string
nextIndex int
+ // if set, then return an error in the next poll and unset
+ // for testing the client's error handling
+ unavailable bool
}
type IncrementalPollTestServer struct {
@@ -451,6 +454,10 @@ type IncrementalPollTestServer struct {
testCases map[string]IncrementalPollTestCase
}
+var unavailableCase = IncrementalPollTestCase{
+ progress: []int{1, 1},
+}
+
func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc
*flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()
@@ -478,27 +485,46 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx
context.Context, desc *
testCase, ok := srv.testCases[query.query]
if !ok {
- return nil, status.Errorf(codes.Unimplemented,
fmt.Sprintf("Invalid case %s", query.query))
+ if query.query == "unavailable" {
+ testCase = unavailableCase
+ } else {
+ return nil, status.Errorf(codes.Unimplemented,
fmt.Sprintf("Invalid case %s", query.query))
+ }
}
if testCase.differentRetryDescriptor && progress !=
int64(query.nextIndex) {
return nil, status.Errorf(codes.InvalidArgument,
fmt.Sprintf("Used wrong retry descriptor, expected %d but got %d",
query.nextIndex, progress))
}
+ if query.unavailable {
+ query.unavailable = false
+ return nil, status.Errorf(codes.Unavailable, "Server
temporarily unavailable")
+ }
+
return srv.MakePollInfo(&testCase, query, queryId)
}
func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx
context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor)
(*flight.PollInfo, error) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+
queryId := uuid.New().String()
+ if query.GetQuery() == "unavailable" {
+ srv.queries[queryId] = &IncrementalQuery{
+ query: query.GetQuery(),
+ nextIndex: 0,
+ unavailable: true,
+ }
+
+ return srv.MakePollInfo(&unavailableCase, srv.queries[queryId],
queryId)
+ }
+
testCase, ok := srv.testCases[query.GetQuery()]
if !ok {
return nil, status.Errorf(codes.Unimplemented,
fmt.Sprintf("Invalid case %s", query.GetQuery()))
}
- srv.mu.Lock()
- defer srv.mu.Unlock()
-
srv.queries[queryId] = &IncrementalQuery{
query: query.GetQuery(),
nextIndex: 0,
@@ -701,6 +727,34 @@ func (ts *IncrementalPollTests) TestOptionValue() {
ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
}
+func (ts *IncrementalPollTests) TestUnavailable() {
+ // An error from the server should not tear down all the state. We
+ // should be able to retry the request.
+ ctx := context.Background()
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental,
adbc.OptionValueEnabled))
+
+ ts.NoError(stmt.SetSqlQuery("unavailable"))
+ _, partitions, _, err := stmt.ExecutePartitions(ctx)
+ ts.NoError(err)
+ ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)
+
+ _, partitions, _, err = stmt.ExecutePartitions(ctx)
+ ts.ErrorContains(err, "Server temporarily unavailable")
+ ts.Equal(uint64(0), partitions.NumPartitions)
+
+ _, partitions, _, err = stmt.ExecutePartitions(ctx)
+ ts.NoError(err)
+ ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)
+
+ _, partitions, _, err = stmt.ExecutePartitions(ctx)
+ ts.NoError(err)
+ ts.Equal(uint64(0), partitions.NumPartitions)
+}
+
func (ts *IncrementalPollTests) RunOneTestCase(ctx context.Context, stmt
adbc.Statement, name string, testCase *IncrementalPollTestCase) {
opts := stmt.(adbc.GetSetOptions)