This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 8251403d fix(go/adbc): fix crash on map type (#854)
8251403d is described below

commit 8251403d054eb225d92ae82b687300f34d346061
Author: David Li <[email protected]>
AuthorDate: Wed Jul 5 12:12:59 2023 -0400

    fix(go/adbc): fix crash on map type (#854)
    
    Fixes #853.
---
 .../driver/flightsql/flightsql_adbc_server_test.go | 103 +++++++++++++++++++++
 go/adbc/utils/utils.go                             |  10 +-
 2 files changed, 112 insertions(+), 1 deletion(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 11c5c31e..61a46db1 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -103,6 +103,10 @@ func TestCookies(t *testing.T) {
        suite.Run(t, &CookieTests{})
 }
 
+func TestDataType(t *testing.T) {
+       suite.Run(t, &DataTypeTests{})
+}
+
 // ---- AuthN Tests --------------------
 
 type AuthnTestServer struct {
@@ -524,3 +528,102 @@ func (suite *CookieTests) TestCookieUsage() {
        suite.Require().NoError(err)
        defer reader.Release()
 }
+
+// ---- Data Type Tests --------------------
+type DataTypeTestServer struct {
+       flightsql.BaseServer
+}
+
+func (server *DataTypeTestServer) GetFlightInfoStatement(ctx context.Context, 
cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) 
(*flight.FlightInfo, error) {
+       tkt, _ := flightsql.CreateStatementQueryTicket([]byte(cmd.GetQuery()))
+       info := &flight.FlightInfo{
+               FlightDescriptor: desc,
+               Endpoint: []*flight.FlightEndpoint{
+                       {Ticket: &flight.Ticket{Ticket: tkt}},
+               },
+               TotalRecords: -1,
+               TotalBytes:   -1,
+       }
+
+       return info, nil
+}
+
+var (
+       SchemaListInt3     = arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.FixedSizeListOf(3, arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
+       SchemaListInt      = arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.ListOf(arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
+       SchemaLargeListInt = arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.LargeListOf(arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
+       SchemaMapIntInt    = arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.MapOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32), Nullable: 
true}}, nil)
+)
+
+func (server *DataTypeTestServer) DoGetStatement(ctx context.Context, tkt 
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, 
error) {
+       var schema *arrow.Schema
+       var record arrow.Record
+       var err error
+
+       cmd := string(tkt.GetStatementHandle())
+       switch cmd {
+       case "list[int, 3]":
+               schema = SchemaListInt3
+               record, _, err = array.RecordFromJSON(memory.DefaultAllocator, 
schema, strings.NewReader(`[{"a": [1, 2, 3]}]`))
+       case "list[int]":
+               schema = SchemaListInt
+               record, _, err = array.RecordFromJSON(memory.DefaultAllocator, 
schema, strings.NewReader(`[{"a": [1]}]`))
+       case "large_list[int]":
+               schema = SchemaLargeListInt
+               record, _, err = array.RecordFromJSON(memory.DefaultAllocator, 
schema, strings.NewReader(`[{"a": [1]}]`))
+       case "map[int]int":
+               schema = SchemaMapIntInt
+               record, _, err = array.RecordFromJSON(memory.DefaultAllocator, 
schema, strings.NewReader(`[{"a": null}]`))
+       default:
+               return nil, nil, fmt.Errorf("Unknown command: '%s'", cmd)
+       }
+
+       if err != nil {
+               return nil, nil, err
+       }
+
+       ch := make(chan flight.StreamChunk)
+       go func() {
+               defer close(ch)
+               ch <- flight.StreamChunk{
+                       Data: record,
+               }
+       }()
+       return schema, ch, nil
+}
+
+type DataTypeTests struct {
+       ServerBasedTests
+}
+
+func (suite *DataTypeTests) SetupSuite() {
+       suite.DoSetupSuite(&DataTypeTestServer{}, nil, map[string]string{})
+}
+
+func (suite *DataTypeTests) DoTestCase(name string, schema *arrow.Schema) {
+       stmt, err := suite.cnxn.NewStatement()
+       suite.NoError(err)
+       defer stmt.Close()
+
+       suite.NoError(stmt.SetSqlQuery(name))
+       reader, _, err := stmt.ExecuteQuery(context.Background())
+       suite.NoError(err)
+       suite.Equal(reader.Schema(), schema)
+       defer reader.Release()
+}
+
+func (suite *DataTypeTests) TestListInt3() {
+       suite.DoTestCase("list[int, 3]", SchemaListInt3)
+}
+
+func (suite *DataTypeTests) TestLargeListInt() {
+       suite.DoTestCase("large_list[int]", SchemaLargeListInt)
+}
+
+func (suite *DataTypeTests) TestListInt() {
+       suite.DoTestCase("list[int]", SchemaListInt)
+}
+
+func (suite *DataTypeTests) TestMapIntInt() {
+       suite.DoTestCase("map[int]int", SchemaMapIntInt)
+}
diff --git a/go/adbc/utils/utils.go b/go/adbc/utils/utils.go
index 1a34e662..4ddf71bd 100644
--- a/go/adbc/utils/utils.go
+++ b/go/adbc/utils/utils.go
@@ -46,7 +46,15 @@ func removeFieldMetadata(field *arrow.Field) arrow.Field {
                case *arrow.LargeListType:
                        fieldType = arrow.LargeListOfField(childFields[0])
                case *arrow.MapType:
-                       mapType := arrow.MapOf(childFields[0].Type, 
childFields[1].Type)
+                       // XXX: arrow-go doesn't let us build a map type from 
fields (so
+                       // nonstandard field names or nullability will be lost 
here)
+
+                       // child must be struct
+                       structType := ty.Elem().(*arrow.StructType)
+                       // struct must have two children
+                       keyType := structType.Field(0).Type
+                       itemType := structType.Field(1).Type
+                       mapType := arrow.MapOf(keyType, itemType)
                        mapType.KeysSorted = ty.KeysSorted
                        fieldType = mapType
                case *arrow.SparseUnionType:

Reply via email to