This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 631f068 fix(go/adbc/driver/flightsql): guard against inconsistent
schemas (#409)
631f068 is described below
commit 631f068d794e4a3eb2299eaa98d12618ee6d2c90
Author: David Li <[email protected]>
AuthorDate: Fri Feb 3 11:20:54 2023 -0500
fix(go/adbc/driver/flightsql): guard against inconsistent schemas (#409)
In case the FlightInfo schema doesn't match the DoGet schema, return an
error instead of allowing the client to misinterpret the result.
---
.github/workflows/native-unix.yml | 4 +-
go/adbc/driver/flightsql/record_reader.go | 54 ++++++++++++++++++++++++++
go/adbc/driver/flightsql/record_reader_test.go | 29 ++++++++++++++
3 files changed, 85 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/native-unix.yml
b/.github/workflows/native-unix.yml
index 35b480e..1211936 100644
--- a/.github/workflows/native-unix.yml
+++ b/.github/workflows/native-unix.yml
@@ -392,7 +392,7 @@ jobs:
cache: true
cache-dependency-path: go/adbc/go.sum
- name: Install staticcheck
- run: go install honnef.co/go/tools/cmd/staticcheck@latest
+ run: go install honnef.co/go/tools/cmd/[email protected]
- name: Go Build
run: |
./ci/scripts/go_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
@@ -450,7 +450,7 @@ jobs:
- name: Install staticcheck
shell: bash -l {0}
if: ${{ !contains('macos-latest', matrix.os) }}
- run: go install honnef.co/go/tools/cmd/staticcheck@latest
+ run: go install honnef.co/go/tools/cmd/[email protected]
- uses: actions/download-artifact@v3
with:
diff --git a/go/adbc/driver/flightsql/record_reader.go
b/go/adbc/driver/flightsql/record_reader.go
index 042c661..9e204c0 100644
--- a/go/adbc/driver/flightsql/record_reader.go
+++ b/go/adbc/driver/flightsql/record_reader.go
@@ -19,6 +19,7 @@ package flightsql
import (
"context"
+ "fmt"
"sync/atomic"
"github.com/apache/arrow-adbc/go/adbc"
@@ -116,6 +117,7 @@ func newRecordReader(ctx context.Context, alloc
memory.Allocator, cl *flightsql.
lastChannelIndex := len(chs) - 1
+ referenceSchema := removeSchemaMetadata(schema)
for i, ep := range endpoints {
endpoint := ep
endpointIndex := i
@@ -132,6 +134,11 @@ func newRecordReader(ctx context.Context, alloc
memory.Allocator, cl *flightsql.
}
defer rdr.Release()
+ streamSchema := removeSchemaMetadata(rdr.Schema())
+ if !streamSchema.Equal(referenceSchema) {
+ return fmt.Errorf("endpoint %d returned
inconsistent schema: expected %s but got %s", endpointIndex,
referenceSchema.String(), streamSchema.String())
+ }
+
for rdr.Next() && ctx.Err() == nil {
rec := rdr.Record()
rec.Retain()
@@ -201,3 +208,50 @@ func (r *reader) Schema() *arrow.Schema {
func (r *reader) Record() arrow.Record {
return r.rec
}
+
+func removeSchemaMetadata(schema *arrow.Schema) *arrow.Schema {
+ fields := make([]arrow.Field, len(schema.Fields()))
+ for i, field := range schema.Fields() {
+ fields[i] = removeFieldMetadata(&field)
+ }
+ return arrow.NewSchema(fields, nil)
+}
+
+func removeFieldMetadata(field *arrow.Field) arrow.Field {
+ fieldType := field.Type
+
+ if nestedType, ok := field.Type.(arrow.NestedType); ok {
+ childFields := make([]arrow.Field, len(nestedType.Fields()))
+ for i, field := range nestedType.Fields() {
+ childFields[i] = removeFieldMetadata(&field)
+ }
+
+ switch ty := field.Type.(type) {
+ case *arrow.DenseUnionType:
+ fieldType = arrow.DenseUnionOf(childFields,
ty.TypeCodes())
+ case *arrow.FixedSizeListType:
+ fieldType = arrow.FixedSizeListOfField(ty.Len(),
childFields[0])
+ case *arrow.ListType:
+ fieldType = arrow.ListOfField(childFields[0])
+ case *arrow.LargeListType:
+ fieldType = arrow.LargeListOfField(childFields[0])
+ case *arrow.MapType:
+ mapType := arrow.MapOf(childFields[0].Type,
childFields[1].Type)
+ mapType.KeysSorted = ty.KeysSorted
+ fieldType = mapType
+ case *arrow.SparseUnionType:
+ fieldType = arrow.SparseUnionOf(childFields,
ty.TypeCodes())
+ case *arrow.StructType:
+ fieldType = arrow.StructOf(childFields...)
+ default:
+ // XXX: ignore it
+ }
+ }
+
+ return arrow.Field{
+ Name: field.Name,
+ Type: fieldType,
+ Nullable: field.Nullable,
+ Metadata: arrow.Metadata{},
+ }
+}
diff --git a/go/adbc/driver/flightsql/record_reader_test.go
b/go/adbc/driver/flightsql/record_reader_test.go
index fd4d31a..c210122 100644
--- a/go/adbc/driver/flightsql/record_reader_test.go
+++ b/go/adbc/driver/flightsql/record_reader_test.go
@@ -282,6 +282,35 @@ func (suite *RecordReaderTests) TestNoSchema() {
suite.NoError(reader.Err())
}
+func (suite *RecordReaderTests) TestSchemaEndpointMismatch() {
+ location := "grpc://" + suite.server.Addr().String()
+ badSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "epIndex", Type: arrow.PrimitiveTypes.Int32},
+ {Name: "batchIndex", Type: arrow.PrimitiveTypes.Int32},
+ }, nil)
+ info := flight.FlightInfo{
+ Schema: flight.SerializeSchema(badSchema, suite.alloc),
+ Endpoint: []*flight.FlightEndpoint{
+ {
+ Ticket: &flight.Ticket{Ticket: []byte{0}},
+ Location: []*flight.Location{{Uri: location}},
+ },
+ {
+ Ticket: &flight.Ticket{Ticket: []byte{1}},
+ Location: []*flight.Location{{Uri: location}},
+ },
+ },
+ }
+
+ reader, err := newRecordReader(context.Background(), suite.alloc,
suite.cl, &info, suite.clCache, 3)
+ suite.NoError(err)
+ defer reader.Release()
+
+ suite.True(reader.Schema().Equal(badSchema))
+ suite.False(reader.Next())
+ suite.ErrorContains(reader.Err(), "returned inconsistent schema:
expected schema:")
+}
+
func (suite *RecordReaderTests) TestOrdering() {
// Info with a ton of endpoints; we want to make sure data comes back
in order
location := "grpc://" + suite.server.Addr().String()