This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new ddbfaecc fix(go/adbc/driver/flightsql): Have GetTableSchema check for
table name match instead of the first schema it receives (#980)
ddbfaecc is described below
commit ddbfaeccba2be01fe0e54cacd29c058fdf5359e3
Author: Solomon Choe <[email protected]>
AuthorDate: Tue Aug 22 11:42:44 2023 -0700
fix(go/adbc/driver/flightsql): Have GetTableSchema check for table name
match instead of the first schema it receives (#980)
Fixes #934.
---
go/adbc/driver/flightsql/flightsql_adbc.go | 42 +++++++---
.../driver/flightsql/flightsql_adbc_server_test.go | 92 ++++++++++++++++++++++
2 files changed, 122 insertions(+), 12 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go
b/go/adbc/driver/flightsql/flightsql_adbc.go
index 1ae99a6a..e00310cf 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -1231,24 +1231,42 @@ func (c *cnxn) GetTableSchema(ctx context.Context,
catalog *string, dbSchema *st
return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}
- if rec.NumRows() == 0 {
+ numRows := rec.NumRows()
+ switch {
+ case numRows == 0:
return nil, adbc.Error{
Code: adbc.StatusNotFound,
}
+ case numRows > math.MaxInt32:
+ return nil, adbc.Error{
+ Msg: "[Flight SQL] GetTableSchema cannot handle tables
with number of rows > 2^31 - 1",
+ Code: adbc.StatusNotImplemented,
+ }
}
- // returned schema should be
- // 0: catalog_name: utf8
- // 1: db_schema_name: utf8
- // 2: table_name: utf8 not null
- // 3: table_type: utf8 not null
- // 4: table_schema: bytes not null
- schemaBytes := rec.Column(4).(*array.Binary).Value(0)
- s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc)
- if err != nil {
- return nil, adbcFromFlightStatus(err, "GetTableSchema")
+ var s *arrow.Schema
+ for i := 0; i < int(numRows); i++ {
+ currentTableName := rec.Column(2).(*array.String).Value(i)
+ if currentTableName == tableName {
+ // returned schema should be
+ // 0: catalog_name: utf8
+ // 1: db_schema_name: utf8
+ // 2: table_name: utf8 not null
+ // 3: table_type: utf8 not null
+ // 4: table_schema: bytes not null
+ schemaBytes := rec.Column(4).(*array.Binary).Value(i)
+ s, err = flight.DeserializeSchema(schemaBytes,
c.db.alloc)
+ if err != nil {
+ return nil, adbcFromFlightStatus(err,
"GetTableSchema")
+ }
+ return s, nil
+ }
+ }
+
+ return s, adbc.Error{
+ Msg: "[Flight SQL] GetTableSchema could not find a table with
a matching schema",
+ Code: adbc.StatusNotFound,
}
- return s, nil
}
// GetTableTypes returns a list of the table types in the database.
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index dd6171c4..d8af6a65 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -35,6 +35,7 @@ import (
"github.com/apache/arrow/go/v13/arrow/array"
"github.com/apache/arrow/go/v13/arrow/flight"
"github.com/apache/arrow/go/v13/arrow/flight/flightsql"
+ "github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/stretchr/testify/suite"
"golang.org/x/exp/maps"
@@ -107,6 +108,10 @@ func TestDataType(t *testing.T) {
suite.Run(t, &DataTypeTests{})
}
+func TestMultiTable(t *testing.T) {
+ suite.Run(t, &MultiTableTests{})
+}
+
// ---- AuthN Tests --------------------
type AuthnTestServer struct {
@@ -627,3 +632,90 @@ func (suite *DataTypeTests) TestListInt() {
func (suite *DataTypeTests) TestMapIntInt() {
suite.DoTestCase("map[int]int", SchemaMapIntInt)
}
+
+// ---- Multi Table Tests --------------------
+
+type MultiTableTestServer struct {
+ flightsql.BaseServer
+}
+
+func (server *MultiTableTestServer) GetFlightInfoStatement(ctx
context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor)
(*flight.FlightInfo, error) {
+ query := cmd.GetQuery()
+ tkt, err := flightsql.CreateStatementQueryTicket([]byte(query))
+ if err != nil {
+ return nil, err
+ }
+
+ return &flight.FlightInfo{
+ Endpoint: []*flight.FlightEndpoint{{Ticket:
&flight.Ticket{Ticket: tkt}}},
+ FlightDescriptor: desc,
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }, nil
+}
+
+func (server *MultiTableTestServer) GetFlightInfoTables(ctx context.Context,
cmd flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo,
error) {
+ schema := schema_ref.Tables
+ if cmd.GetIncludeSchema() {
+ schema = schema_ref.TablesWithIncludedSchema
+ }
+ server.Alloc = memory.NewCheckedAllocator(memory.DefaultAllocator)
+ info := &flight.FlightInfo{
+ Endpoint: []*flight.FlightEndpoint{
+ {Ticket: &flight.Ticket{Ticket: desc.Cmd}},
+ },
+ FlightDescriptor: desc,
+ Schema: flight.SerializeSchema(schema, server.Alloc),
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }
+
+ return info, nil
+}
+
+func (server *MultiTableTestServer) DoGetTables(ctx context.Context, cmd
flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+ bldr := array.NewRecordBuilder(server.Alloc, adbc.GetTableSchemaSchema)
+
+ bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
+ bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
+ bldr.Field(2).(*array.StringBuilder).AppendValues([]string{"tbl1",
"tbl2"}, nil)
+ bldr.Field(3).(*array.StringBuilder).AppendValues([]string{"", ""}, nil)
+
+ sc1 := arrow.NewSchema([]arrow.Field{{Name: "a", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ sc2 := arrow.NewSchema([]arrow.Field{{Name: "b", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ buf1 := flight.SerializeSchema(sc1, server.Alloc)
+ buf2 := flight.SerializeSchema(sc2, server.Alloc)
+
+ bldr.Field(4).(*array.BinaryBuilder).AppendValues([][]byte{buf1, buf2},
nil)
+ defer bldr.Release()
+
+ rec := bldr.NewRecord()
+
+ ch := make(chan flight.StreamChunk)
+ go func() {
+ defer close(ch)
+ ch <- flight.StreamChunk{
+ Data: rec,
+ Desc: nil,
+ Err: nil,
+ }
+ }()
+ return adbc.GetTableSchemaSchema, ch, nil
+}
+
+type MultiTableTests struct {
+ ServerBasedTests
+}
+
+func (suite *MultiTableTests) SetupSuite() {
+ suite.DoSetupSuite(&MultiTableTestServer{}, nil, map[string]string{})
+}
+
+// Regression test for https://github.com/apache/arrow-adbc/issues/934
+func (suite *MultiTableTests) TestGetTableSchema() {
+ actualSchema, err := suite.cnxn.GetTableSchema(context.Background(),
nil, nil, "tbl2")
+ suite.NoError(err)
+
+ expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ suite.Equal(expectedSchema, actualSchema)
+}